flare / utils.py
ciyidogan's picture
Update utils.py
a1c5ec0 verified
raw
history blame
4.22 kB
import os
from datetime import datetime, timedelta
import jwt
# ===================== Rate Limiting =====================
class RateLimiter:
"""Simple in-memory rate limiter"""
def __init__(self):
self.requests = {} # {key: [(timestamp, count)]}
self.lock = threading.Lock()
def is_allowed(self, key: str, max_requests: int, window_seconds: int) -> bool:
"""Check if request is allowed"""
with self.lock:
now = datetime.now(timezone.utc)
if key not in self.requests:
self.requests[key] = []
# Remove old entries
cutoff = now.timestamp() - window_seconds
self.requests[key] = [
(ts, count) for ts, count in self.requests[key]
if ts > cutoff
]
# Count requests in window
total = sum(count for _, count in self.requests[key])
if total >= max_requests:
return False
# Add this request
self.requests[key].append((now.timestamp(), 1))
return True
def reset(self, key: str):
"""Reset rate limit for key"""
with self.lock:
if key in self.requests:
del self.requests[key]
# Create global rate limiter instance
import threading
rate_limiter = RateLimiter()
# ===================== JWT Config =====================
def get_jwt_config():
"""Get JWT configuration based on environment"""
# Check if we're in HuggingFace Space
if os.getenv("SPACE_ID"):
# Cloud mode - use secrets from environment
jwt_secret = os.getenv("JWT_SECRET")
if not jwt_secret:
log("⚠️ WARNING: JWT_SECRET not found in environment, using fallback")
jwt_secret = "flare-admin-secret-key-change-in-production" # Fallback
else:
# On-premise mode - use .env file
from dotenv import load_dotenv
load_dotenv()
jwt_secret = os.getenv("JWT_SECRET", "flare-admin-secret-key-change-in-production")
return {
"secret": jwt_secret,
"algorithm": os.getenv("JWT_ALGORITHM", "HS256"),
"expiration_hours": int(os.getenv("JWT_EXPIRATION_HOURS", "24"))
}
# ===================== Auth Helpers =====================
def create_token(username: str) -> str:
"""Create JWT token for user"""
config = get_jwt_config()
expiry = datetime.now(timezone.utc) + timedelta(hours=config["expiration_hours"])
payload = {
"sub": username,
"exp": expiry,
"iat": datetime.now(timezone.utc)
}
return jwt.encode(payload, config["secret"], algorithm=config["algorithm"])
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str:
"""Verify JWT token and return username"""
token = credentials.credentials
config = get_jwt_config()
try:
payload = jwt.decode(token, config["secret"], algorithms=[config["algorithm"]])
return payload["sub"]
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")
# ===================== Utility Functions =====================
def truncate_string(text: str, max_length: int = 100, suffix: str = "...") -> str:
"""Truncate string to max length"""
if len(text) <= max_length:
return text
return text[:max_length - len(suffix)] + suffix
def format_file_size(size_bytes: int) -> str:
"""Format file size in human readable format"""
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if size_bytes < 1024.0:
return f"{size_bytes:.2f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.2f} PB"
def is_safe_path(path: str, base_path: str) -> bool:
"""Check if path is safe (no directory traversal)"""
import os
# Resolve to absolute paths
base = os.path.abspath(base_path)
target = os.path.abspath(os.path.join(base, path))
# Check if target is under base
return target.startswith(base)