import redis import json import logging from typing import Optional, Any, Dict, List from datetime import timedelta logger = logging.getLogger(__name__) class CacheRepository: """Redis cache repository""" def __init__(self, redis_url: str): self.redis_url = redis_url self._connect() def _connect(self): """Connect to Redis""" try: self.redis = redis.from_url(self.redis_url, decode_responses=True) self.redis.ping() logger.info("Connected to Redis") except Exception as e: logger.error(f"Failed to connect to Redis: {e}") self.redis = None def _make_key(self, prefix: str, *args) -> str: """Create cache key""" parts = [prefix] + [str(arg) for arg in args] return ":".join(parts) def get(self, key: str) -> Optional[Any]: """Get value from cache""" if not self.redis: return None try: value = self.redis.get(key) if value: return json.loads(value) return None except Exception as e: logger.error(f"Cache get error: {e}") return None def set(self, key: str, value: Any, ttl: int = 300) -> bool: """Set value in cache with TTL in seconds""" if not self.redis: return False try: json_value = json.dumps(value) return self.redis.setex(key, ttl, json_value) except Exception as e: logger.error(f"Cache set error: {e}") return False def delete(self, key: str) -> bool: """Delete key from cache""" if not self.redis: return False try: return bool(self.redis.delete(key)) except Exception as e: logger.error(f"Cache delete error: {e}") return False def delete_pattern(self, pattern: str) -> int: """Delete all keys matching pattern""" if not self.redis: return 0 try: keys = self.redis.keys(pattern) if keys: return self.redis.delete(*keys) return 0 except Exception as e: logger.error(f"Cache delete pattern error: {e}") return 0 # License-specific cache methods def get_license_validation(self, license_key: str, hardware_id: str) -> Optional[Dict[str, Any]]: """Get cached license validation result""" key = self._make_key("license:validation", license_key, hardware_id) return self.get(key) def set_license_validation(self, license_key: str, hardware_id: str, result: Dict[str, Any], ttl: int = 300) -> bool: """Cache license validation result""" key = self._make_key("license:validation", license_key, hardware_id) return self.set(key, result, ttl) def get_license_status(self, license_id: str) -> Optional[Dict[str, Any]]: """Get cached license status""" key = self._make_key("license:status", license_id) return self.get(key) def set_license_status(self, license_id: str, status: Dict[str, Any], ttl: int = 60) -> bool: """Cache license status""" key = self._make_key("license:status", license_id) return self.set(key, status, ttl) def get_device_list(self, license_id: str) -> Optional[List[Dict[str, Any]]]: """Get cached device list""" key = self._make_key("license:devices", license_id) return self.get(key) def set_device_list(self, license_id: str, devices: List[Dict[str, Any]], ttl: int = 300) -> bool: """Cache device list""" key = self._make_key("license:devices", license_id) return self.set(key, devices, ttl) def invalidate_license_cache(self, license_id: str) -> None: """Invalidate all cache entries for a license""" patterns = [ f"license:validation:*:{license_id}", f"license:status:{license_id}", f"license:devices:{license_id}" ] for pattern in patterns: self.delete_pattern(pattern) # Rate limiting methods def check_rate_limit(self, key: str, limit: int, window: int) -> tuple[bool, int]: """Check if rate limit is exceeded Returns: (is_allowed, current_count) """ if not self.redis: return True, 0 try: pipe = self.redis.pipeline() now = int(time.time()) window_start = now - window # Remove old entries pipe.zremrangebyscore(key, 0, window_start) # Count requests in current window pipe.zcard(key) # Add current request pipe.zadd(key, {str(now): now}) # Set expiry pipe.expire(key, window + 1) results = pipe.execute() current_count = results[1] return current_count < limit, current_count + 1 except Exception as e: logger.error(f"Rate limit check error: {e}") return True, 0 def increment_counter(self, key: str, window: int = 3600) -> int: """Increment counter with expiry""" if not self.redis: return 0 try: pipe = self.redis.pipeline() pipe.incr(key) pipe.expire(key, window) results = pipe.execute() return results[0] except Exception as e: logger.error(f"Counter increment error: {e}") return 0 import time # Add this import at the top