158 Zeilen
5.8 KiB
Python
158 Zeilen
5.8 KiB
Python
import time
|
|
from functools import wraps
|
|
from flask import request, jsonify
|
|
import redis
|
|
from typing import Optional, Tuple
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class RateLimiter:
|
|
"""Rate limiting middleware using Redis"""
|
|
|
|
def __init__(self, redis_url: str):
|
|
self.redis_client = None
|
|
try:
|
|
self.redis_client = redis.from_url(redis_url, decode_responses=True)
|
|
self.redis_client.ping()
|
|
logger.info("Connected to Redis for rate limiting")
|
|
except Exception as e:
|
|
logger.warning(f"Redis not available for rate limiting: {e}")
|
|
|
|
def limit(self, requests_per_minute: int = 60, requests_per_hour: int = 1000):
|
|
"""Decorator for rate limiting endpoints"""
|
|
def decorator(f):
|
|
@wraps(f)
|
|
def decorated_function(*args, **kwargs):
|
|
if not self.redis_client:
|
|
# Redis not available, skip rate limiting
|
|
return f(*args, **kwargs)
|
|
|
|
# Get client identifier (API key or IP)
|
|
client_id = self._get_client_id()
|
|
|
|
# Check rate limits
|
|
is_allowed, retry_after = self._check_rate_limit(
|
|
client_id,
|
|
requests_per_minute,
|
|
requests_per_hour
|
|
)
|
|
|
|
if not is_allowed:
|
|
response = jsonify({
|
|
"error": "Rate limit exceeded",
|
|
"retry_after": retry_after
|
|
})
|
|
response.status_code = 429
|
|
response.headers['Retry-After'] = str(retry_after)
|
|
response.headers['X-RateLimit-Limit'] = str(requests_per_minute)
|
|
return response
|
|
|
|
# Add rate limit headers
|
|
response = f(*args, **kwargs)
|
|
if hasattr(response, 'headers'):
|
|
response.headers['X-RateLimit-Limit'] = str(requests_per_minute)
|
|
response.headers['X-RateLimit-Remaining'] = str(
|
|
self._get_remaining_requests(client_id, requests_per_minute)
|
|
)
|
|
|
|
return response
|
|
|
|
return decorated_function
|
|
return decorator
|
|
|
|
def _get_client_id(self) -> str:
|
|
"""Get client identifier from request"""
|
|
# First try API key
|
|
api_key = request.headers.get('X-API-Key')
|
|
if api_key:
|
|
return f"api_key:{api_key}"
|
|
|
|
# Then try auth token
|
|
auth_header = request.headers.get('Authorization')
|
|
if auth_header and auth_header.startswith('Bearer '):
|
|
return f"token:{auth_header[7:32]}" # Use first 32 chars of token
|
|
|
|
# Fallback to IP
|
|
if request.headers.get('X-Forwarded-For'):
|
|
ip = request.headers.get('X-Forwarded-For').split(',')[0]
|
|
else:
|
|
ip = request.remote_addr
|
|
|
|
return f"ip:{ip}"
|
|
|
|
def _check_rate_limit(self, client_id: str,
|
|
requests_per_minute: int,
|
|
requests_per_hour: int) -> Tuple[bool, Optional[int]]:
|
|
"""Check if request is within rate limits"""
|
|
now = int(time.time())
|
|
|
|
# Check minute limit
|
|
minute_key = f"rate_limit:minute:{client_id}:{now // 60}"
|
|
minute_count = self.redis_client.incr(minute_key)
|
|
self.redis_client.expire(minute_key, 60)
|
|
|
|
if minute_count > requests_per_minute:
|
|
retry_after = 60 - (now % 60)
|
|
return False, retry_after
|
|
|
|
# Check hour limit
|
|
hour_key = f"rate_limit:hour:{client_id}:{now // 3600}"
|
|
hour_count = self.redis_client.incr(hour_key)
|
|
self.redis_client.expire(hour_key, 3600)
|
|
|
|
if hour_count > requests_per_hour:
|
|
retry_after = 3600 - (now % 3600)
|
|
return False, retry_after
|
|
|
|
return True, None
|
|
|
|
def _get_remaining_requests(self, client_id: str, limit: int) -> int:
|
|
"""Get remaining requests in current minute"""
|
|
now = int(time.time())
|
|
minute_key = f"rate_limit:minute:{client_id}:{now // 60}"
|
|
|
|
try:
|
|
current_count = int(self.redis_client.get(minute_key) or 0)
|
|
return max(0, limit - current_count)
|
|
except:
|
|
return limit
|
|
|
|
class APIKeyRateLimiter(RateLimiter):
|
|
"""Extended rate limiter with API key specific limits"""
|
|
|
|
def __init__(self, redis_url: str, db_repo):
|
|
super().__init__(redis_url)
|
|
self.db_repo = db_repo
|
|
|
|
def limit_by_api_key(self):
|
|
"""Rate limit based on API key configuration"""
|
|
def decorator(f):
|
|
@wraps(f)
|
|
def decorated_function(*args, **kwargs):
|
|
api_key = request.headers.get('X-API-Key')
|
|
|
|
if not api_key:
|
|
# Use default limits for non-API key requests
|
|
return self.limit()(f)(*args, **kwargs)
|
|
|
|
# Get API key configuration from database
|
|
query = """
|
|
SELECT rate_limit_per_minute, rate_limit_per_hour
|
|
FROM api_clients
|
|
WHERE api_key = %s AND is_active = true
|
|
"""
|
|
|
|
client = self.db_repo.execute_one(query, (api_key,))
|
|
|
|
if not client:
|
|
return jsonify({"error": "Invalid API key"}), 401
|
|
|
|
# Use custom limits or defaults
|
|
rpm = client.get('rate_limit_per_minute', 60)
|
|
rph = client.get('rate_limit_per_hour', 1000)
|
|
|
|
return self.limit(rpm, rph)(f)(*args, **kwargs)
|
|
|
|
return decorated_function
|
|
return decorator |