stillerman HF Staff commited on
Commit
5a65482
·
1 Parent(s): 60d35d7

login via fastAPI

Browse files
Dockerfile CHANGED
@@ -42,7 +42,7 @@ RUN curl -L https://huggingface.co/HuggingFaceTB/simplewiki-pruned-text-350k/res
42
 
43
  ENV WIKISPEEDIA_DB_PATH=/home/user/app/wikihop.db
44
 
45
- ENV VITE_API_BASE=""
46
 
47
  CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
48
 
 
42
 
43
  ENV WIKISPEEDIA_DB_PATH=/home/user/app/wikihop.db
44
 
45
+ ENV VITE_ENV=production
46
 
47
  CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
48
 
api.py CHANGED
@@ -1,14 +1,16 @@
 
1
  import sqlite3
2
  import json
3
  import os
4
  from typing import Tuple, List, Optional
5
  from functools import lru_cache
6
- from fastapi import FastAPI, HTTPException
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.staticfiles import StaticFiles
9
  from pydantic import BaseModel
10
  import uvicorn
11
- from fastapi.responses import FileResponse
 
12
 
13
  app = FastAPI(title="WikiSpeedia API")
14
 
@@ -21,14 +23,20 @@ app.add_middleware(
21
  allow_headers=["*"], # Allows all headers
22
  )
23
 
 
 
 
 
24
  class ArticleResponse(BaseModel):
25
  title: str
26
  links: List[str]
27
 
 
28
  class HealthResponse(BaseModel):
29
  status: str
30
  article_count: int
31
 
 
32
  class SQLiteDB:
33
  def __init__(self, db_path: str):
34
  """Initialize the database with path to SQLite database"""
@@ -60,24 +68,25 @@ class SQLiteDB:
60
  self.cursor.execute("SELECT title FROM core_articles")
61
  return [row[0] for row in self.cursor.fetchall()]
62
 
 
63
  # Initialize database connection
64
  db = SQLiteDB(
65
  os.getenv("WIKISPEEDIA_DB_PATH", "/Users/jts/daily/wikihop/db/data/wikihop.db")
66
  )
67
 
 
68
  @app.get("/health", response_model=HealthResponse)
69
  async def health_check():
70
  """Health check endpoint that returns the article count"""
71
- return HealthResponse(
72
- status="healthy",
73
- article_count=db._article_count
74
- )
75
 
76
  @app.get("/get_all_articles", response_model=List[str])
77
  async def get_all_articles():
78
  """Get all articles"""
79
  return db.get_all_articles()
80
 
 
81
  @app.get("/get_article_with_links/{article_title}", response_model=ArticleResponse)
82
  async def get_article(article_title: str):
83
  """Get article and its links by title"""
@@ -87,6 +96,49 @@ async def get_article(article_title: str):
87
  return ArticleResponse(title=title, links=links)
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  # Mount the dist folder for static files
91
  app.mount("/", StaticFiles(directory="dist", html=True), name="static")
92
 
 
1
+ from base64 import b64encode
2
  import sqlite3
3
  import json
4
  import os
5
  from typing import Tuple, List, Optional
6
  from functools import lru_cache
7
+ from fastapi import FastAPI, HTTPException, Request
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from fastapi.staticfiles import StaticFiles
10
  from pydantic import BaseModel
11
  import uvicorn
12
+ from fastapi.responses import FileResponse, RedirectResponse
13
+ import requests
14
 
15
  app = FastAPI(title="WikiSpeedia API")
16
 
 
23
  allow_headers=["*"], # Allows all headers
24
  )
25
 
26
+ CLIENT_SECRET = os.getenv("HUGGINGFACE_CLIENT_SECRET")
27
+ IS_PROD = os.getenv("VITE_ENV") == "production"
28
+ print("CLIENT_SECRET:", CLIENT_SECRET)
29
+ print("IS_PROD:", IS_PROD)
30
  class ArticleResponse(BaseModel):
31
  title: str
32
  links: List[str]
33
 
34
+
35
  class HealthResponse(BaseModel):
36
  status: str
37
  article_count: int
38
 
39
+
40
  class SQLiteDB:
41
  def __init__(self, db_path: str):
42
  """Initialize the database with path to SQLite database"""
 
68
  self.cursor.execute("SELECT title FROM core_articles")
69
  return [row[0] for row in self.cursor.fetchall()]
70
 
71
+
72
  # Initialize database connection
73
  db = SQLiteDB(
74
  os.getenv("WIKISPEEDIA_DB_PATH", "/Users/jts/daily/wikihop/db/data/wikihop.db")
75
  )
76
 
77
+
78
  @app.get("/health", response_model=HealthResponse)
