import time import functools from typing import Dict, Any, Optional, List from collections import defaultdict, deque from datetime import datetime, timedelta from threading import Lock import logging from prometheus_client import Counter, Histogram, Gauge, generate_latest from flask import g, request, Response from .exceptions import BaseApplicationException from .logging_config import log_security_event logger = logging.getLogger(__name__) class ErrorMetrics: def __init__(self): self.error_counter = Counter( 'app_errors_total', 'Total number of errors', ['error_code', 'status_code', 'endpoint'] ) self.error_rate = Gauge( 'app_error_rate', 'Error rate per minute', ['error_code'] ) self.request_duration = Histogram( 'app_request_duration_seconds', 'Request duration in seconds', ['method', 'endpoint', 'status_code'] ) self.validation_errors = Counter( 'app_validation_errors_total', 'Total validation errors', ['field', 'endpoint'] ) self.auth_failures = Counter( 'app_auth_failures_total', 'Total authentication failures', ['reason', 'endpoint'] ) self.db_errors = Counter( 'app_database_errors_total', 'Total database errors', ['error_type', 'operation'] ) self._error_history = defaultdict(lambda: deque(maxlen=60)) self._lock = Lock() def record_error(self, error: BaseApplicationException, endpoint: str = None): endpoint = endpoint or request.endpoint or 'unknown' self.error_counter.labels( error_code=error.code, status_code=error.status_code, endpoint=endpoint ).inc() with self._lock: self._error_history[error.code].append(datetime.utcnow()) self._update_error_rates() if error.code == 'VALIDATION_ERROR' and 'field' in error.details: self.validation_errors.labels( field=error.details['field'], endpoint=endpoint ).inc() elif error.code == 'AUTHENTICATION_ERROR': reason = error.__class__.__name__ self.auth_failures.labels( reason=reason, endpoint=endpoint ).inc() elif error.code == 'DATABASE_ERROR': error_type = error.__class__.__name__ operation = error.details.get('operation', 'unknown') self.db_errors.labels( error_type=error_type, operation=operation ).inc() def _update_error_rates(self): now = datetime.utcnow() one_minute_ago = now - timedelta(minutes=1) for error_code, timestamps in self._error_history.items(): recent_count = sum(1 for ts in timestamps if ts >= one_minute_ago) self.error_rate.labels(error_code=error_code).set(recent_count) class AlertManager: def __init__(self): self.alerts = [] self.alert_thresholds = { 'error_rate': 10, 'auth_failure_rate': 5, 'db_error_rate': 3, 'response_time_95th': 2.0 } self._lock = Lock() def check_alerts(self, metrics: ErrorMetrics): new_alerts = [] for error_code, rate in self._get_current_error_rates(metrics).items(): if rate > self.alert_thresholds['error_rate']: new_alerts.append({ 'type': 'high_error_rate', 'severity': 'critical', 'error_code': error_code, 'rate': rate, 'threshold': self.alert_thresholds['error_rate'], 'message': f'High error rate for {error_code}: {rate}/min', 'timestamp': datetime.utcnow() }) auth_failure_rate = self._get_auth_failure_rate(metrics) if auth_failure_rate > self.alert_thresholds['auth_failure_rate']: new_alerts.append({ 'type': 'auth_failures', 'severity': 'warning', 'rate': auth_failure_rate, 'threshold': self.alert_thresholds['auth_failure_rate'], 'message': f'High authentication failure rate: {auth_failure_rate}/min', 'timestamp': datetime.utcnow() }) log_security_event( 'HIGH_AUTH_FAILURE_RATE', f'Authentication failure rate exceeded threshold', rate=auth_failure_rate, threshold=self.alert_thresholds['auth_failure_rate'] ) with self._lock: self.alerts.extend(new_alerts) self.alerts = [a for a in self.alerts if a['timestamp'] > datetime.utcnow() - timedelta(hours=24)] return new_alerts def _get_current_error_rates(self, metrics: ErrorMetrics) -> Dict[str, float]: rates = {} with metrics._lock: now = datetime.utcnow() one_minute_ago = now - timedelta(minutes=1) for error_code, timestamps in metrics._error_history.items(): rates[error_code] = sum(1 for ts in timestamps if ts >= one_minute_ago) return rates def _get_auth_failure_rate(self, metrics: ErrorMetrics) -> float: return sum( sample.value for sample in metrics.auth_failures._child_samples() ) / 60.0 def get_active_alerts(self) -> List[Dict[str, Any]]: with self._lock: return list(self.alerts) error_metrics = ErrorMetrics() alert_manager = AlertManager() def init_monitoring(app): @app.before_request def before_request(): g.start_time = time.time() @app.after_request def after_request(response): if hasattr(g, 'start_time'): duration = time.time() - g.start_time error_metrics.request_duration.labels( method=request.method, endpoint=request.endpoint or 'unknown', status_code=response.status_code ).observe(duration) return response @app.route('/metrics') def metrics(): alert_manager.check_alerts(error_metrics) return Response(generate_latest(), mimetype='text/plain') @app.route('/api/alerts') def get_alerts(): alerts = alert_manager.get_active_alerts() return { 'alerts': alerts, 'total': len(alerts), 'critical': len([a for a in alerts if a['severity'] == 'critical']), 'warning': len([a for a in alerts if a['severity'] == 'warning']) } def monitor_performance(func): @functools.wraps(func) def wrapper(*args, **kwargs): start_time = time.time() try: result = func(*args, **kwargs) return result finally: duration = time.time() - start_time if duration > 1.0: logger.warning( f"Slow function execution: {func.__name__}", extra={ 'function': func.__name__, 'duration': duration, 'request_id': getattr(g, 'request_id', 'unknown') } ) return wrapper def track_error(error: BaseApplicationException): error_metrics.record_error(error) if error.status_code >= 500: logger.error( f"Critical error occurred: {error.code}", extra={ 'error_code': error.code, 'message': error.message, 'details': error.details, 'request_id': error.request_id } )