import os from typing import Optional from fastapi import HTTPException, Depends from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from datetime import datetime, timedelta, timezone import jwt from .logger import log_info, log_warning security = HTTPBearer() # ===================== Rate Limiting ===================== class RateLimiter: """Simple in-memory rate limiter""" def __init__(self): self.requests = {} # {key: [(timestamp, count)]} self.lock = threading.Lock() def is_allowed(self, key: str, max_requests: int, window_seconds: int) -> bool: """Check if request is allowed""" with self.lock: now = datetime.now(timezone.utc) if key not in self.requests: self.requests[key] = [] # Remove old entries cutoff = now.timestamp() - window_seconds self.requests[key] = [ (ts, count) for ts, count in self.requests[key] if ts > cutoff ] # Count requests in window total = sum(count for _, count in self.requests[key]) if total >= max_requests: return False # Add this request self.requests[key].append((now.timestamp(), 1)) return True def reset(self, key: str): """Reset rate limit for key""" with self.lock: if key in self.requests: del self.requests[key] # Create global rate limiter instance import threading rate_limiter = RateLimiter() # ===================== JWT Config ===================== def get_jwt_config(): """Get JWT configuration based on environment""" # Check if we're in HuggingFace Space if os.getenv("SPACE_ID"): # Cloud mode - use secrets from environment jwt_secret = os.getenv("JWT_SECRET") if not jwt_secret: log_warning("⚠️ WARNING: JWT_SECRET not found in environment, using fallback") jwt_secret = "flare-admin-secret-key-change-in-production" # Fallback else: # On-premise mode - use .env file from dotenv import load_dotenv load_dotenv() jwt_secret = os.getenv("JWT_SECRET", "flare-admin-secret-key-change-in-production") return { "secret": jwt_secret, "algorithm": os.getenv("JWT_ALGORITHM", "HS256"), "expiration_hours": int(os.getenv("JWT_EXPIRATION_HOURS", "24")) } # ===================== Auth Helpers ===================== def create_token(username: str) -> str: """Create JWT token for user""" config = get_jwt_config() expiry = datetime.now(timezone.utc) + timedelta(hours=config["expiration_hours"]) payload = { "sub": username, "exp": expiry, "iat": datetime.now(timezone.utc) } return jwt.encode(payload, config["secret"], algorithm=config["algorithm"]) def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str: """Verify JWT token and return username""" token = credentials.credentials config = get_jwt_config() try: payload = jwt.decode(token, config["secret"], algorithms=[config["algorithm"]]) return payload["sub"] except jwt.ExpiredSignatureError: raise HTTPException(status_code=401, detail="Token expired") except jwt.InvalidTokenError: raise HTTPException(status_code=401, detail="Invalid token") # ===================== Utility Functions ===================== def truncate_string(text: str, max_length: int = 100, suffix: str = "...") -> str: """Truncate string to max length""" if len(text) <= max_length: return text return text[:max_length - len(suffix)] + suffix def format_file_size(size_bytes: int) -> str: """Format file size in human readable format""" for unit in ['B', 'KB', 'MB', 'GB', 'TB']: if size_bytes < 1024.0: return f"{size_bytes:.2f} {unit}" size_bytes /= 1024.0 return f"{size_bytes:.2f} PB" def is_safe_path(path: str, base_path: str) -> bool: """Check if path is safe (no directory traversal)""" import os # Resolve to absolute paths base = os.path.abspath(base_path) target = os.path.abspath(os.path.join(base, path)) # Check if target is under base return target.startswith(base) def get_current_timestamp() -> str: """ Get current UTC timestamp in ISO format with Z suffix Returns: "2025-01-10T12:00:00.123Z" """ return datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z') def normalize_timestamp(timestamp: Optional[str]) -> str: """ Normalize timestamp string for consistent comparison Handles various formats: - "2025-01-10T12:00:00Z" - "2025-01-10T12:00:00.000Z" - "2025-01-10T12:00:00+00:00" - "2025-01-10 12:00:00+00:00" """ if not timestamp: return "" # Normalize various formats normalized = timestamp.replace(' ', 'T') # Space to T normalized = normalized.replace('+00:00', 'Z') # UTC timezone # Remove milliseconds if present for comparison if '.' in normalized and normalized.endswith('Z'): normalized = normalized.split('.')[0] + 'Z' return normalized def timestamps_equal(ts1: Optional[str], ts2: Optional[str]) -> bool: """ Compare two timestamps regardless of format differences """ return normalize_timestamp(ts1) == normalize_timestamp(ts2)