File size: 6,167 Bytes
5a65482
8de92d1
 
 
 
 
5a65482
8de92d1
 
 
 
5a65482
 
8de92d1
 
 
 
 
 
 
 
 
 
 
 
5a65482
 
 
 
8de92d1
 
 
 
5a65482
8de92d1
 
 
 
5a65482
8de92d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27bba9c
 
 
 
5a65482
8de92d1
 
 
 
 
5a65482
8de92d1
 
 
5a65482
 
8de92d1
27bba9c
 
 
 
 
5a65482
8de92d1
 
 
 
 
 
 
 
 
5a65482
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3da50d1
 
 
 
 
5a65482
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8de92d1
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from base64 import b64encode
import sqlite3
import json
import os
from typing import Tuple, List, Optional
from functools import lru_cache
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import uvicorn
from fastapi.responses import FileResponse, RedirectResponse
import requests

app = FastAPI(title="WikiSpeedia API")

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allows all origins
    allow_credentials=True,
    allow_methods=["*"],  # Allows all methods
    allow_headers=["*"],  # Allows all headers
)

CLIENT_SECRET = os.getenv("HUGGINGFACE_CLIENT_SECRET")
IS_PROD = os.getenv("VITE_ENV") == "production"
print("CLIENT_SECRET:", CLIENT_SECRET)
print("IS_PROD:", IS_PROD)
class ArticleResponse(BaseModel):
    title: str
    links: List[str]


class HealthResponse(BaseModel):
    status: str
    article_count: int


class SQLiteDB:
    def __init__(self, db_path: str):
        """Initialize the database with path to SQLite database"""
        self.db_path = db_path
        self.conn = sqlite3.connect(db_path)
        self.conn.row_factory = sqlite3.Row
        self.cursor = self.conn.cursor()
        self._article_count = self._get_article_count()
        print(f"Connected to SQLite database with {self._article_count} articles")

    def _get_article_count(self):
        self.cursor.execute("SELECT COUNT(*) FROM core_articles")
        return self.cursor.fetchone()[0]

    @lru_cache(maxsize=8192)
    def get_article_with_links(self, article_title: str) -> Tuple[str, List[str]]:
        self.cursor.execute(
            "SELECT title, links_json FROM core_articles WHERE title = ?",
            (article_title,),
        )
        article = self.cursor.fetchone()
        if not article:
            return None, []

        links = json.loads(article["links_json"])
        return article["title"], links

    def get_all_articles(self):
        self.cursor.execute("SELECT title FROM core_articles")
        return [row[0] for row in self.cursor.fetchall()]


# Initialize database connection
db = SQLiteDB(
    os.getenv("WIKISPEEDIA_DB_PATH", "/Users/jts/daily/wikihop/db/data/wikihop.db")
)


@app.get("/health", response_model=HealthResponse)
async def health_check():
    """Health check endpoint that returns the article count"""
    return HealthResponse(status="healthy", article_count=db._article_count)


@app.get("/get_all_articles", response_model=List[str])
async def get_all_articles():
    """Get all articles"""
    return db.get_all_articles()


@app.get("/get_article_with_links/{article_title}", response_model=ArticleResponse)
async def get_article(article_title: str):
    """Get article and its links by title"""
    title, links = db.get_article_with_links(article_title)
    if title is None:
        raise HTTPException(status_code=404, detail="Article not found")
    return ArticleResponse(title=title, links=links)


