File size: 3,776 Bytes
844386f
 
 
9781b82
4335561
9781b82
4335561
9781b82
 
4335561
 
844386f
 
 
 
 
 
4335561
 
 
 
844386f
9781b82
 
844386f
9781b82
 
4335561
844386f
 
 
 
4335561
 
844386f
 
 
 
 
 
 
 
4335561
844386f
 
 
9781b82
844386f
 
 
 
 
 
 
4335561
 
 
 
 
844386f
 
4335561
844386f
4335561
844386f
4335561
 
 
 
 
 
844386f
 
 
 
 
4335561
844386f
 
4335561
 
 
 
 
844386f
 
4335561
844386f
9781b82
844386f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import jwt
from datetime import datetime, timedelta
from fastapi.security import OAuth2PasswordBearer
from fastapi import HTTPException, status, Depends
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
from config.logging_config import logger

class Settings(BaseSettings):
    api_key_secret: str = Field(..., env="API_KEY_SECRET")
    token_expiration_minutes: int = Field(30, env="TOKEN_EXPIRATION_MINUTES")
    llm_model_name: str = "google/gemma-3-4b-it"
    max_tokens: int = 512
    host: str = "0.0.0.0"
    port: int = 7860
    chat_rate_limit: str = "100/minute"
    speech_rate_limit: str = "5/minute"
    external_tts_url: str = Field(..., env="EXTERNAL_TTS_URL")
    external_asr_url: str = Field(..., env="EXTERNAL_ASR_URL")
    external_text_gen_url: str = Field(..., env="EXTERNAL_TEXT_GEN_URL")
    external_audio_proc_url: str = Field(..., env="EXTERNAL_AUDIO_PROC_URL")

    class Config:
        env_file = ".env"
        env_file_encoding = "utf-8"

settings = Settings()
logger.info(f"Loaded API_KEY_SECRET at startup: {settings.api_key_secret}")

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v1/token")

class TokenPayload(BaseModel):
    sub: str
    exp: float

class TokenResponse(BaseModel):
    access_token: str
    token_type: str

async def create_access_token(user_id: str) -> str:
    expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes)
    payload = {"sub": user_id, "exp": expire.timestamp()}
    logger.info(f"Signing token with API_KEY_SECRET: {settings.api_key_secret}")
    token = jwt.encode(payload, settings.api_key_secret, algorithm="HS256")
    logger.info(f"Generated access token for user: {user_id}")
    return token

async def get_current_user(token: str = Depends(oauth2_scheme)) -> str:
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Invalid authentication credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        logger.info(f"Received token: {token}")
        logger.info(f"Verifying token with API_KEY_SECRET: {settings.api_key_secret}")
        # Decode with expiration verification disabled to avoid PyJWT bug
        payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"], options={"verify_exp": False})
        logger.info(f"Decoded payload: {payload}")
        token_data = TokenPayload(**payload)
        user_id = token_data.sub
        
        if user_id is None:
            logger.warning("Token has no 'sub' claim")
            raise credentials_exception
        
        # Manual expiration check
        current_time = datetime.utcnow().timestamp()
        logger.info(f"Current time: {current_time}, Token exp: {token_data.exp}")
        if current_time > token_data.exp:
            logger.warning(f"Token expired: current_time={current_time}, exp={token_data.exp}")
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Token has expired",
                headers={"WWW-Authenticate": "Bearer"},
            )
        
        logger.info(f"Validated token for user: {user_id}")
        return user_id
    except jwt.InvalidSignatureError as e:
        logger.error(f"Invalid signature error: {str(e)}")
        raise credentials_exception
    except jwt.InvalidTokenError as e:
        logger.error(f"Other token error: {str(e)}")
        raise credentials_exception
    except Exception as e:
        logger.error(f"Unexpected token validation error: {str(e)}")
        raise credentials_exception

async def login(user_id: str) -> TokenResponse:
    token = await create_access_token(user_id=user_id)
    return TokenResponse(access_token=token, token_type="bearer")