79
  async def health_check():
80
  """Health check endpoint that returns the article count"""
81
+ return HealthResponse(status="healthy", article_count=db._article_count)
82
+
 
 
83
 
84
  @app.get("/get_all_articles", response_model=List[str])
85
  async def get_all_articles():
86
  """Get all articles"""
87
  return db.get_all_articles()
88
 
89
+
90
  @app.get("/get_article_with_links/{article_title}", response_model=ArticleResponse)
91
  async def get_article(article_title: str):
92
  """Get article and its links by title"""
 
96
  return ArticleResponse(title=title, links=links)
97
 
98
 
99
+ @app.get("/auth/callback")
100
+ async def auth_callback(request: Request):
101
+
102
+ OAUTH_API_BASE = "https://huggingface.co/oauth/token"
103
+ CLIENT_ID = "a67ef241-fb7e-4300-a6bd-8430a7565c9a"
104
+
105
+ code = request.query_params.get("code")
106
+ if not code:
107
+ raise HTTPException(status_code=400, detail="No code provided")
108
+
109
+ response = requests.post(
110
+ OAUTH_API_BASE,
111
+ headers={
112
+ "Content-Type": "application/x-www-form-urlencoded",
113
+ "Authorization": f"Basic {b64encode(f'{CLIENT_ID}:{CLIENT_SECRET}'.encode()).decode()}",
114
+ },
115
+ data={
116
+ "client_id": CLIENT_ID,
117
+ "code": code,
118
+ "grant_type": "authorization_code",
119
+ "redirect_uri": "http://localhost:8000/auth/callback" if not IS_PROD else "https://huggingfacetb-wikispeedia.hf.space/auth/callback",
120
+ },
121
+ )
122
+
123
+ # response.json() =
124
+ # {
125
+ # "access_token": "hf_oauth_eyJhbGciOiJFZERTQSJ9.eyJzY29wZSI6WyJvcGVuaWQiLCJwcm9maWxlIiwiZW1haWwiLCJpbmZlcmVuY2UtYXBpIl0sImF1ZCI6Imh0dHBzOi8vaHVnZ2luZ2ZhY2UuY28iLCJvYXV0aEFwcCI6ImE2N2VmMjQxLWZiN2UtNDMwMC1hNmJkLTg0MzBhNzU2NWM5YSIsInNlc3Npb25JZCI6IjY3YTBkYjk3OWNmZDQ3ZGFkOGNmNDMwNyIsImlhdCI6MTc0NjIxOTEwOCwic3ViIjoiNjE3OGQ4NDIyNjczMjBhYmI5OWRmNzc2IiwiZXhwIjoxNzQ2MjQ3OTA4LCJpc3MiOiJodHRwczovL2h1Z2dpbmdmYWNlLmNvIn0.TNK7Nb2X22LHlFqleo6rzJjBngjTWpVIksE1Mw7m8vVxgr7CBbK_a1J4cW488n02391qqopcaNlZKFP8noZSAA",
126
+ # "token_type": "bearer",
127
+ # "expires_in": 28799,
128
+ # "id_token": "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiI2MTc4ZDg0MjI2NzMyMGFiYjk5ZGY3NzYiLCJuYW1lIjoiSmFzb24gU3RpbGxlcm1hbiIsInByZWZlcnJlZF91c2VybmFtZSI6InN0aWxsZXJtYW4iLCJwcm9maWxlIjoiaHR0cHM6Ly9odWdnaW5nZmFjZS5jby9zdGlsbGVybWFuIiwicGljdHVyZSI6Imh0dHBzOi8vaHVnZ2luZ2ZhY2UuY28vYXZhdGFycy84NzM5NzA1ZWY3ZWFiYzk0NWExZWYzYzA3MTk2YWYxMy5zdmciLCJlbWFpbCI6Imphc29uLnQuc3RpbGxlcm1hbkBnbWFpbC5jb20iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXVkIjoiYTY3ZWYyNDEtZmI3ZS00MzAwLWE2YmQtODQzMGE3NTY1YzlhIiwiYXV0aF90aW1lIjoxNzQ2MjE5MTA4LCJpYXQiOjE3NDYyMTkxMDgsImV4cCI6MTc0NjIyMjcwOCwiaXNzIjoiaHR0cHM6Ly9odWdnaW5nZmFjZS5jbyJ9.pB7j-jkxxMG3GJNzipMNCsKQimk8_R0TcPrwi-Kln6qXcSccwGcWJvyMZvFRHjKB779UkMTzgCO-eY1CINX75KaRALLS_Eu0w448F_5LMixwpBXA6dntXBEdP69VLXakpXaPHjFY2HuvUN7fbE8e2_v4a-s7RRwHTDJIcxyH2Bd_OUpebFy1N6RNB_9MIL3jxXhsXyLNL2uDry0WIB52BJKBXB4EzE12HDGgNaWR6lrqr4nvjAExsGcTwarPhFSA5ndcbgh82vJxB3rVFhSU4iZ5AmMV1mDX6SgRVdPmWZPgTBwGeGlVN-OAHvLlNJ9FZ_i0qjrtA5IRU0o6ctKrfw",
129
+ # "scope": "openid profile email inference-api",
130
+ # "refresh_token": "hf_oauth__refresh_RiVshOppmioFVoxvYMXSPkMdyyzbyIqadj",
131
+ # }
132
+
133
+ print(response.json())
134
+
135
+ # redirect to the home page with access token and id token in the url
136
+ return RedirectResponse(url=f"/?access_token={response.json()['access_token']}&id_token={response.json()['id_token']}")
137
+
138
+ """Auth callback endpoint"""
139
+ return {"message": "Auth callback received"}
140
+
141
+
142
  # Mount the dist folder for static files
