import time from functools import wraps from flask import request, jsonify import redis from typing import Optional, Tuple import logging logger = logging.getLogger(__name__) class RateLimiter: """Rate limiting middleware using Redis""" def __init__(self, redis_url: str): self.redis_client = None try: self.redis_client = redis.from_url(redis_url, decode_responses=True) self.redis_client.ping() logger.info("Connected to Redis for rate limiting") except Exception as e: logger.warning(f"Redis not available for rate limiting: {e}") def limit(self, requests_per_minute: int = 60, requests_per_hour: int = 1000): """Decorator for rate limiting endpoints""" def decorator(f): @wraps(f) def decorated_function(*args, **kwargs): if not self.redis_client: # Redis not available, skip rate limiting return f(*args, **kwargs) # Get client identifier (API key or IP) client_id = self._get_client_id() # Check rate limits is_allowed, retry_after = self._check_rate_limit( client_id, requests_per_minute, requests_per_hour ) if not is_allowed: response = jsonify({ "error": "Rate limit exceeded", "retry_after": retry_after }) response.status_code = 429 response.headers['Retry-After'] = str(retry_after) response.headers['X-RateLimit-Limit'] = str(requests_per_minute) return response # Add rate limit headers response = f(*args, **kwargs) if hasattr(response, 'headers'): response.headers['X-RateLimit-Limit'] = str(requests_per_minute) response.headers['X-RateLimit-Remaining'] = str( self._get_remaining_requests(client_id, requests_per_minute) ) return response return decorated_function return decorator def _get_client_id(self) -> str: """Get client identifier from request""" # First try API key api_key = request.headers.get('X-API-Key') if api_key: return f"api_key:{api_key}" # Then try auth token auth_header = request.headers.get('Authorization') if auth_header and auth_header.startswith('Bearer '): return f"token:{auth_header[7:32]}" # Use first 32 chars of token # Fallback to IP if request.headers.get('X-Forwarded-For'): ip = request.headers.get('X-Forwarded-For').split(',')[0] else: ip = request.remote_addr return f"ip:{ip}" def _check_rate_limit(self, client_id: str, requests_per_minute: int, requests_per_hour: int) -> Tuple[bool, Optional[int]]: """Check if request is within rate limits""" now = int(time.time()) # Check minute limit minute_key = f"rate_limit:minute:{client_id}:{now // 60}" minute_count = self.redis_client.incr(minute_key) self.redis_client.expire(minute_key, 60) if minute_count > requests_per_minute: retry_after = 60 - (now % 60) return False, retry_after # Check hour limit hour_key = f"rate_limit:hour:{client_id}:{now // 3600}" hour_count = self.redis_client.incr(hour_key) self.redis_client.expire(hour_key, 3600) if hour_count > requests_per_hour: retry_after = 3600 - (now % 3600) return False, retry_after return True, None def _get_remaining_requests(self, client_id: str, limit: int) -> int: """Get remaining requests in current minute""" now = int(time.time()) minute_key = f"rate_limit:minute:{client_id}:{now // 60}" try: current_count = int(self.redis_client.get(minute_key) or 0) return max(0, limit - current_count) except: return limit class APIKeyRateLimiter(RateLimiter): """Extended rate limiter with API key specific limits""" def __init__(self, redis_url: str, db_repo): super().__init__(redis_url) self.db_repo = db_repo def limit_by_api_key(self): """Rate limit based on API key configuration""" def decorator(f): @wraps(f) def decorated_function(*args, **kwargs): api_key = request.headers.get('X-API-Key') if not api_key: # Use default limits for non-API key requests return self.limit()(f)(*args, **kwargs) # Get API key configuration from database query = """ SELECT rate_limit_per_minute, rate_limit_per_hour FROM api_clients WHERE api_key = %s AND is_active = true """ client = self.db_repo.execute_one(query, (api_key,)) if not client: return jsonify({"error": "Invalid API key"}), 401 # Use custom limits or defaults rpm = client.get('rate_limit_per_minute', 60) rph = client.get('rate_limit_per_hour', 1000) return self.limit(rpm, rph)(f)(*args, **kwargs) return decorated_function return decorator