import jwt from datetime import datetime, timedelta from fastapi import HTTPException, status, Depends from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel, Field from pydantic_settings import BaseSettings from config.logging_config import logger from sqlalchemy import create_engine, Column, String, Boolean from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from passlib.context import CryptContext import os import base64 from Crypto.Cipher import AES from Crypto.Random import get_random_bytes # SQLite database setup with Hugging Face persistent storage DATABASE_PATH = "/data/users.db" DATABASE_URL = f"sqlite:///{DATABASE_PATH}" engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) Base = declarative_base() SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # Model for admin-related users class User(Base): __tablename__ = "users" username = Column(String, primary_key=True, index=True) password = Column(String) # Stores hashed passwords is_admin = Column(Boolean, default=False) session_key = Column(String, nullable=True) # Stores base64-encoded session key # Model for app users class AppUser(Base): __tablename__ = "app_users" username = Column(String, primary_key=True, index=True) password = Column(String) # Stores hashed passwords session_key = Column(String, nullable=True) # Stores base64-encoded session key # Ensure the /data directory exists os.makedirs(os.path.dirname(DATABASE_PATH), exist_ok=True) # Create database tables Base.metadata.create_all(bind=engine) # Password hashing pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") class Settings(BaseSettings): api_key_secret: str = Field(..., env="API_KEY_SECRET") token_expiration_minutes: int = Field(1440, env="TOKEN_EXPIRATION_MINUTES") refresh_token_expiration_days: int = Field(7, env="REFRESH_TOKEN_EXPIRATION_DAYS") llm_model_name: str = "google/gemma-3-4b-it" max_tokens: int = 512 host: str = "0.0.0.0" port: int = 7860 chat_rate_limit: str = "100/minute" speech_rate_limit: str = "5/minute" external_tts_url: str = Field(..., env="EXTERNAL_TTS_URL") external_asr_url: str = Field(..., env="EXTERNAL_ASR_URL") external_text_gen_url: str = Field(..., env="EXTERNAL_TEXT_GEN_URL") external_audio_proc_url: str = Field(..., env="EXTERNAL_AUDIO_PROC_URL") default_admin_username: str = Field("admin", env="DEFAULT_ADMIN_USERNAME") default_admin_password: str = Field("admin54321", env="DEFAULT_ADMIN_PASSWORD") database_path: str = DATABASE_PATH class Config: env_file = ".env" env_file_encoding = "utf-8" settings = Settings() # Seed initial data for users table only def seed_initial_data(): db = SessionLocal() try: test_username = "testuser@example.com" if not db.query(User).filter_by(username=test_username).first(): test_device_token = "550e8400-e29b-41d4-a716-446655440000" hashed_password = pwd_context.hash(test_device_token) session_key = base64.b64encode(get_random_bytes(16)).decode('utf-8') db.add(User(username=test_username, password=hashed_password, is_admin=False, session_key=session_key)) db.commit() admin_username = settings.default_admin_username admin_password = settings.default_admin_password if not db.query(User).filter_by(username=admin_username).first(): hashed_password = pwd_context.hash(admin_password) session_key = base64.b64encode(get_random_bytes(16)).decode('utf-8') db.add(User(username=admin_username, password=hashed_password, is_admin=True, session_key=session_key)) db.commit() logger.info(f"Seeded initial data: test user '{test_username}', admin user '{admin_username}'") except Exception as e: logger.error(f"Error seeding initial data: {str(e)}") db.rollback() finally: db.close() seed_initial_data() bearer_scheme = HTTPBearer() class TokenPayload(BaseModel): sub: str exp: float type: str class TokenResponse(BaseModel): access_token: str refresh_token: str token_type: str class LoginRequest(BaseModel): username: str password: str class RegisterRequest(BaseModel): username: str password: str def decrypt_data(encrypted_data: str, key: bytes) -> str: try: data = base64.b64decode(encrypted_data) nonce, ciphertext = data[:12], data[12:] cipher = AES.new(key, AES.MODE_GCM, nonce=nonce) plaintext = cipher.decrypt_and_verify(ciphertext[:-16], ciphertext[-16:]) return plaintext.decode('utf-8') except Exception as e: logger.error(f"Decryption failed: {str(e)}") raise HTTPException(status_code=400, detail="Invalid encrypted data") async def create_access_token(user_id: str) -> dict: expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes) payload = {"sub": user_id, "exp": expire.timestamp(), "type": "access"} token = jwt.encode(payload, settings.api_key_secret, algorithm="HS256") refresh_expire = datetime.utcnow() + timedelta(days=settings.refresh_token_expiration_days) refresh_payload = {"sub": user_id, "exp": refresh_expire.timestamp(), "type": "refresh"} refresh_token = jwt.encode(refresh_payload, settings.api_key_secret, algorithm="HS256") logger.info(f"Generated tokens for user: {user_id}") return {"access_token": token, "refresh_token": refresh_token} async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> str: token = credentials.credentials credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"], options={"verify_exp": False}) token_data = TokenPayload(**payload) user_id = token_data.sub db = SessionLocal() # Check both users and app_users tables user = db.query(User).filter_by(username=user_id).first() app_user = db.query(AppUser).filter_by(username=user_id).first() db.close() if user_id is None or (not user and not app_user): logger.warning(f"Invalid or unknown user: {user_id}") raise credentials_exception current_time = datetime.utcnow().timestamp() if current_time > token_data.exp: logger.warning(f"Token expired: current_time={current_time}, exp={token_data.exp}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired", headers={"WWW-Authenticate": "Bearer"}, ) logger.info(f"Validated token for user: {user_id}") return user_id except jwt.InvalidSignatureError as e: logger.error(f"Invalid signature error: {str(e)}") raise credentials_exception except jwt.InvalidTokenError as e: logger.error(f"Other token error: {str(e)}") raise credentials_exception except Exception as e: logger.error(f"Unexpected token validation error: {str(e)}") raise credentials_exception async def get_current_user_with_admin(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> str: token = credentials.credentials credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"]) token_data = TokenPayload(**payload) user_id = token_data.sub db = SessionLocal() user = db.query(User).filter_by(username=user_id).first() db.close() if not user: logger.warning(f"User not found in users table: {user_id}") raise credentials_exception if not user.is_admin: logger.warning(f"User {user_id} is not authorized as admin") raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required; only admin accounts can perform this action" ) logger.info(f"Validated admin user: {user_id}") return user_id except jwt.InvalidSignatureError as e: logger.error(f"Invalid signature error: {str(e)}") raise credentials_exception except jwt.InvalidTokenError as e: logger.error(f"Other token error: {str(e)}") raise credentials_exception except Exception as e: logger.error(f"Unexpected admin validation error: {str(e)}") raise credentials_exception async def login(login_request: LoginRequest, session_key_b64: str) -> TokenResponse: db = SessionLocal() session_key = base64.b64decode(session_key_b64) try: username = decrypt_data(login_request.username, session_key) password = decrypt_data(login_request.password, session_key) except: db.close() raise HTTPException(status_code=400, detail="Invalid encrypted data") # Check both users and app_users tables user = db.query(User).filter_by(username=username).first() app_user = db.query(AppUser).filter_by(username=username).first() if not user and not app_user: db.close() logger.warning(f"Login failed for user: {username}") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or device token") target_user = user if user else app_user if not pwd_context.verify(password, target_user.password): db.close() logger.warning(f"Login failed for user: {username}") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or device token") if target_user.session_key != session_key_b64: target_user.session_key = session_key_b64 db.commit() db.close() tokens = await create_access_token(user_id=username) return TokenResponse(access_token=tokens["access_token"], refresh_token=tokens["refresh_token"], token_type="bearer") async def register(register_request: RegisterRequest, current_user: str = Depends(get_current_user_with_admin)) -> TokenResponse: db = SessionLocal() try: existing_user = db.query(User).filter_by(username=register_request.username).first() if existing_user: logger.warning(f"Registration failed: Username {register_request.username} already exists") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already exists") hashed_password = pwd_context.hash(register_request.password) new_user = User(username=register_request.username, password=hashed_password, is_admin=False) db.add(new_user) db.commit() logger.info(f"Admin {current_user} successfully registered new user: {register_request.username}") tokens = await create_access_token(user_id=register_request.username) return TokenResponse(access_token=tokens["access_token"], refresh_token=tokens["refresh_token"], token_type="bearer") except Exception as e: db.rollback() logger.error(f"Registration error by admin {current_user}: {str(e)}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Registration failed: {str(e)}") finally: db.close() async def app_register(register_request: RegisterRequest, session_key_b64: str) -> TokenResponse: db = SessionLocal() session_key = base64.b64decode(session_key_b64) try: username = decrypt_data(register_request.username, session_key) password = decrypt_data(register_request.password, session_key) except: db.close() raise HTTPException(status_code=400, detail="Invalid encrypted data") # Check both tables to prevent duplicate usernames existing_user = db.query(User).filter_by(username=username).first() existing_app_user = db.query(AppUser).filter_by(username=username).first() if existing_user or existing_app_user: db.close() logger.warning(f"App registration failed: Email {username} already exists") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered") hashed_password = pwd_context.hash(password) new_app_user = AppUser(username=username, password=hashed_password, session_key=session_key_b64) db.add(new_app_user) db.commit() db.close() tokens = await create_access_token(user_id=username) logger.info(f"App registered new user: {username}") return TokenResponse(access_token=tokens["access_token"], refresh_token=tokens["refresh_token"], token_type="bearer") async def refresh_token(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> TokenResponse: token = credentials.credentials try: payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"]) token_data = TokenPayload(**payload) if payload.get("type") != "refresh": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type; refresh token required") user_id = token_data.sub db = SessionLocal() # Check both users and app_users tables user = db.query(User).filter_by(username=user_id).first() app_user = db.query(AppUser).filter_by(username=user_id).first() db.close() if not user and not app_user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found") tokens = await create_access_token(user_id=user_id) return TokenResponse(access_token=tokens["access_token"], refresh_token=tokens["refresh_token"], token_type="bearer") except jwt.InvalidTokenError: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")