143
  app.mount("/", StaticFiles(directory="dist", html=True), name="static")
144
 
src/components/sign-in-with-hf-button.tsx CHANGED
@@ -4,94 +4,52 @@ import { useEffect } from "react";
4
  import { jwtDecode } from "jwt-decode";
5
 
6
  const CLIENT_ID = "a67ef241-fb7e-4300-a6bd-8430a7565c9a";
7
- const REDIRECT_URI = "https://huggingfacetb-wikispeedia.hf.space";
8
  // const REDIRECT_URI = "http://localhost:5173/auth/callback";
 
 
9
  const SCOPE = "openid%20profile%20email%20inference-api";
10
  const STATE = "1234567890";
11
  const SSO_URL = `https://huggingface.co/oauth/authorize?client_id=${CLIENT_ID}&redirect_uri=${REDIRECT_URI}&response_type=code&scope=${SCOPE}&prompt=consent&state=${STATE}`;
12
- const OAUTH_API_BASE = "https://huggingface.co/oauth/token";
13
- const CLIENT_SECRET = import.meta.env.VITE_HUGGINGFACE_CLIENT_SECRET; // THIS IS UNSAFE, must fix before real deploy
14
 
15
- import { oauthLoginUrl, oauthHandleRedirectIfPresent } from "@huggingface/hub";
16
 
