sachin commited on
Commit
4335561
·
1 Parent(s): 844386f
Files changed (1) hide show
  1. src/server/utils/auth.py +32 -34
src/server/utils/auth.py CHANGED
@@ -1,94 +1,92 @@
1
  import jwt
2
  from datetime import datetime, timedelta
3
- from pydantic import BaseModel, Field, field_validator
4
-
5
  from fastapi.security import OAuth2PasswordBearer
6
  from fastapi import HTTPException, status, Depends
7
- from pydantic import BaseModel
8
  from pydantic_settings import BaseSettings
9
- from config.logging_config import logger # Assuming this is available
10
- from typing import Optional
11
 
12
- # Centralized Settings class (can be moved to a separate config file later)
13
  class Settings(BaseSettings):
14
- api_key_secret: str = Field(..., env="API_KEY_SECRET") # Secret key for signing JWTs
15
- token_expiration_minutes: int = Field(30, env="TOKEN_EXPIRATION_MINUTES") # Default to 30 minutes
16
  llm_model_name: str = "google/gemma-3-4b-it"
17
  max_tokens: int = 512
18
  host: str = "0.0.0.0"
19
  port: int = 7860
20
  chat_rate_limit: str = "100/minute"
21
  speech_rate_limit: str = "5/minute"
 
 
 
 
22
 
23
  class Config:
24
  env_file = ".env"
25
  env_file_encoding = "utf-8"
26
 
27
  settings = Settings()
28
- logger.info(f"Loaded API_KEY_SECRET at startup: {settings.api_key_secret}") # Add this line
29
 
30
- # OAuth2 scheme with Bearer token
31
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v1/token")
32
 
33
- # Model for token payload
34
  class TokenPayload(BaseModel):
35
- sub: str # Subject (user identifier)
36
- exp: int # Expiration timestamp
37
 
38
- # Model for token response
39
  class TokenResponse(BaseModel):
40
  access_token: str
41
  token_type: str
42
 
43
  async def create_access_token(user_id: str) -> str:
44
- """
45
- Create a JWT access token for a given user.
46
- """
47
  expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes)
48
  payload = {"sub": user_id, "exp": expire.timestamp()}
49
- logger.info(f"Signing token with API_KEY_SECRET: {settings.api_key_secret}") # Add this line
50
  token = jwt.encode(payload, settings.api_key_secret, algorithm="HS256")
51
  logger.info(f"Generated access token for user: {user_id}")
52
  return token
53
 
54
  async def get_current_user(token: str = Depends(oauth2_scheme)) -> str:
55
- """
56
- Validate the Bearer token and return the user ID.
57
- """
58
  credentials_exception = HTTPException(
59
  status_code=status.HTTP_401_UNAUTHORIZED,
60
  detail="Invalid authentication credentials",
61
  headers={"WWW-Authenticate": "Bearer"},
62
  )
63
-
64
  try:
65
- logger.info(f"Verifying token with API_KEY_SECRET: {settings.api_key_secret}") # Add this line
66
- payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"])
 
 
 
67
  token_data = TokenPayload(**payload)
68
  user_id = token_data.sub
 
69
  if user_id is None:
 
70
  raise credentials_exception
71
- if datetime.utcnow().timestamp() > token_data.exp:
 
 
 
 
 
72
  raise HTTPException(
73
  status_code=status.HTTP_401_UNAUTHORIZED,
74
  detail="Token has expired",
75
  headers={"WWW-Authenticate": "Bearer"},
76
  )
 
77
  logger.info(f"Validated token for user: {user_id}")
78
  return user_id
79
- except jwt.InvalidTokenError:
80
- logger.warning(f"Invalid token attempt: {token[:10]}...")
 
 
 
81
  raise credentials_exception
82
  except Exception as e:
83
- logger.error(f"Token validation error: {str(e)}")
84
  raise credentials_exception
85
 
86
- # For demonstration purposes, a simple login function
87
- # In production, replace with proper user authentication (e.g., database lookup)
88
  async def login(user_id: str) -> TokenResponse:
89
- """
90
- Generate a token for a user. In production, validate credentials here.
91
- """
92
- # Placeholder: Assume user_id is valid; in reality, check against a database
93
  token = await create_access_token(user_id=user_id)
94
  return TokenResponse(access_token=token, token_type="bearer")
 
1
  import jwt
