ecemumutlu commited on
Commit
03cbb60
·
1 Parent(s): a608c19

Add authentication to endpoints

Browse files
Files changed (4) hide show
  1. auth/authentication.py +35 -0
  2. requirements.txt +2 -1
  3. svc/router.py +20 -2
  4. svc/schemas.py +13 -0
auth/authentication.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi.security import OAuth2PasswordBearer
2
+ from fastapi import HTTPException, Depends
3
+ from jose import JWTError, jwt
4
+ from datetime import datetime, timedelta
5
+
6
+ from svc.schemas import User
7
+
8
+
9
+ SECRET_KEY = "llmbenchmark_tr" # your secret key
10
+ ALGORITHM = "HS256"
11
+ ACCESS_TOKEN_EXPIRE_MINUTES = 30
12
+
13
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
14
+
15
+ def create_access_token(data: dict):
16
+ to_encode = data.copy()
17
+ expire = datetime.now() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
18
+ to_encode.update({"exp": expire})
19
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
20
+ return encoded_jwt
21
+
22
+ def get_current_user(token: str = Depends(oauth2_scheme)):
23
+ credentials_exception = HTTPException(
24
+ status_code=401,
25
+ detail="Could not validate credentials",
26
+ headers={"WWW-Authenticate": "Bearer"},
27
+ )
28
+ try:
29
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
30
+ username: str = payload.get("sub")
31
+ if username is None:
32
+ raise credentials_exception
33
+ return username
34
+ except JWTError:
35
+ raise credentials_exception
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  fastapi
2
  uvicorn[standard]
3
- lm_eval
 
 
1
  fastapi
2
  uvicorn[standard]
3
+ lm_eval
4
+ jose
svc/router.py CHANGED
@@ -2,17 +2,35 @@ from fastapi import APIRouter, HTTPException
2
  import logging
3
 
4
  from lm_eval import evaluator, utils
5
- from svc.schemas import LMHarnessTaskRequest, LMHarnessTaskResponse
 
 
 
6
 
7
  router = APIRouter()
8
 
9
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
10
  logger = logging.getLogger(__name__)
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  @router.post("/chat", response_model=LMHarnessTaskResponse)
15
- def inference_model(request: LMHarnessTaskRequest):
16
  try:
17
  results = evaluator.simple_evaluate(
18
  model=request.model,
 
2
  import logging
3
 
4
  from lm_eval import evaluator, utils
5
+ from svc.schemas import LMHarnessTaskRequest, LMHarnessTaskResponse, OAuth2PasswordRequestForm, User
6
+ from fastapi import FastAPI, Depends, HTTPException
7
+ from auth.authentication import oauth2_scheme, get_current_user, create_access_token
8
+ from data.users import users_db
9
 
10
  router = APIRouter()
11
 
12
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
13
  logger = logging.getLogger(__name__)
14
 
15
+ from dotenv import load_dotenv
16
+ import os
17
+
18
+ load_dotenv()
19
+ @router.post("/token")
20
+ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
21
+ if os.getenv("HF_TOKEN") != form_data.hf_token:
22
+ raise HTTPException(status_code=400, detail="Incorrect username or password")
23
+ access_token = create_access_token(data={"sub": form_data.username})
24
+ return {"access_token": access_token, "token_type": "bearer"}
25
+
26
+
27
+ @router.get("/protected")
28
+ async def protected_route(username: str = Depends(get_current_user)):
29
+ return {"message": f"Hello, {username}! This is a protected resource."}
30
 
31
 
32
  @router.post("/chat", response_model=LMHarnessTaskResponse)
33
+ def inference_model(request: LMHarnessTaskRequest = Depends(get_current_user)):
34
  try:
35
  results = evaluator.simple_evaluate(
36
  model=request.model,
svc/schemas.py CHANGED
@@ -1,6 +1,19 @@
1
  from pydantic import BaseModel
2
  from typing import List, Optional, Union, Any
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  class LMHarnessTaskRequest(BaseModel):
5
  model: str
6
  model_args: Optional[Union[str, dict]] = None
 
1
  from pydantic import BaseModel
2
  from typing import List, Optional, Union, Any
3
 
4
+
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class OAuth2PasswordRequestForm(BaseModel):
9
+ username: str
10
+ hf_token: str
11
+
12
+ class User(BaseModel):
13
+ username: str
14
+ hf_token: str
15
+
16
+
17
  class LMHarnessTaskRequest(BaseModel):
18
  model: str
19
  model_args: Optional[Union[str, dict]] = None