Spaces:
Running
Running
Commit
·
5a65482
1
Parent(s):
60d35d7
login via fastAPI
Browse files- Dockerfile +1 -1
- api.py +58 -6
- src/components/sign-in-with-hf-button.tsx +26 -68
- src/lib/constants.ts +2 -1
Dockerfile
CHANGED
@@ -42,7 +42,7 @@ RUN curl -L https://huggingface.co/HuggingFaceTB/simplewiki-pruned-text-350k/res
|
|
42 |
|
43 |
ENV WIKISPEEDIA_DB_PATH=/home/user/app/wikihop.db
|
44 |
|
45 |
-
ENV
|
46 |
|
47 |
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
|
48 |
|
|
|
42 |
|
43 |
ENV WIKISPEEDIA_DB_PATH=/home/user/app/wikihop.db
|
44 |
|
45 |
+
ENV VITE_ENV=production
|
46 |
|
47 |
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
|
48 |
|
api.py
CHANGED
@@ -1,14 +1,16 @@
|
|
|
|
1 |
import sqlite3
|
2 |
import json
|
3 |
import os
|
4 |
from typing import Tuple, List, Optional
|
5 |
from functools import lru_cache
|
6 |
-
from fastapi import FastAPI, HTTPException
|
7 |
from fastapi.middleware.cors import CORSMiddleware
|
8 |
from fastapi.staticfiles import StaticFiles
|
9 |
from pydantic import BaseModel
|
10 |
import uvicorn
|
11 |
-
from fastapi.responses import FileResponse
|
|
|
12 |
|
13 |
app = FastAPI(title="WikiSpeedia API")
|
14 |
|
@@ -21,14 +23,20 @@ app.add_middleware(
|
|
21 |
allow_headers=["*"], # Allows all headers
|
22 |
)
|
23 |
|
|
|
|
|
|
|
|
|
24 |
class ArticleResponse(BaseModel):
|
25 |
title: str
|
26 |
links: List[str]
|
27 |
|
|
|
28 |
class HealthResponse(BaseModel):
|
29 |
status: str
|
30 |
article_count: int
|
31 |
|
|
|
32 |
class SQLiteDB:
|
33 |
def __init__(self, db_path: str):
|
34 |
"""Initialize the database with path to SQLite database"""
|
@@ -60,24 +68,25 @@ class SQLiteDB:
|
|
60 |
self.cursor.execute("SELECT title FROM core_articles")
|
61 |
return [row[0] for row in self.cursor.fetchall()]
|
62 |
|
|
|
63 |
# Initialize database connection
|
64 |
db = SQLiteDB(
|
65 |
os.getenv("WIKISPEEDIA_DB_PATH", "/Users/jts/daily/wikihop/db/data/wikihop.db")
|
66 |
)
|
67 |
|
|
|
68 |
@app.get("/health", response_model=HealthResponse)
|
69 |
async def health_check():
|
70 |
"""Health check endpoint that returns the article count"""
|
71 |
-
return HealthResponse(
|
72 |
-
|
73 |
-
article_count=db._article_count
|
74 |
-
)
|
75 |
|
76 |
@app.get("/get_all_articles", response_model=List[str])
|
77 |
async def get_all_articles():
|
78 |
"""Get all articles"""
|
79 |
return db.get_all_articles()
|
80 |
|
|
|
81 |
@app.get("/get_article_with_links/{article_title}", response_model=ArticleResponse)
|
82 |
async def get_article(article_title: str):
|
83 |
"""Get article and its links by title"""
|
@@ -87,6 +96,49 @@ async def get_article(article_title: str):
|
|
87 |
return ArticleResponse(title=title, links=links)
|
88 |
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
# Mount the dist folder for static files
|
91 |
app.mount("/", StaticFiles(directory="dist", html=True), name="static")
|
92 |
|
|
|
1 |
+
from base64 import b64encode
|
2 |
import sqlite3
|
3 |
import json
|
4 |
import os
|
5 |
from typing import Tuple, List, Optional
|
6 |
from functools import lru_cache
|
7 |
+
from fastapi import FastAPI, HTTPException, Request
|
8 |
from fastapi.middleware.cors import CORSMiddleware
|
9 |
from fastapi.staticfiles import StaticFiles
|
10 |
from pydantic import BaseModel
|
11 |
import uvicorn
|
12 |
+
from fastapi.responses import FileResponse, RedirectResponse
|
13 |
+
import requests
|
14 |
|
15 |
app = FastAPI(title="WikiSpeedia API")
|
16 |
|
|
|
23 |
allow_headers=["*"], # Allows all headers
|
24 |
)
|
25 |
|
26 |
+
CLIENT_SECRET = os.getenv("HUGGINGFACE_CLIENT_SECRET")
|
27 |
+
IS_PROD = os.getenv("VITE_ENV") == "production"
|
28 |
+
print("CLIENT_SECRET:", CLIENT_SECRET)
|
29 |
+
print("IS_PROD:", IS_PROD)
|
30 |
class ArticleResponse(BaseModel):
|
31 |
title: str
|
32 |
links: List[str]
|
33 |
|
34 |
+
|
35 |
class HealthResponse(BaseModel):
|
36 |
status: str
|
37 |
article_count: int
|
38 |
|
39 |
+
|
40 |
class SQLiteDB:
|
41 |
def __init__(self, db_path: str):
|
42 |
"""Initialize the database with path to SQLite database"""
|
|
|
68 |
self.cursor.execute("SELECT title FROM core_articles")
|
69 |
return [row[0] for row in self.cursor.fetchall()]
|
70 |
|
71 |
+
|
72 |
# Initialize database connection
|
73 |
db = SQLiteDB(
|
74 |
os.getenv("WIKISPEEDIA_DB_PATH", "/Users/jts/daily/wikihop/db/data/wikihop.db")
|
75 |
)
|
76 |
|
77 |
+
|
78 |
@app.get("/health", response_model=HealthResponse)
|
79 |
async def health_check():
|
80 |
"""Health check endpoint that returns the article count"""
|
81 |
+
return HealthResponse(status="healthy", article_count=db._article_count)
|
82 |
+
|
|
|
|
|
83 |
|
84 |
@app.get("/get_all_articles", response_model=List[str])
|
85 |
async def get_all_articles():
|
86 |
"""Get all articles"""
|
87 |
return db.get_all_articles()
|
88 |
|
89 |
+
|
90 |
@app.get("/get_article_with_links/{article_title}", response_model=ArticleResponse)
|
91 |
async def get_article(article_title: str):
|
92 |
"""Get article and its links by title"""
|
|
|
96 |
return ArticleResponse(title=title, links=links)
|
97 |
|
98 |
|
99 |
+
@app.get("/auth/callback")
|
100 |
+
async def auth_callback(request: Request):
|
101 |
+
|
102 |
+
OAUTH_API_BASE = "https://huggingface.co/oauth/token"
|
103 |
+
CLIENT_ID = "a67ef241-fb7e-4300-a6bd-8430a7565c9a"
|
104 |
+
|
105 |
+
code = request.query_params.get("code")
|
106 |
+
if not code:
|
107 |
+
raise HTTPException(status_code=400, detail="No code provided")
|
108 |
+
|
109 |
+
response = requests.post(
|
110 |
+
OAUTH_API_BASE,
|
111 |
+
headers={
|
112 |
+
"Content-Type": "application/x-www-form-urlencoded",
|
113 |
+
"Authorization": f"Basic {b64encode(f'{CLIENT_ID}:{CLIENT_SECRET}'.encode()).decode()}",
|
114 |
+
},
|
115 |
+
data={
|
116 |
+
"client_id": CLIENT_ID,
|
117 |
+
"code": code,
|
118 |
+
"grant_type": "authorization_code",
|
119 |
+
"redirect_uri": "http://localhost:8000/auth/callback" if not IS_PROD else "https://huggingfacetb-wikispeedia.hf.space/auth/callback",
|
120 |
+
},
|
121 |
+
)
|
122 |
+
|
123 |
+
# response.json() =
|
124 |
+
# {
|
125 |
+
# "access_token": "hf_oauth_eyJhbGciOiJFZERTQSJ9.eyJzY29wZSI6WyJvcGVuaWQiLCJwcm9maWxlIiwiZW1haWwiLCJpbmZlcmVuY2UtYXBpIl0sImF1ZCI6Imh0dHBzOi8vaHVnZ2luZ2ZhY2UuY28iLCJvYXV0aEFwcCI6ImE2N2VmMjQxLWZiN2UtNDMwMC1hNmJkLTg0MzBhNzU2NWM5YSIsInNlc3Npb25JZCI6IjY3YTBkYjk3OWNmZDQ3ZGFkOGNmNDMwNyIsImlhdCI6MTc0NjIxOTEwOCwic3ViIjoiNjE3OGQ4NDIyNjczMjBhYmI5OWRmNzc2IiwiZXhwIjoxNzQ2MjQ3OTA4LCJpc3MiOiJodHRwczovL2h1Z2dpbmdmYWNlLmNvIn0.TNK7Nb2X22LHlFqleo6rzJjBngjTWpVIksE1Mw7m8vVxgr7CBbK_a1J4cW488n02391qqopcaNlZKFP8noZSAA",
|
126 |
+
# "token_type": "bearer",
|
127 |
+
# "expires_in": 28799,
|
128 |
+
# "id_token": "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiI2MTc4ZDg0MjI2NzMyMGFiYjk5ZGY3NzYiLCJuYW1lIjoiSmFzb24gU3RpbGxlcm1hbiIsInByZWZlcnJlZF91c2VybmFtZSI6InN0aWxsZXJtYW4iLCJwcm9maWxlIjoiaHR0cHM6Ly9odWdnaW5nZmFjZS5jby9zdGlsbGVybWFuIiwicGljdHVyZSI6Imh0dHBzOi8vaHVnZ2luZ2ZhY2UuY28vYXZhdGFycy84NzM5NzA1ZWY3ZWFiYzk0NWExZWYzYzA3MTk2YWYxMy5zdmciLCJlbWFpbCI6Imphc29uLnQuc3RpbGxlcm1hbkBnbWFpbC5jb20iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXVkIjoiYTY3ZWYyNDEtZmI3ZS00MzAwLWE2YmQtODQzMGE3NTY1YzlhIiwiYXV0aF90aW1lIjoxNzQ2MjE5MTA4LCJpYXQiOjE3NDYyMTkxMDgsImV4cCI6MTc0NjIyMjcwOCwiaXNzIjoiaHR0cHM6Ly9odWdnaW5nZmFjZS5jbyJ9.pB7j-jkxxMG3GJNzipMNCsKQimk8_R0TcPrwi-Kln6qXcSccwGcWJvyMZvFRHjKB779UkMTzgCO-eY1CINX75KaRALLS_Eu0w448F_5LMixwpBXA6dntXBEdP69VLXakpXaPHjFY2HuvUN7fbE8e2_v4a-s7RRwHTDJIcxyH2Bd_OUpebFy1N6RNB_9MIL3jxXhsXyLNL2uDry0WIB52BJKBXB4EzE12HDGgNaWR6lrqr4nvjAExsGcTwarPhFSA5ndcbgh82vJxB3rVFhSU4iZ5AmMV1mDX6SgRVdPmWZPgTBwGeGlVN-OAHvLlNJ9FZ_i0qjrtA5IRU0o6ctKrfw",
|
129 |
+
# "scope": "openid profile email inference-api",
|
130 |
+
# "refresh_token": "hf_oauth__refresh_RiVshOppmioFVoxvYMXSPkMdyyzbyIqadj",
|
131 |
+
# }
|
132 |
+
|
133 |
+
print(response.json())
|
134 |
+
|
135 |
+
# redirect to the home page with access token and id token in the url
|
136 |
+
return RedirectResponse(url=f"/?access_token={response.json()['access_token']}&id_token={response.json()['id_token']}")
|
137 |
+
|
138 |
+
"""Auth callback endpoint"""
|
139 |
+
return {"message": "Auth callback received"}
|
140 |
+
|
141 |
+
|
142 |
# Mount the dist folder for static files
|
143 |
app.mount("/", StaticFiles(directory="dist", html=True), name="static")
|
144 |
|
src/components/sign-in-with-hf-button.tsx
CHANGED
@@ -4,94 +4,52 @@ import { useEffect } from "react";
|
|
4 |
import { jwtDecode } from "jwt-decode";
|
5 |
|
6 |
const CLIENT_ID = "a67ef241-fb7e-4300-a6bd-8430a7565c9a";
|
7 |
-
const REDIRECT_URI = "https://huggingfacetb-wikispeedia.hf.space";
|
8 |
// const REDIRECT_URI = "http://localhost:5173/auth/callback";
|
|
|
|
|
9 |
const SCOPE = "openid%20profile%20email%20inference-api";
|
10 |
const STATE = "1234567890";
|
11 |
const SSO_URL = `https://huggingface.co/oauth/authorize?client_id=${CLIENT_ID}&redirect_uri=${REDIRECT_URI}&response_type=code&scope=${SCOPE}&prompt=consent&state=${STATE}`;
|
12 |
-
const OAUTH_API_BASE = "https://huggingface.co/oauth/token";
|
13 |
-
const CLIENT_SECRET = import.meta.env.VITE_HUGGINGFACE_CLIENT_SECRET; // THIS IS UNSAFE, must fix before real deploy
|
14 |
|
15 |
-
import {
|
16 |
|
17 |
export const SignInWithHuggingFaceButton = () => {
|
18 |
const [isSignedIn, setIsSignedIn] = useState(false);
|
19 |
const [isLoading, setIsLoading] = useState(false);
|
20 |
const [name, setName] = useState<string | null>(null);
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
// if (idToken && accessToken) {
|
27 |
-
// const idTokenObject = JSON.parse(idToken);
|
28 |
-
// if (idTokenObject.exp > Date.now() / 1000) {
|
29 |
-
// setIsSignedIn(true);
|
30 |
-
// setName(idTokenObject.name);
|
31 |
|
32 |
-
|
33 |
-
// }
|
34 |
-
// }
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
// // remove the code from the url
|
40 |
-
// window.history.replaceState({}, "", window.location.pathname);
|
41 |
-
// setIsLoading(true);
|
42 |
-
// const response = await fetch(`${OAUTH_API_BASE}`, {
|
43 |
-
// method: "POST",
|
44 |
-
// headers: {
|
45 |
-
// "Content-Type": "application/x-www-form-urlencoded",
|
46 |
-
// Authorization: `Basic ${btoa(`${CLIENT_ID}:${CLIENT_SECRET}`)}`,
|
47 |
-
// },
|
48 |
-
// body: new URLSearchParams({
|
49 |
-
// client_id: CLIENT_ID,
|
50 |
-
// code,
|
51 |
-
// grant_type: "authorization_code",
|
52 |
-
// redirect_uri: REDIRECT_URI
|
53 |
-
// }).toString(),
|
54 |
-
// });
|
55 |
-
// const data = await response.json();
|
56 |
-
// window.localStorage.setItem("huggingface_access_token", data.access_token);
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
// window.localStorage.setItem("huggingface_id_token", JSON.stringify(idToken));
|
62 |
-
// setName(idToken.name);
|
63 |
-
// setIsSignedIn(true);
|
64 |
-
// setIsLoading(false);
|
65 |
-
// }
|
66 |
-
// }
|
67 |
|
68 |
-
|
69 |
-
|
|
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
75 |
|
76 |
-
|
77 |
-
// If the user is not logged in, redirect to the login page
|
78 |
-
window.location.href = await oauthLoginUrl();
|
79 |
}
|
80 |
-
|
81 |
-
// You can use oauthResult.accessToken, oauthResult.accessTokenExpiresAt and oauthResult.userInfo
|
82 |
-
console.log(oauthResult);
|
83 |
}
|
84 |
-
|
85 |
-
fetchToken();
|
86 |
}, []);
|
87 |
|
88 |
-
const handleLogin = async () => {
|
89 |
-
window.location.href =
|
90 |
-
(await oauthLoginUrl({
|
91 |
-
scopes: "inference-api,email",
|
92 |
-
})) + "&prompt=consent";
|
93 |
-
};
|
94 |
-
|
95 |
if (isLoading) {
|
96 |
return <div>Loading...</div>;
|
97 |
}
|
@@ -101,7 +59,7 @@ export const SignInWithHuggingFaceButton = () => {
|
|
101 |
}
|
102 |
|
103 |
return (
|
104 |
-
<a
|
105 |
<img
|
106 |
src="https://huggingface.co/datasets/huggingface/badges/resolve/main/sign-in-with-huggingface-xl.svg"
|
107 |
alt="Sign in with Hugging Face"
|
|
|
4 |
import { jwtDecode } from "jwt-decode";
|
5 |
|
6 |
const CLIENT_ID = "a67ef241-fb7e-4300-a6bd-8430a7565c9a";
|
7 |
+
// const REDIRECT_URI = "https://huggingfacetb-wikispeedia.hf.space";
|
8 |
// const REDIRECT_URI = "http://localhost:5173/auth/callback";
|
9 |
+
const REDIRECT_URI = isProd ? "https://huggingfacetb-wikispeedia.hf.space/auth/callback" : "http://localhost:8000/auth/callback";
|
10 |
+
|
11 |
const SCOPE = "openid%20profile%20email%20inference-api";
|
12 |
const STATE = "1234567890";
|
13 |
const SSO_URL = `https://huggingface.co/oauth/authorize?client_id=${CLIENT_ID}&redirect_uri=${REDIRECT_URI}&response_type=code&scope=${SCOPE}&prompt=consent&state=${STATE}`;
|
|
|
|
|
14 |
|
15 |
+
import { isProd } from "@/lib/constants";
|
16 |
|
17 |
export const SignInWithHuggingFaceButton = () => {
|
18 |
const [isSignedIn, setIsSignedIn] = useState(false);
|
19 |
const [isLoading, setIsLoading] = useState(false);
|
20 |
const [name, setName] = useState<string | null>(null);
|
21 |
|
22 |
+
useEffect(() => {
|
23 |
+
// check if access_token and id_token are in the url
|
24 |
+
const accessTokenURL = new URLSearchParams(window.location.search).get("access_token");
|
25 |
+
const idTokenURL = new URLSearchParams(window.location.search).get("id_token");
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
console.log(accessTokenURL, idTokenURL);
|
|
|
|
|
28 |
|
29 |
+
if (accessTokenURL && idTokenURL) {
|
30 |
+
// remove the access_token and id_token from the url
|
31 |
+
window.history.replaceState({}, "", window.location.pathname);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
window.localStorage.setItem("huggingface_access_token", accessTokenURL);
|
34 |
+
window.localStorage.setItem("huggingface_id_token", JSON.stringify(jwtDecode(idTokenURL)));
|
35 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
// check if the user is already logged in
|
38 |
+
const idToken = window.localStorage.getItem("huggingface_id_token");
|
39 |
+
const accessToken = window.localStorage.getItem("huggingface_access_token");
|
40 |
|
41 |
+
if (idToken && accessToken) {
|
42 |
+
const idTokenObject = JSON.parse(idToken);
|
43 |
+
if (idTokenObject.exp > Date.now() / 1000) {
|
44 |
+
setIsSignedIn(true);
|
45 |
+
setName(idTokenObject.name);
|
46 |
|
47 |
+
return;
|
|
|
|
|
48 |
}
|
|
|
|
|
|
|
49 |
}
|
50 |
+
|
|
|
51 |
}, []);
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
if (isLoading) {
|
54 |
return <div>Loading...</div>;
|
55 |
}
|
|
|
59 |
}
|
60 |
|
61 |
return (
|
62 |
+
<a href={SSO_URL} rel="nofollow">
|
63 |
<img
|
64 |
src="https://huggingface.co/datasets/huggingface/badges/resolve/main/sign-in-with-huggingface-xl.svg"
|
65 |
alt="Sign in with Hugging Face"
|
src/lib/constants.ts
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
-
export const
|
|
|
2 |
|
3 |
console.log("API_BASE", API_BASE);
|
|
|
1 |
+
export const isProd = import.meta.env.VITE_ENV === "production";
|
2 |
+
export const API_BASE = isProd ? "" : "http://localhost:8000"; // we want this blank in production
|
3 |
|
4 |
console.log("API_BASE", API_BASE);
|