File size: 2,757 Bytes
7e1e2b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from pydantic import BaseModel
import os
from huggingface_hub import HfApi
import time

from dotenv import load_dotenv

load_dotenv()

api = HfApi(token=os.getenv("HF_TOKEN"))

PASSWORD = os.getenv("PASSWORD")

app = FastAPI()

repo_url = os.environ["HF_SPACE_ID"].replace("/", "-")

app.add_middleware(
    TrustedHostMiddleware,
    allowed_hosts=["localhost", f"{repo_url}.hf.space"]  # Replace with your actual HF space URL
)


app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:7860", f"https://{repo_url}.hf.space"],  # Replace with your actual HF space URL
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
# Rate limiting
class RateLimiter:
    def __init__(self, max_attempts: int = 5, window_seconds: int = 300):
        self.max_attempts = max_attempts
        self.window_seconds = window_seconds
        self.attempts = {}

    async def check_rate_limit(self, ip: str) -> bool:
        now = time.time()
        if ip in self.attempts:
            attempts = [t for t in self.attempts[ip] if now - t < self.window_seconds]
            self.attempts[ip] = attempts
            if len(attempts) >= self.max_attempts:
                raise HTTPException(
                    status_code=429,
                    detail=f"Too many attempts. Try again in {self.window_seconds} seconds"
                )
        else:
            self.attempts[ip] = []
        self.attempts[ip].append(now)
        return True

rate_limiter = RateLimiter()

class PasswordCheck(BaseModel):
    password: str

@app.post("/api/verify-password")
async def verify_password(password_check: PasswordCheck, request: Request):
    await rate_limiter.check_rate_limit(request.client.host)

    if password_check.password == PASSWORD:
        # Return list of available items
        items = api.list_repo_files(repo_id=os.environ["HF_DATASET_ID"], repo_type="dataset")
        return sorted(items)
    raise HTTPException(status_code=401, detail="Invalid password")

@app.get("/api/download/{item_name}")
async def download_item(item_name: str, request: Request):
    await rate_limiter.check_rate_limit(request.client.host)

    filepath = api.hf_hub_download(repo_id=os.environ["HF_DATASET_ID"], filename=item_name, repo_type="dataset")
    return FileResponse(filepath, filename=item_name)

app.mount("/", StaticFiles(directory="static", html=True), name="static")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)