Spaces:
Build error
Build error
import asyncio | |
import os | |
from collections import defaultdict | |
from datetime import datetime, timedelta | |
from urllib.parse import urlparse | |
from fastapi import Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint | |
from starlette.requests import Request as StarletteRequest | |
from starlette.responses import Response | |
from starlette.types import ASGIApp | |
class LocalhostCORSMiddleware(CORSMiddleware): | |
""" | |
Custom CORS middleware that allows any request from localhost/127.0.0.1 domains, | |
while using standard CORS rules for other origins. | |
""" | |
def __init__(self, app: ASGIApp) -> None: | |
allow_origins_str = os.getenv('PERMITTED_CORS_ORIGINS') | |
if allow_origins_str: | |
allow_origins = tuple( | |
origin.strip() for origin in allow_origins_str.split(',') | |
) | |
else: | |
allow_origins = () | |
super().__init__( | |
app, | |
allow_origins=allow_origins, | |
allow_credentials=True, | |
allow_methods=['*'], | |
allow_headers=['*'], | |
) | |
def is_allowed_origin(self, origin: str) -> bool: | |
if origin and not self.allow_origins and not self.allow_origin_regex: | |
parsed = urlparse(origin) | |
hostname = parsed.hostname or '' | |
# Allow any localhost/127.0.0.1 origin regardless of port | |
if hostname in ['localhost', '127.0.0.1']: | |
return True | |
# For missing origin or other origins, use the parent class's logic | |
result: bool = super().is_allowed_origin(origin) | |
return result | |
class CacheControlMiddleware(BaseHTTPMiddleware): | |
""" | |
Middleware to disable caching for all routes by adding appropriate headers | |
""" | |
async def dispatch( | |
self, request: Request, call_next: RequestResponseEndpoint | |
) -> Response: | |
response = await call_next(request) | |
if request.url.path.startswith('/assets'): | |
# The content of the assets directory has fingerprinted file names so we cache aggressively | |
response.headers['Cache-Control'] = 'public, max-age=2592000, immutable' | |
else: | |
response.headers['Cache-Control'] = ( | |
'no-cache, no-store, must-revalidate, max-age=0' | |
) | |
response.headers['Pragma'] = 'no-cache' | |
response.headers['Expires'] = '0' | |
return response | |
class InMemoryRateLimiter: | |
history: dict[str, list[datetime]] | |
requests: int | |
seconds: int | |
sleep_seconds: int | |
def __init__(self, requests: int = 2, seconds: int = 1, sleep_seconds: int = 1): | |
self.requests = requests | |
self.seconds = seconds | |
self.sleep_seconds = sleep_seconds | |
self.history = defaultdict(list) | |
self.sleep_seconds = sleep_seconds | |
def _clean_old_requests(self, key: str) -> None: | |
now = datetime.now() | |
cutoff = now - timedelta(seconds=self.seconds) | |
self.history[key] = [ts for ts in self.history[key] if ts > cutoff] | |
async def __call__(self, request: Request) -> bool: | |
key = request.client.host | |
now = datetime.now() | |
self._clean_old_requests(key) | |
self.history[key].append(now) | |
if len(self.history[key]) > self.requests * 2: | |
return False | |
elif len(self.history[key]) > self.requests: | |
if self.sleep_seconds > 0: | |
await asyncio.sleep(self.sleep_seconds) | |
return True | |
else: | |
return False | |
return True | |
class RateLimitMiddleware(BaseHTTPMiddleware): | |
def __init__(self, app: ASGIApp, rate_limiter: InMemoryRateLimiter): | |
super().__init__(app) | |
self.rate_limiter = rate_limiter | |
async def dispatch( | |
self, request: Request, call_next: RequestResponseEndpoint | |
) -> Response: | |
if not self.is_rate_limited_request(request): | |
return await call_next(request) | |
ok = await self.rate_limiter(request) | |
if not ok: | |
return JSONResponse( | |
status_code=429, | |
content={'message': 'Too many requests'}, | |
headers={'Retry-After': '1'}, | |
) | |
return await call_next(request) | |
def is_rate_limited_request(self, request: StarletteRequest) -> bool: | |
if request.url.path.startswith('/assets'): | |
return False | |
# Put Other non rate limited checks here | |
return True | |