Spaces:
Building
Building
import os | |
from typing import Optional | |
from fastapi import HTTPException, Depends | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from datetime import datetime, timedelta, timezone | |
import jwt | |
from logger import log_info, log_warning | |
security = HTTPBearer() | |
# ===================== Rate Limiting ===================== | |
class RateLimiter: | |
"""Simple in-memory rate limiter""" | |
def __init__(self): | |
self.requests = {} # {key: [(timestamp, count)]} | |
self.lock = threading.Lock() | |
def is_allowed(self, key: str, max_requests: int, window_seconds: int) -> bool: | |
"""Check if request is allowed""" | |
with self.lock: | |
now = datetime.now(timezone.utc) | |
if key not in self.requests: | |
self.requests[key] = [] | |
# Remove old entries | |
cutoff = now.timestamp() - window_seconds | |
self.requests[key] = [ | |
(ts, count) for ts, count in self.requests[key] | |
if ts > cutoff | |
] | |
# Count requests in window | |
total = sum(count for _, count in self.requests[key]) | |
if total >= max_requests: | |
return False | |
# Add this request | |
self.requests[key].append((now.timestamp(), 1)) | |
return True | |
def reset(self, key: str): | |
"""Reset rate limit for key""" | |
with self.lock: | |
if key in self.requests: | |
del self.requests[key] | |
# Create global rate limiter instance | |
import threading | |
rate_limiter = RateLimiter() | |
# ===================== JWT Config ===================== | |
def get_jwt_config(): | |
"""Get JWT configuration based on environment""" | |
# Check if we're in HuggingFace Space | |
if os.getenv("SPACE_ID"): | |
# Cloud mode - use secrets from environment | |
jwt_secret = os.getenv("JWT_SECRET") | |
if not jwt_secret: | |
log_warning("⚠️ WARNING: JWT_SECRET not found in environment, using fallback") | |
jwt_secret = "flare-admin-secret-key-change-in-production" # Fallback | |
else: | |
# On-premise mode - use .env file | |
from dotenv import load_dotenv | |
load_dotenv() | |
jwt_secret = os.getenv("JWT_SECRET", "flare-admin-secret-key-change-in-production") | |
return { | |
"secret": jwt_secret, | |
"algorithm": os.getenv("JWT_ALGORITHM", "HS256"), | |
"expiration_hours": int(os.getenv("JWT_EXPIRATION_HOURS", "24")) | |
} | |
# ===================== Auth Helpers ===================== | |
def create_token(username: str) -> str: | |
"""Create JWT token for user""" | |
config = get_jwt_config() | |
expiry = datetime.now(timezone.utc) + timedelta(hours=config["expiration_hours"]) | |
payload = { | |
"sub": username, | |
"exp": expiry, | |
"iat": datetime.now(timezone.utc) | |
} | |
return jwt.encode(payload, config["secret"], algorithm=config["algorithm"]) | |
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str: | |
"""Verify JWT token and return username""" | |
token = credentials.credentials | |
config = get_jwt_config() | |
try: | |
payload = jwt.decode(token, config["secret"], algorithms=[config["algorithm"]]) | |
return payload["sub"] | |
except jwt.ExpiredSignatureError: | |
raise HTTPException(status_code=401, detail="Token expired") | |
except jwt.InvalidTokenError: | |
raise HTTPException(status_code=401, detail="Invalid token") | |
# ===================== Utility Functions ===================== | |
def truncate_string(text: str, max_length: int = 100, suffix: str = "...") -> str: | |
"""Truncate string to max length""" | |
if len(text) <= max_length: | |
return text | |
return text[:max_length - len(suffix)] + suffix | |
def format_file_size(size_bytes: int) -> str: | |
"""Format file size in human readable format""" | |
for unit in ['B', 'KB', 'MB', 'GB', 'TB']: | |
if size_bytes < 1024.0: | |
return f"{size_bytes:.2f} {unit}" | |
size_bytes /= 1024.0 | |
return f"{size_bytes:.2f} PB" | |
def is_safe_path(path: str, base_path: str) -> bool: | |
"""Check if path is safe (no directory traversal)""" | |
import os | |
# Resolve to absolute paths | |
base = os.path.abspath(base_path) | |
target = os.path.abspath(os.path.join(base, path)) | |
# Check if target is under base | |
return target.startswith(base) | |
def get_current_timestamp() -> str: | |
""" | |
Get current UTC timestamp in ISO format with Z suffix | |
Returns: "2025-01-10T12:00:00.123Z" | |
""" | |
return datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z') | |
def normalize_timestamp(timestamp: Optional[str]) -> str: | |
""" | |
Normalize timestamp string for consistent comparison | |
Handles various formats: | |
- "2025-01-10T12:00:00Z" | |
- "2025-01-10T12:00:00.000Z" | |
- "2025-01-10T12:00:00+00:00" | |
- "2025-01-10 12:00:00+00:00" | |
""" | |
if not timestamp: | |
return "" | |
# Normalize various formats | |
normalized = timestamp.replace(' ', 'T') # Space to T | |
normalized = normalized.replace('+00:00', 'Z') # UTC timezone | |
# Remove milliseconds if present for comparison | |
if '.' in normalized and normalized.endswith('Z'): | |
normalized = normalized.split('.')[0] + 'Z' | |
return normalized | |
def timestamps_equal(ts1: Optional[str], ts2: Optional[str]) -> bool: | |
""" | |
Compare two timestamps regardless of format differences | |
""" | |
return normalize_timestamp(ts1) == normalize_timestamp(ts2) |