from fastapi import FastAPI, HTTPException, Request from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.trustedhost import TrustedHostMiddleware from pydantic import BaseModel import os from huggingface_hub import HfApi import time from dotenv import load_dotenv load_dotenv() api = HfApi(token=os.getenv("HF_TOKEN")) PASSWORD = os.getenv("PASSWORD") app = FastAPI() repo_url = os.environ["HF_SPACE_ID"].replace("/", "-") app.add_middleware( TrustedHostMiddleware, allowed_hosts=["localhost", f"{repo_url}.hf.space"] # Replace with your actual HF space URL ) app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:7860", f"https://{repo_url}.hf.space"], # Replace with your actual HF space URL allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Rate limiting class RateLimiter: def __init__(self, max_attempts: int = 5, window_seconds: int = 300): self.max_attempts = max_attempts self.window_seconds = window_seconds self.attempts = {} async def check_rate_limit(self, ip: str) -> bool: now = time.time() if ip in self.attempts: attempts = [t for t in self.attempts[ip] if now - t < self.window_seconds] self.attempts[ip] = attempts if len(attempts) >= self.max_attempts: raise HTTPException( status_code=429, detail=f"Too many attempts. Try again in {self.window_seconds} seconds" ) else: self.attempts[ip] = [] self.attempts[ip].append(now) return True rate_limiter = RateLimiter() class PasswordCheck(BaseModel): password: str @app.post("/api/verify-password") async def verify_password(password_check: PasswordCheck, request: Request): await rate_limiter.check_rate_limit(request.client.host) if password_check.password == PASSWORD: # Return list of available items items = api.list_repo_files(repo_id=os.environ["HF_DATASET_ID"], repo_type="dataset") return sorted(items) raise HTTPException(status_code=401, detail="Invalid password") @app.get("/api/download/{item_name}") async def download_item(item_name: str, request: Request): await rate_limiter.check_rate_limit(request.client.host) filepath = api.hf_hub_download(repo_id=os.environ["HF_DATASET_ID"], filename=item_name, repo_type="dataset") return FileResponse(filepath, filename=item_name) app.mount("/", StaticFiles(directory="static", html=True), name="static") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)