import jwt from datetime import datetime, timedelta from fastapi.security import OAuth2PasswordBearer from fastapi import HTTPException, status, Depends from pydantic import BaseModel, Field from pydantic_settings import BaseSettings from config.logging_config import logger class Settings(BaseSettings): api_key_secret: str = Field(..., env="API_KEY_SECRET") token_expiration_minutes: int = Field(30, env="TOKEN_EXPIRATION_MINUTES") llm_model_name: str = "google/gemma-3-4b-it" max_tokens: int = 512 host: str = "0.0.0.0" port: int = 7860 chat_rate_limit: str = "100/minute" speech_rate_limit: str = "5/minute" external_tts_url: str = Field(..., env="EXTERNAL_TTS_URL") external_asr_url: str = Field(..., env="EXTERNAL_ASR_URL") external_text_gen_url: str = Field(..., env="EXTERNAL_TEXT_GEN_URL") external_audio_proc_url: str = Field(..., env="EXTERNAL_AUDIO_PROC_URL") class Config: env_file = ".env" env_file_encoding = "utf-8" settings = Settings() logger.info(f"Loaded API_KEY_SECRET at startup: {settings.api_key_secret}") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v1/token") class TokenPayload(BaseModel): sub: str exp: float class TokenResponse(BaseModel): access_token: str token_type: str async def create_access_token(user_id: str) -> str: expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes) payload = {"sub": user_id, "exp": expire.timestamp()} logger.info(f"Signing token with API_KEY_SECRET: {settings.api_key_secret}") token = jwt.encode(payload, settings.api_key_secret, algorithm="HS256") logger.info(f"Generated access token for user: {user_id}") return token async def get_current_user(token: str = Depends(oauth2_scheme)) -> str: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: logger.info(f"Received token: {token}") logger.info(f"Verifying token with API_KEY_SECRET: {settings.api_key_secret}") # Decode with expiration verification disabled to avoid PyJWT bug payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"], options={"verify_exp": False}) logger.info(f"Decoded payload: {payload}") token_data = TokenPayload(**payload) user_id = token_data.sub if user_id is None: logger.warning("Token has no 'sub' claim") raise credentials_exception # Manual expiration check current_time = datetime.utcnow().timestamp() logger.info(f"Current time: {current_time}, Token exp: {token_data.exp}") if current_time > token_data.exp: logger.warning(f"Token expired: current_time={current_time}, exp={token_data.exp}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired", headers={"WWW-Authenticate": "Bearer"}, ) logger.info(f"Validated token for user: {user_id}") return user_id except jwt.InvalidSignatureError as e: logger.error(f"Invalid signature error: {str(e)}") raise credentials_exception except jwt.InvalidTokenError as e: logger.error(f"Other token error: {str(e)}") raise credentials_exception except Exception as e: logger.error(f"Unexpected token validation error: {str(e)}") raise credentials_exception async def login(user_id: str) -> TokenResponse: token = await create_access_token(user_id=user_id) return TokenResponse(access_token=token, token_type="bearer")