Spaces:
Sleeping
Sleeping
File size: 2,757 Bytes
7e1e2b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from pydantic import BaseModel
import os
from huggingface_hub import HfApi
import time
from dotenv import load_dotenv
load_dotenv()
api = HfApi(token=os.getenv("HF_TOKEN"))
PASSWORD = os.getenv("PASSWORD")
app = FastAPI()
repo_url = os.environ["HF_SPACE_ID"].replace("/", "-")
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["localhost", f"{repo_url}.hf.space"] # Replace with your actual HF space URL
)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:7860", f"https://{repo_url}.hf.space"], # Replace with your actual HF space URL
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Rate limiting
class RateLimiter:
def __init__(self, max_attempts: int = 5, window_seconds: int = 300):
self.max_attempts = max_attempts
self.window_seconds = window_seconds
self.attempts = {}
async def check_rate_limit(self, ip: str) -> bool:
now = time.time()
if ip in self.attempts:
attempts = [t for t in self.attempts[ip] if now - t < self.window_seconds]
self.attempts[ip] = attempts
if len(attempts) >= self.max_attempts:
raise HTTPException(
status_code=429,
detail=f"Too many attempts. Try again in {self.window_seconds} seconds"
)
else:
self.attempts[ip] = []
self.attempts[ip].append(now)
return True
rate_limiter = RateLimiter()
class PasswordCheck(BaseModel):
password: str
@app.post("/api/verify-password")
async def verify_password(password_check: PasswordCheck, request: Request):
await rate_limiter.check_rate_limit(request.client.host)
if password_check.password == PASSWORD:
# Return list of available items
items = api.list_repo_files(repo_id=os.environ["HF_DATASET_ID"], repo_type="dataset")
return sorted(items)
raise HTTPException(status_code=401, detail="Invalid password")
@app.get("/api/download/{item_name}")
async def download_item(item_name: str, request: Request):
await rate_limiter.check_rate_limit(request.client.host)
filepath = api.hf_hub_download(repo_id=os.environ["HF_DATASET_ID"], filename=item_name, repo_type="dataset")
return FileResponse(filepath, filename=item_name)
app.mount("/", StaticFiles(directory="static", html=True), name="static")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |