import os import sys from flask import Flask, request, jsonify from flask_cors import CORS import jwt from datetime import datetime, timedelta import logging from functools import wraps from prometheus_flask_exporter import PrometheusMetrics # Add parent directory to path for imports sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from config import get_config from repositories.base import BaseRepository # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize Flask app app = Flask(__name__) config = get_config() app.config.from_object(config) CORS(app) # Initialize Prometheus metrics metrics = PrometheusMetrics(app) metrics.info('auth_service_info', 'Auth Service Information', version='1.0.0') # Initialize repository db_repo = BaseRepository(config.DATABASE_URL) def create_token(payload: dict, expires_delta: timedelta) -> str: """Create JWT token""" to_encode = payload.copy() expire = datetime.utcnow() + expires_delta to_encode.update({"exp": expire, "iat": datetime.utcnow()}) return jwt.encode( to_encode, config.JWT_SECRET, algorithm=config.JWT_ALGORITHM ) def decode_token(token: str) -> dict: """Decode and validate JWT token""" try: payload = jwt.decode( token, config.JWT_SECRET, algorithms=[config.JWT_ALGORITHM] ) return payload except jwt.ExpiredSignatureError: raise ValueError("Token has expired") except jwt.InvalidTokenError: raise ValueError("Invalid token") def require_api_key(f): """Decorator to require API key""" @wraps(f) def decorated_function(*args, **kwargs): api_key = request.headers.get('X-API-Key') if not api_key: return jsonify({"error": "Missing API key"}), 401 # Validate API key query = """ SELECT id, client_name, allowed_endpoints FROM api_clients WHERE api_key = %s AND is_active = true """ client = db_repo.execute_one(query, (api_key,)) if not client: return jsonify({"error": "Invalid API key"}), 401 # Check if endpoint is allowed endpoint = request.endpoint allowed = client.get('allowed_endpoints', []) if allowed and endpoint not in allowed: return jsonify({"error": "Endpoint not allowed"}), 403 # Add client info to request request.api_client = client return f(*args, **kwargs) return decorated_function @app.route('/health', methods=['GET']) def health_check(): """Health check endpoint""" return jsonify({ "status": "healthy", "service": "auth", "timestamp": datetime.utcnow().isoformat() }) @app.route('/api/v1/auth/token', methods=['POST']) @require_api_key def create_access_token(): """Create access token for license validation""" data = request.get_json() if not data or 'license_id' not in data: return jsonify({"error": "Missing license_id"}), 400 license_id = data['license_id'] hardware_id = data.get('hardware_id') # Verify license exists and is active query = """ SELECT id, is_active, max_devices FROM licenses WHERE id = %s """ license = db_repo.execute_one(query, (license_id,)) if not license: return jsonify({"error": "License not found"}), 404 if not license['is_active']: return jsonify({"error": "License is not active"}), 403 # Create token payload payload = { "sub": license_id, "hwid": hardware_id, "client_id": request.api_client['id'], "type": "access" } # Add features and limits based on license payload["features"] = data.get('features', []) payload["limits"] = { "api_calls": config.DEFAULT_RATE_LIMIT_PER_HOUR, "concurrent_sessions": config.MAX_CONCURRENT_SESSIONS } # Create tokens access_token = create_token(payload, config.JWT_ACCESS_TOKEN_EXPIRES) # Create refresh token refresh_payload = { "sub": license_id, "client_id": request.api_client['id'], "type": "refresh" } refresh_token = create_token(refresh_payload, config.JWT_REFRESH_TOKEN_EXPIRES) return jsonify({ "access_token": access_token, "refresh_token": refresh_token, "token_type": "Bearer", "expires_in": int(config.JWT_ACCESS_TOKEN_EXPIRES.total_seconds()) }) @app.route('/api/v1/auth/refresh', methods=['POST']) def refresh_access_token(): """Refresh access token""" data = request.get_json() if not data or 'refresh_token' not in data: return jsonify({"error": "Missing refresh_token"}), 400 try: # Decode refresh token payload = decode_token(data['refresh_token']) if payload.get('type') != 'refresh': return jsonify({"error": "Invalid token type"}), 400 license_id = payload['sub'] # Verify license still active query = "SELECT is_active FROM licenses WHERE id = %s" license = db_repo.execute_one(query, (license_id,)) if not license or not license['is_active']: return jsonify({"error": "License is not active"}), 403 # Create new access token access_payload = { "sub": license_id, "client_id": payload['client_id'], "type": "access" } access_token = create_token(access_payload, config.JWT_ACCESS_TOKEN_EXPIRES) return jsonify({ "access_token": access_token, "token_type": "Bearer", "expires_in": int(config.JWT_ACCESS_TOKEN_EXPIRES.total_seconds()) }) except ValueError as e: return jsonify({"error": str(e)}), 401 @app.route('/api/v1/auth/verify', methods=['POST']) def verify_token(): """Verify token validity""" auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return jsonify({"error": "Missing or invalid authorization header"}), 401 token = auth_header.split(' ')[1] try: payload = decode_token(token) return jsonify({ "valid": True, "license_id": payload['sub'], "expires_at": datetime.fromtimestamp(payload['exp']).isoformat() }) except ValueError as e: return jsonify({ "valid": False, "error": str(e) }), 401 @app.route('/api/v1/auth/api-key', methods=['POST']) def create_api_key(): """Create new API key (admin only)""" # This endpoint should be protected by admin authentication # For now, we'll use a simple secret header admin_secret = request.headers.get('X-Admin-Secret') if admin_secret != os.getenv('ADMIN_SECRET', 'change-this-admin-secret'): return jsonify({"error": "Unauthorized"}), 401 data = request.get_json() if not data or 'client_name' not in data: return jsonify({"error": "Missing client_name"}), 400 import secrets api_key = f"sk_{secrets.token_urlsafe(32)}" secret_key = secrets.token_urlsafe(64) query = """ INSERT INTO api_clients (client_name, api_key, secret_key, allowed_endpoints) VALUES (%s, %s, %s, %s) RETURNING id """ allowed_endpoints = data.get('allowed_endpoints', []) client_id = db_repo.execute_insert( query, (data['client_name'], api_key, secret_key, allowed_endpoints) ) if not client_id: return jsonify({"error": "Failed to create API key"}), 500 return jsonify({ "client_id": client_id, "api_key": api_key, "secret_key": secret_key, "client_name": data['client_name'] }), 201 @app.errorhandler(404) def not_found(error): return jsonify({"error": "Not found"}), 404 @app.errorhandler(500) def internal_error(error): logger.error(f"Internal error: {error}") return jsonify({"error": "Internal server error"}), 500 if __name__ == '__main__': app.run(host='0.0.0.0', port=5001, debug=True)