import asyncio import os from collections import defaultdict from datetime import datetime, timedelta from urllib.parse import urlparse from fastapi import Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request as StarletteRequest from starlette.responses import Response from starlette.types import ASGIApp class LocalhostCORSMiddleware(CORSMiddleware): """ Custom CORS middleware that allows any request from localhost/127.0.0.1 domains, while using standard CORS rules for other origins. """ def __init__(self, app: ASGIApp) -> None: allow_origins_str = os.getenv('PERMITTED_CORS_ORIGINS') if allow_origins_str: allow_origins = tuple( origin.strip() for origin in allow_origins_str.split(',') ) else: allow_origins = () super().__init__( app, allow_origins=allow_origins, allow_credentials=True, allow_methods=['*'], allow_headers=['*'], ) def is_allowed_origin(self, origin: str) -> bool: if origin and not self.allow_origins and not self.allow_origin_regex: parsed = urlparse(origin) hostname = parsed.hostname or '' # Allow any localhost/127.0.0.1 origin regardless of port if hostname in ['localhost', '127.0.0.1']: return True # For missing origin or other origins, use the parent class's logic result: bool = super().is_allowed_origin(origin) return result class CacheControlMiddleware(BaseHTTPMiddleware): """ Middleware to disable caching for all routes by adding appropriate headers """ async def dispatch( self, request: Request, call_next: RequestResponseEndpoint ) -> Response: response = await call_next(request) if request.url.path.startswith('/assets'): # The content of the assets directory has fingerprinted file names so we cache aggressively response.headers['Cache-Control'] = 'public, max-age=2592000, immutable' else: response.headers['Cache-Control'] = ( 'no-cache, no-store, must-revalidate, max-age=0' ) response.headers['Pragma'] = 'no-cache' response.headers['Expires'] = '0' return response class InMemoryRateLimiter: history: dict[str, list[datetime]] requests: int seconds: int sleep_seconds: int def __init__(self, requests: int = 2, seconds: int = 1, sleep_seconds: int = 1): self.requests = requests self.seconds = seconds self.sleep_seconds = sleep_seconds self.history = defaultdict(list) self.sleep_seconds = sleep_seconds def _clean_old_requests(self, key: str) -> None: now = datetime.now() cutoff = now - timedelta(seconds=self.seconds) self.history[key] = [ts for ts in self.history[key] if ts > cutoff] async def __call__(self, request: Request) -> bool: key = request.client.host now = datetime.now() self._clean_old_requests(key) self.history[key].append(now) if len(self.history[key]) > self.requests * 2: return False elif len(self.history[key]) > self.requests: if self.sleep_seconds > 0: await asyncio.sleep(self.sleep_seconds) return True else: return False return True class RateLimitMiddleware(BaseHTTPMiddleware): def __init__(self, app: ASGIApp, rate_limiter: InMemoryRateLimiter): super().__init__(app) self.rate_limiter = rate_limiter async def dispatch( self, request: Request, call_next: RequestResponseEndpoint ) -> Response: if not self.is_rate_limited_request(request): return await call_next(request) ok = await self.rate_limiter(request) if not ok: return JSONResponse( status_code=429, content={'message': 'Too many requests'}, headers={'Retry-After': '1'}, ) return await call_next(request) def is_rate_limited_request(self, request: StarletteRequest) -> bool: if request.url.path.startswith('/assets'): return False # Put Other non rate limited checks here return True