@app.get("/auth/callback")
async def auth_callback(request: Request):

    OAUTH_API_BASE = "https://huggingface.co/oauth/token"
    CLIENT_ID = "a67ef241-fb7e-4300-a6bd-8430a7565c9a"

    code = request.query_params.get("code")
    if not code:
        raise HTTPException(status_code=400, detail="No code provided")

    response = requests.post(
        OAUTH_API_BASE,
        headers={
            "Content-Type": "application/x-www-form-urlencoded",
            "Authorization": f"Basic {b64encode(f'{CLIENT_ID}:{CLIENT_SECRET}'.encode()).decode()}",
        },
        data={
            "client_id": CLIENT_ID,
            "code": code,
            "grant_type": "authorization_code",
            "redirect_uri": (
                "http://localhost:8000/auth/callback"
                if not IS_PROD
                else "https://huggingfacetb-wikiracing-llms.hf.space/auth/callback"
            ),
        },
    )

    # response.json() =
    # {
    #     "access_token": "hf_oauth_eyJhbGciOiJFZERTQSJ9.eyJzY29wZSI6WyJvcGVuaWQiLCJwcm9maWxlIiwiZW1haWwiLCJpbmZlcmVuY2UtYXBpIl0sImF1ZCI6Imh0dHBzOi8vaHVnZ2luZ2ZhY2UuY28iLCJvYXV0aEFwcCI6ImE2N2VmMjQxLWZiN2UtNDMwMC1hNmJkLTg0MzBhNzU2NWM5YSIsInNlc3Npb25JZCI6IjY3YTBkYjk3OWNmZDQ3ZGFkOGNmNDMwNyIsImlhdCI6MTc0NjIxOTEwOCwic3ViIjoiNjE3OGQ4NDIyNjczMjBhYmI5OWRmNzc2IiwiZXhwIjoxNzQ2MjQ3OTA4LCJpc3MiOiJodHRwczovL2h1Z2dpbmdmYWNlLmNvIn0.TNK7Nb2X22LHlFqleo6rzJjBngjTWpVIksE1Mw7m8vVxgr7CBbK_a1J4cW488n02391qqopcaNlZKFP8noZSAA",
    #     "token_type": "bearer",
    #     "expires_in": 28799,
    #     "id_token": "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiI2MTc4ZDg0MjI2NzMyMGFiYjk5ZGY3NzYiLCJuYW1lIjoiSmFzb24gU3RpbGxlcm1hbiIsInByZWZlcnJlZF91c2VybmFtZSI6InN0aWxsZXJtYW4iLCJwcm9maWxlIjoiaHR0cHM6Ly9odWdnaW5nZmFjZS5jby9zdGlsbGVybWFuIiwicGljdHVyZSI6Imh0dHBzOi8vaHVnZ2luZ2ZhY2UuY28vYXZhdGFycy84NzM5NzA1ZWY3ZWFiYzk0NWExZWYzYzA3MTk2YWYxMy5zdmciLCJlbWFpbCI6Imphc29uLnQuc3RpbGxlcm1hbkBnbWFpbC5jb20iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXVkIjoiYTY3ZWYyNDEtZmI3ZS00MzAwLWE2YmQtODQzMGE3NTY1YzlhIiwiYXV0aF90aW1lIjoxNzQ2MjE5MTA4LCJpYXQiOjE3NDYyMTkxMDgsImV4cCI6MTc0NjIyMjcwOCwiaXNzIjoiaHR0cHM6Ly9odWdnaW5nZmFjZS5jbyJ9.pB7j-jkxxMG3GJNzipMNCsKQimk8_R0TcPrwi-Kln6qXcSccwGcWJvyMZvFRHjKB779UkMTzgCO-eY1CINX75KaRALLS_Eu0w448F_5LMixwpBXA6dntXBEdP69VLXakpXaPHjFY2HuvUN7fbE8e2_v4a-s7RRwHTDJIcxyH2Bd_OUpebFy1N6RNB_9MIL3jxXhsXyLNL2uDry0WIB52BJKBXB4EzE12HDGgNaWR6lrqr4nvjAExsGcTwarPhFSA5ndcbgh82vJxB3rVFhSU4iZ5AmMV1mDX6SgRVdPmWZPgTBwGeGlVN-OAHvLlNJ9FZ_i0qjrtA5IRU0o6ctKrfw",
    #     "scope": "openid profile email inference-api",
    #     "refresh_token": "hf_oauth__refresh_RiVshOppmioFVoxvYMXSPkMdyyzbyIqadj",
    # }

    print(response.json())

    # redirect to the home page with access token and id token in the url
    return RedirectResponse(url=f"/?access_token={response.json()['access_token']}&id_token={response.json()['id_token']}")

    """Auth callback endpoint"""
    return {"message": "Auth callback received"}


# Mount the dist folder for static files
app.mount("/", StaticFiles(directory="dist", html=True), name="static")


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)