2
  from datetime import datetime, timedelta
 
 
3
  from fastapi.security import OAuth2PasswordBearer
4
  from fastapi import HTTPException, status, Depends
5
+ from pydantic import BaseModel, Field
6
  from pydantic_settings import BaseSettings
7
+ from config.logging_config import logger
 
8
 
 
9
  class Settings(BaseSettings):
10
+ api_key_secret: str = Field(..., env="API_KEY_SECRET")
11
+ token_expiration_minutes: int = Field(30, env="TOKEN_EXPIRATION_MINUTES")
12
  llm_model_name: str = "google/gemma-3-4b-it"
13
  max_tokens: int = 512
14
  host: str = "0.0.0.0"
15
  port: int = 7860
16
  chat_rate_limit: str = "100/minute"
17
  speech_rate_limit: str = "5/minute"
18
+ external_tts_url: str = Field(..., env="EXTERNAL_TTS_URL")
19
+ external_asr_url: str = Field(..., env="EXTERNAL_ASR_URL")
20
+ external_text_gen_url: str = Field(..., env="EXTERNAL_TEXT_GEN_URL")
21
+ external_audio_proc_url: str = Field(..., env="EXTERNAL_AUDIO_PROC_URL")
22
 
23
  class Config:
24
  env_file = ".env"
25
  env_file_encoding = "utf-8"
26
 
27
  settings = Settings()
28
+ logger.info(f"Loaded API_KEY_SECRET at startup: {settings.api_key_secret}")
29
 
 
30
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v1/token")
31
 
 
32
  class TokenPayload(BaseModel):
33
+ sub: str
34
+ exp: float
35
 
 
36
  class TokenResponse(BaseModel):
37
  access_token: str
38
  token_type: str
39
 
40
  async def create_access_token(user_id: str) -> str:
 
 
 
41
  expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes)
42
  payload = {"sub": user_id, "exp": expire.timestamp()}
43
+ logger.info(f"Signing token with API_KEY_SECRET: {settings.api_key_secret}")
44
  token = jwt.encode(payload, settings.api_key_secret, algorithm="HS256")
45
  logger.info(f"Generated access token for user: {user_id}")
46
  return token
47
 
48
  async def get_current_user(token: str = Depends(oauth2_scheme)) -> str:
 
 
 
49
  credentials_exception = HTTPException(
50
  status_code=status.HTTP_401_UNAUTHORIZED,
51
  detail="Invalid authentication credentials",
52
  headers={"WWW-Authenticate": "Bearer"},
53
  )
 
54
  try:
55
+ logger.info(f"Received token: {token}")
56
+ logger.info(f"Verifying token with API_KEY_SECRET: {settings.api_key_secret}")
57
+ # Decode with expiration verification disabled to avoid PyJWT bug
58
+ payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"], options={"verify_exp": False})
59
+ logger.info(f"Decoded payload: {payload}")
60
  token_data = TokenPayload(**payload)
61
  user_id = token_data.sub
62
+
63
  if user_id is None:
64
+ logger.warning("Token has no 'sub' claim")
65
  raise credentials_exception
66
+
67
+ # Manual expiration check
68
+ current_time = datetime.utcnow().timestamp()
69
+ logger.info(f"Current time: {current_time}, Token exp: {token_data.exp}")
70
+ if current_time > token_data.exp:
71
+ logger.warning(f"Token expired: current_time={current_time}, exp={token_data.exp}")
72
  raise HTTPException(
73
  status_code=status.HTTP_401_UNAUTHORIZED,
74
  detail="Token has expired",
75
  headers={"WWW-Authenticate": "Bearer"},
76
  )
77
+
78
  logger.info(f"Validated token for user: {user_id}")
79
  return user_id
80
+ except jwt.InvalidSignatureError as e:
81
+ logger.error(f"Invalid signature error: {str(e)}")
82
+ raise credentials_exception
83
+ except jwt.InvalidTokenError as e:
84
+ logger.error(f"Other token error: {str(e)}")
85
  raise credentials_exception
86
  except Exception as e:
87
+ logger.error(f"Unexpected token validation error: {str(e)}")
88
  raise credentials_exception
89
 
 
 
90
  async def login(user_id: str) -> TokenResponse:
 
 
 
 
91
  token = await create_access_token(user_id=user_id)
92
  return TokenResponse(access_token=token, token_type="bearer")