sachin
add ft
3fb3087
raw
history blame
8.04 kB
import jwt
from datetime import datetime, timedelta
from fastapi import HTTPException, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
from config.logging_config import logger
from sqlalchemy import create_engine, Column, String, Boolean
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from passlib.context import CryptContext
# SQLite database setup
DATABASE_URL = "sqlite:///users.db"
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
Base = declarative_base()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
class User(Base):
__tablename__ = "users"
username = Column(String, primary_key=True, index=True)
password = Column(String) # Stores hashed passwords
is_admin = Column(Boolean, default=False) # New admin flag
Base.metadata.create_all(bind=engine)
# Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class Settings(BaseSettings):
api_key_secret: str = Field(..., env="API_KEY_SECRET")
token_expiration_minutes: int = Field(30, env="TOKEN_EXPIRATION_MINUTES")
llm_model_name: str = "google/gemma-3-4b-it"
max_tokens: int = 512
host: str = "0.0.0.0"
port: int = 7860
chat_rate_limit: str = "100/minute"
speech_rate_limit: str = "5/minute"
external_tts_url: str = Field(..., env="EXTERNAL_TTS_URL")
external_asr_url: str = Field(..., env="EXTERNAL_ASR_URL")
external_text_gen_url: str = Field(..., env="EXTERNAL_TEXT_GEN_URL")
external_audio_proc_url: str = Field(..., env="EXTERNAL_AUDIO_PROC_URL")
# Admin credentials required from environment variables, no defaults
default_admin_username: str = Field("admin", env="DEFAULT_ADMIN_USERNAME")
default_admin_password: str = Field("admin54321", env="DEFAULT_ADMIN_PASSWORD")
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
settings = Settings()
logger.info(f"Loaded API_KEY_SECRET at startup: {settings.api_key_secret}")
# Seed initial data (optional)
def seed_initial_data():
db = SessionLocal()
# Seed test user (non-admin)
if not db.query(User).filter_by(username="testuser").first():
hashed_password = pwd_context.hash("password123")
db.add(User(username="testuser", password=hashed_password, is_admin=False))
db.commit()
# Seed admin user using environment variables
admin_username = settings.default_admin_username
admin_password = settings.default_admin_password
if not db.query(User).filter_by(username=admin_username).first():
hashed_password = pwd_context.hash(admin_password)
db.add(User(username=admin_username, password=hashed_password, is_admin=True))
db.commit()
db.close()
logger.info(f"Seeded initial data: admin user '{admin_username}'")
seed_initial_data()
# Use HTTPBearer
bearer_scheme = HTTPBearer()
class TokenPayload(BaseModel):
sub: str
exp: float
class TokenResponse(BaseModel):
access_token: str
token_type: str
class LoginRequest(BaseModel):
username: str
password: str
class RegisterRequest(BaseModel):
username: str
password: str
async def create_access_token(user_id: str) -> str:
expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes)
payload = {"sub": user_id, "exp": expire.timestamp()}
logger.info(f"Signing token with API_KEY_SECRET: {settings.api_key_secret}")
token = jwt.encode(payload, settings.api_key_secret, algorithm="HS256")
logger.info(f"Generated access token for user: {user_id}")
return token
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> str:
token = credentials.credentials
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
logger.info(f"Received token: {token}")
logger.info(f"Verifying token with API_KEY_SECRET: {settings.api_key_secret}")
payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"], options={"verify_exp": False})
logger.info(f"Decoded payload: {payload}")
token_data = TokenPayload(**payload)
user_id = token_data.sub
db = SessionLocal()
user = db.query(User).filter_by(username=user_id).first()
db.close()
if user_id is None or not user:
logger.warning(f"Invalid or unknown user: {user_id}")
raise credentials_exception
current_time = datetime.utcnow().timestamp()
logger.info(f"Current time: {current_time}, Token exp: {token_data.exp}")
if current_time > token_data.exp:
logger.warning(f"Token expired: current_time={current_time}, exp={token_data.exp}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired",
headers={"WWW-Authenticate": "Bearer"},
)
logger.info(f"Validated token for user: {user_id}")
return user_id
except jwt.InvalidSignatureError as e:
logger.error(f"Invalid signature error: {str(e)}")
raise credentials_exception
except jwt.InvalidTokenError as e:
logger.error(f"Other token error: {str(e)}")
raise credentials_exception
except Exception as e:
logger.error(f"Unexpected token validation error: {str(e)}")
raise credentials_exception
async def get_current_user_with_admin(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> str:
user_id = await get_current_user(credentials)
db = SessionLocal()
user = db.query(User).filter_by(username=user_id).first()
db.close()
if not user or not user.is_admin:
logger.warning(f"User {user_id} is not authorized as admin")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required")
return user_id
async def login(login_request: LoginRequest) -> TokenResponse:
db = SessionLocal()
user = db.query(User).filter_by(username=login_request.username).first()
db.close()
if not user or not pwd_context.verify(login_request.password, user.password):
logger.warning(f"Login failed for user: {login_request.username}")
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
token = await create_access_token(user_id=user.username)
return TokenResponse(access_token=token, token_type="bearer")
async def register(register_request: RegisterRequest, current_user: str = Depends(get_current_user_with_admin)) -> TokenResponse:
db = SessionLocal()
existing_user = db.query(User).filter_by(username=register_request.username).first()
if existing_user:
db.close()
logger.warning(f"Registration failed: Username {register_request.username} already exists")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already exists")
hashed_password = pwd_context.hash(register_request.password)
new_user = User(username=register_request.username, password=hashed_password, is_admin=False)
db.add(new_user)
db.commit()
db.close()
token = await create_access_token(user_id=register_request.username)
logger.info(f"Registered and generated token for user: {register_request.username} by admin {current_user}")
return TokenResponse(access_token=token, token_type="bearer")
async def refresh_token(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> TokenResponse:
user_id = await get_current_user(credentials)
new_token = await create_access_token(user_id=user_id)
return TokenResponse(access_token=new_token, token_type="bearer")