17
  export const SignInWithHuggingFaceButton = () => {
18
  const [isSignedIn, setIsSignedIn] = useState(false);
19
  const [isLoading, setIsLoading] = useState(false);
20
  const [name, setName] = useState<string | null>(null);
21
 
22
- // useEffect(() => {
23
- // const idToken = window.localStorage.getItem("huggingface_id_token");
24
- // const accessToken = window.localStorage.getItem("huggingface_access_token");
25
-
26
- // if (idToken && accessToken) {
27
- // const idTokenObject = JSON.parse(idToken);
28
- // if (idTokenObject.exp > Date.now() / 1000) {
29
- // setIsSignedIn(true);
30
- // setName(idTokenObject.name);
31
 
32
- // return;
33
- // }
34
- // }
35
 
36
- // async function fetchToken() {
37
- // const code = new URLSearchParams(window.location.search).get("code");
38
- // if (code) {
39
- // // remove the code from the url
40
- // window.history.replaceState({}, "", window.location.pathname);
41
- // setIsLoading(true);
42
- // const response = await fetch(`${OAUTH_API_BASE}`, {
43
- // method: "POST",
44
- // headers: {
45
- // "Content-Type": "application/x-www-form-urlencoded",
46
- // Authorization: `Basic ${btoa(`${CLIENT_ID}:${CLIENT_SECRET}`)}`,
47
- // },
48
- // body: new URLSearchParams({
49
- // client_id: CLIENT_ID,
50
- // code,
51
- // grant_type: "authorization_code",
52
- // redirect_uri: REDIRECT_URI
53
- // }).toString(),
54
- // });
55
- // const data = await response.json();
56
- // window.localStorage.setItem("huggingface_access_token", data.access_token);
57
 
58
- // // parse the id_token
59
- // const idToken = jwtDecode(data.id_token);
60
- // console.log(idToken);
61
- // window.localStorage.setItem("huggingface_id_token", JSON.stringify(idToken));
62
- // setName(idToken.name);
63
- // setIsSignedIn(true);
64
- // setIsLoading(false);
65
- // }
66
- // }
67
 
68
- // fetchToken();
69
- // }, []);
 
70
 
71
- useEffect(() => {
72
- async function fetchToken() {
73
- console.log("fetching token", window.location.href);
74
- const oauthResult = await oauthHandleRedirectIfPresent();
 
75
 
76
- if (!oauthResult) {
77
- // If the user is not logged in, redirect to the login page
78
- window.location.href = await oauthLoginUrl();
79
  }
80
-
81
- // You can use oauthResult.accessToken, oauthResult.accessTokenExpiresAt and oauthResult.userInfo
82
- console.log(oauthResult);
83
  }
84
-
85
- fetchToken();
86
  }, []);
87
 
88
- const handleLogin = async () => {
89
- window.location.href =
90
- (await oauthLoginUrl({
91
- scopes: "inference-api,email",
92
- })) + "&prompt=consent";
93
- };
94
-
95
  if (isLoading) {
96
  return <div>Loading...</div>;
97
  }
@@ -101,7 +59,7 @@ export const SignInWithHuggingFaceButton = () => {
101
  }
102
 
103
  return (
104
- <a onClick={handleLogin} href="#" rel="nofollow">
105
  <img
106
  src="https://huggingface.co/datasets/huggingface/badges/resolve/main/sign-in-with-huggingface-xl.svg"
107
  alt="Sign in with Hugging Face"
 
4
  import { jwtDecode } from "jwt-decode";
5
 
6
  const CLIENT_ID = "a67ef241-fb7e-4300-a6bd-8430a7565c9a";
7
+ // const REDIRECT_URI = "https://huggingfacetb-wikispeedia.hf.space";
8
  // const REDIRECT_URI = "http://localhost:5173/auth/callback";
9
+ const REDIRECT_URI = isProd ? "https://huggingfacetb-wikispeedia.hf.space/auth/callback" : "http://localhost:8000/auth/callback";
10
+
11
  const SCOPE = "openid%20profile%20email%20inference-api";
12
  const STATE = "1234567890";
13
  const SSO_URL = `https://huggingface.co/oauth/authorize?client_id=${CLIENT_ID}&redirect_uri=${REDIRECT_URI}&response_type=code&scope=${SCOPE}&prompt=consent&state=${STATE}`;
 
 
14
 
15
+ import { isProd } from "@/lib/constants";
16
 
17
  export const SignInWithHuggingFaceButton = () => {
18
  const [isSignedIn, setIsSignedIn] = useState(false);
19
  const [isLoading, setIsLoading] = useState(false);
20
  const [name, setName] = useState<string | null>(null);
21
 
22
+ useEffect(() => {
23
+ // check if access_token and id_token are in the url
24
+ const accessTokenURL = new URLSearchParams(window.location.search).get("access_token");
25
+ const idTokenURL = new URLSearchParams(window.location.search).get("id_token");
 
 
 
 
 
26
 
27
+ console.log(accessTokenURL, idTokenURL);
 
 
28
 
29
+ if (accessTokenURL && idTokenURL) {
30
+ // remove the access_token and id_token from the url
31
+ window.history.replaceState({}, "", window.location.pathname);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ window.localStorage.setItem("huggingface_access_token", accessTokenURL);
34
+ window.localStorage.setItem("huggingface_id_token", JSON.stringify(jwtDecode(idTokenURL)));
35
+ }
 
 
 
 
 
 
36
 
37
+ // check if the user is already logged in
38
+ const idToken = window.localStorage.getItem("huggingface_id_token");
39
+ const accessToken = window.localStorage.getItem("huggingface_access_token");
40
 
41
+ if (idToken && accessToken) {
42
+ const idTokenObject = JSON.parse(idToken);
43
+ if (idTokenObject.exp > Date.now() / 1000) {
44
+ setIsSignedIn(true);
45
+ setName(idTokenObject.name);
46
 
47
+ return;
 
 
48
  }
 
 
 
49
  }
50
+
 
51
  }, []);
52
 
 
 
 
 
 
 
 
53
  if (isLoading) {
54
  return <div>Loading...</div>;
55
  }
 
59
  }
60
 
61
  return (
62
+ <a href={SSO_URL} rel="nofollow">
63
  <img
64
  src="https://huggingface.co/datasets/huggingface/badges/resolve/main/sign-in-with-huggingface-xl.svg"
65
  alt="Sign in with Hugging Face"
src/lib/constants.ts CHANGED
@@ -1,3 +1,4 @@
1
- export const API_BASE = import.meta.env.VITE_API_BASE || "http://localhost:8000";
 
2
 
3
  console.log("API_BASE", API_BASE);
 
1
+ export const isProd = import.meta.env.VITE_ENV === "production";
2
+ export const API_BASE = isProd ? "" : "http://localhost:8000"; // we want this blank in production
3
 
4
  console.log("API_BASE", API_BASE);