Spaces:
Sleeping
Sleeping
import os | |
from datetime import timedelta, datetime, timezone | |
from typing import Annotated | |
from dotenv import load_dotenv | |
from fastapi import APIRouter, Depends | |
from fastapi.responses import JSONResponse | |
from fastapi.security import OAuth2PasswordBearer | |
from jose import jwt, JWTError | |
from passlib.context import CryptContext | |
from sqlalchemy.orm import Session | |
from starlette import status | |
from db.models import User | |
from db.database import get_db | |
load_dotenv() | |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") | |
# Custom OAuth2 request form to accept email, username, password, and role_id | |
router = APIRouter(prefix="/auth", tags=["auth"]) | |
SECRET_KEY = os.getenv("SECRET_KEY") | |
ALGORITHM = "HS256" | |
bcrypt_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
# Database dependency | |
db_dependency = Annotated[Session, Depends(get_db)] | |
def authenticate_user(email: str, password: str, db): | |
user = db.query(User).filter(User.email == email).first() | |
if not user: | |
return False | |
if not bcrypt_context.verify(password, user.hashed_password): | |
return False | |
return user | |
def create_access_token( | |
username: str, name: str, user_id: int, role_id: int, expires_delta: timedelta, email: str | |
): | |
encode = {"sub": username, "name":name, "id": user_id, "role_id": role_id, "email": email} | |
expires = datetime.now(timezone.utc) + expires_delta | |
encode.update({"exp": expires}) | |
return jwt.encode(encode, SECRET_KEY, algorithm=ALGORITHM) | |
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]): | |
try: | |
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
username: str = payload.get("sub") | |
name: str = payload.get("name") | |
user_id: int = payload.get("id") | |
role_id: int = payload.get("role_id") | |
email: str = payload.get("email") | |
if username is None or user_id is None: | |
return JSONResponse( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
content="Could not validate user.", | |
) | |
return {"username": username, "name" : name, "id": user_id, "role_id": role_id, "email": email} | |
except JWTError: | |
return JSONResponse( | |
status_code=status.HTTP_401_UNAUTHORIZED, content="Could not validate user." | |
) | |
user_dependency = Annotated[dict, Depends(get_current_user)] | |
def check_user_authentication(user: user_dependency): | |
"""Helper function to check if the user is authenticated.""" | |
if user is None: | |
return JSONResponse(status_code=401, content="Authentication Failed") | |
return None | |
def check_admin_authentication(user: user_dependency): | |
"""Helper function to check if the user is authenticated.""" | |
if user is None or user.get("role_id") != 1: | |
return JSONResponse(status_code=401, content="Authentication Admin Failed") | |
return None |