flare / utils /utils.py
ciyidogan's picture
Upload 7 files
1e4a027 verified
raw
history blame
5.73 kB
import os
from typing import Optional
from fastapi import HTTPException, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from datetime import datetime, timedelta, timezone
import jwt
from logger import log_info, log_warning
security = HTTPBearer()
# ===================== Rate Limiting =====================
class RateLimiter:
"""Simple in-memory rate limiter"""
def __init__(self):
self.requests = {} # {key: [(timestamp, count)]}
self.lock = threading.Lock()
def is_allowed(self, key: str, max_requests: int, window_seconds: int) -> bool:
"""Check if request is allowed"""
with self.lock:
now = datetime.now(timezone.utc)
if key not in self.requests:
self.requests[key] = []
# Remove old entries
cutoff = now.timestamp() - window_seconds
self.requests[key] = [
(ts, count) for ts, count in self.requests[key]
if ts > cutoff
]
# Count requests in window
total = sum(count for _, count in self.requests[key])
if total >= max_requests:
return False
# Add this request
self.requests[key].append((now.timestamp(), 1))
return True
def reset(self, key: str):
"""Reset rate limit for key"""
with self.lock:
if key in self.requests:
del self.requests[key]
# Create global rate limiter instance
import threading
rate_limiter = RateLimiter()
# ===================== JWT Config =====================
def get_jwt_config():
"""Get JWT configuration based on environment"""
# Check if we're in HuggingFace Space
if os.getenv("SPACE_ID"):
# Cloud mode - use secrets from environment
jwt_secret = os.getenv("JWT_SECRET")
if not jwt_secret:
log_warning("⚠️ WARNING: JWT_SECRET not found in environment, using fallback")
jwt_secret = "flare-admin-secret-key-change-in-production" # Fallback
else:
# On-premise mode - use .env file
from dotenv import load_dotenv
load_dotenv()
jwt_secret = os.getenv("JWT_SECRET", "flare-admin-secret-key-change-in-production")
return {
"secret": jwt_secret,
"algorithm": os.getenv("JWT_ALGORITHM", "HS256"),
"expiration_hours": int(os.getenv("JWT_EXPIRATION_HOURS", "24"))
}
# ===================== Auth Helpers =====================
def create_token(username: str) -> str:
"""Create JWT token for user"""
config = get_jwt_config()
expiry = datetime.now(timezone.utc) + timedelta(hours=config["expiration_hours"])
payload = {
"sub": username,
"exp": expiry,
"iat": datetime.now(timezone.utc)
}
return jwt.encode(payload, config["secret"], algorithm=config["algorithm"])
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str:
"""Verify JWT token and return username"""
token = credentials.credentials
config = get_jwt_config()
try:
payload = jwt.decode(token, config["secret"], algorithms=[config["algorithm"]])
return payload["sub"]
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")
# ===================== Utility Functions =====================
def truncate_string(text: str, max_length: int = 100, suffix: str = "...") -> str:
"""Truncate string to max length"""
if len(text) <= max_length:
return text
return text[:max_length - len(suffix)] + suffix
def format_file_size(size_bytes: int) -> str:
"""Format file size in human readable format"""
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if size_bytes < 1024.0:
return f"{size_bytes:.2f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.2f} PB"
def is_safe_path(path: str, base_path: str) -> bool:
"""Check if path is safe (no directory traversal)"""
import os
# Resolve to absolute paths
base = os.path.abspath(base_path)
target = os.path.abspath(os.path.join(base, path))
# Check if target is under base
return target.startswith(base)
def get_current_timestamp() -> str:
"""
Get current UTC timestamp in ISO format with Z suffix
Returns: "2025-01-10T12:00:00.123Z"
"""
return datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z')
def normalize_timestamp(timestamp: Optional[str]) -> str:
"""
Normalize timestamp string for consistent comparison
Handles various formats:
- "2025-01-10T12:00:00Z"
- "2025-01-10T12:00:00.000Z"
- "2025-01-10T12:00:00+00:00"
- "2025-01-10 12:00:00+00:00"
"""
if not timestamp:
return ""
# Normalize various formats
normalized = timestamp.replace(' ', 'T') # Space to T
normalized = normalized.replace('+00:00', 'Z') # UTC timezone
# Remove milliseconds if present for comparison
if '.' in normalized and normalized.endswith('Z'):
normalized = normalized.split('.')[0] + 'Z'
return normalized
def timestamps_equal(ts1: Optional[str], ts2: Optional[str]) -> bool:
"""
Compare two timestamps regardless of format differences
"""
return normalize_timestamp(ts1) == normalize_timestamp(ts2)