Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
from datetime import timedelta | |
from typing import Annotated | |
from fastapi import APIRouter, Depends, status | |
from fastapi.responses import JSONResponse | |
from fastapi.security import OAuth2PasswordRequestForm | |
from passlib.context import CryptContext | |
from sqlalchemy.orm import Session | |
from db.models import User | |
from db.database import get_db | |
from api.auth import get_current_user, create_access_token | |
from service.dto import CreateUserRequest, UserVerification, Token | |
from collections import Counter | |
from time import time | |
load_dotenv() | |
router = APIRouter(tags=["User"]) | |
bcrypt_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
db_dependency = Annotated[Session, Depends(get_db)] | |
user_dependency = Annotated[dict, Depends(get_current_user)] | |
ACCESS_TOKEN_EXPIRE_MINUTES = 43200 | |
# Rate-limiting config | |
FAILED_ATTEMPT_LIMIT = 3 | |
BLOCK_TIME_SECONDS = 300 # Block for 5 minutes | |
# In-memory tracking for failed attempts | |
failed_attempts = Counter() | |
blocked_users = {} | |
async def login_for_access_token( | |
login_data: Annotated[OAuth2PasswordRequestForm, Depends()], | |
db: Session = Depends(get_db), | |
): | |
username = login_data.username | |
# Check if user is blocked | |
if username in blocked_users: | |
block_until = blocked_users[username] | |
if time() < block_until: | |
return JSONResponse( | |
status_code=status.HTTP_403_FORBIDDEN, | |
content=f"Too many failed attempts. Try again after {int(block_until - time())} seconds.", | |
) | |
else: | |
# Unblock the user after the time period | |
del blocked_users[username] | |
del failed_attempts[username] | |
user = db.query(User).filter(User.username == username).first() | |
if not user: | |
# Automatically register the user | |
create_user_request = CreateUserRequest( | |
name=login_data.username, | |
username=login_data.username, | |
email=login_data.username, | |
password=os.getenv("USER_PASSWORD"), # Replace with a generated or temporary password | |
role_id=2, | |
) | |
registration_response = await register_user(db, create_user_request) | |
if isinstance(registration_response, JSONResponse): | |
return registration_response # Return error response if registration failed | |
# Retrieve the newly created user after successful registration | |
user = db.query(User).filter(User.username == username).first() | |
if not user: | |
return JSONResponse( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
content="User registration failed unexpectedly." | |
) | |
correct_password = ( | |
bcrypt_context.verify(os.getenv("USER_PASSWORD"), user.password_hash) or | |
bcrypt_context.verify(login_data.password, user.password_hash) | |
) | |
if not correct_password : | |
failed_attempts[username] = failed_attempts.get(username, 0) + 1 | |
if failed_attempts[username] >= FAILED_ATTEMPT_LIMIT: | |
blocked_users[username] = time() + BLOCK_TIME_SECONDS | |
failed_attempts.pop(username, None) # Reset after blocking | |
return JSONResponse( | |
status_code=status.HTTP_403_FORBIDDEN, | |
content="Too many failed attempts. You are temporarily blocked." | |
) | |
return JSONResponse( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
content="Invalid credentials." | |
) | |
failed_attempts.pop(username, None) | |
try: | |
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
access_token = create_access_token( | |
user.username, | |
user.name, | |
user.id, | |
user.role_id, | |
access_token_expires, | |
user.email, | |
) | |
return {"access_token": access_token, "token_type": "bearer"} | |
except Exception as e: | |
print(e) | |
return JSONResponse(status_code=500, content="An error occurred during login") | |
async def get_user(user: user_dependency): | |
if user is None: | |
return JSONResponse(status_code=401, content="Authentication Failed") | |
return { | |
"username": user.get("username"), | |
"name": user.get("name"), | |
"id": user.get("id"), | |
"email": user.get("email"), | |
"role": user.get("role_id"), | |
} | |
async def get_all_users(user: user_dependency, db: Session = Depends(get_db)): | |
# Check if the current user has an admin role | |
if user.get("role_id") != 1: # Adjust this check based on how roles are represented | |
return JSONResponse(status_code=401, content="Authentication Failed") | |
# Query the database to retrieve all users | |
users = db.query( | |
User | |
).all() # Assuming you have a User model with an SQLAlchemy session | |
return [ | |
{ | |
"id": user.id, | |
"username": user.username, | |
"name": user.name, | |
"email": user.email, | |
"role": user.role_id, | |
} | |
for user in users | |
] | |
async def register_user(db: db_dependency, create_user_request: CreateUserRequest): | |
existing_user = ( | |
db.query(User).filter(User.email == create_user_request.email).first() | |
) | |
if existing_user: | |
return JSONResponse(status_code=400, content="Email is already registered") | |
try: | |
password_hash = bcrypt_context.hash(create_user_request.password) | |
create_user_model = User( | |
name=create_user_request.name, | |
username=create_user_request.username, | |
email=create_user_request.email, | |
role_id=create_user_request.role_id, | |
password_hash=password_hash, | |
) | |
db.add(create_user_model) | |
db.commit() | |
db.refresh(create_user_model) | |
return {"message": "User created successfully", "user_id": create_user_model.id} | |
except Exception as e: | |
print(e) | |
return JSONResponse( | |
status_code=500, content="An error occuring when register user" | |
) | |
# @router.post("/forgot_password") | |
# async def forget_password(): | |
# pass | |
# @router.post("/change_password") | |
# async def change_password( | |
# user: user_dependency, db: db_dependency, user_verification: UserVerification | |
# ): | |
# if user is None: | |
# return JSONResponse(status_code=401, content="Authentication Failed") | |
# user_model = db.query(User).filter(User.id == user.get("id")).first() | |
# if not bcrypt_context.verify( | |
# user_verification.password, user_model.hashed_password | |
# ): | |
# return JSONResponse(status_code=401, content="Error on password change") | |
# user_model.hashed_password = bcrypt_context.hash(user_verification.new_password) | |
# db.add(user_model) | |
# db.commit() | |
# db.refresh(user_model) | |
# return {"message": "User's password successfully changed", "user_id": user_model.id} | |