sachin commited on
Commit
8a99693
·
1 Parent(s): c0849df

update-refresh token

Browse files
Files changed (2) hide show
  1. src/server/main.py +30 -39
  2. src/server/utils/auth.py +34 -20
src/server/main.py CHANGED
@@ -5,7 +5,7 @@ from typing import List, Optional
5
  from abc import ABC, abstractmethod
6
 
7
  import uvicorn
8
- from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Form, Security
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
@@ -147,7 +147,7 @@ async def home():
147
  @app.post("/v1/token",
148
  response_model=TokenResponse,
149
  summary="User Login",
150
- description="Authenticate a user with username and password to obtain an access token. Copy the token and use it in the 'Authorize' button above.",
151
  tags=["Authentication"],
152
  responses={
153
  200: {"description": "Successful login", "model": TokenResponse},
@@ -159,11 +159,11 @@ async def token(login_request: LoginRequest):
159
  @app.post("/v1/refresh",
160
  response_model=TokenResponse,
161
  summary="Refresh Access Token",
162
- description="Generate a new access token using an existing valid token.",
163
  tags=["Authentication"],
164
  responses={
165
- 200: {"description": "New token issued", "model": TokenResponse},
166
- 401: {"description": "Invalid or expired token"}
167
  })
168
  async def refresh(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)):
169
  return await refresh_token(credentials)
@@ -171,7 +171,7 @@ async def refresh(credentials: HTTPAuthorizationCredentials = Depends(bearer_sch
171
  @app.post("/v1/register",
172
  response_model=TokenResponse,
173
  summary="Register New User",
174
- description="Create a new user account and return an access token. Requires admin access (use 'admin' user with password 'adminpass' initially).",
175
  tags=["Authentication"],
176
  responses={
177
  200: {"description": "User registered successfully", "model": TokenResponse},
@@ -180,9 +180,9 @@ async def refresh(credentials: HTTPAuthorizationCredentials = Depends(bearer_sch
180
  })
181
  async def register_user(
182
  register_request: RegisterRequest,
183
- current_user: str = Depends(get_current_user_with_admin) # Enforce admin-only access
184
  ):
185
- return await register(register_request, current_user) # Pass current_user explicitly
186
 
187
  @app.post("/v1/audio/speech",
188
  summary="Generate Speech from Text",
@@ -520,9 +520,7 @@ async def translate(
520
  except ValueError as e:
521
  logger.error(f"Invalid JSON response: {str(e)}")
522
  raise HTTPException(status_code=500, detail="Invalid response format from translation service")
523
-
524
 
525
- # Request/Response Models for Visual Query
526
  class VisualQueryRequest(BaseModel):
527
  query: str
528
  src_lang: str = "kan_Knda" # Default to Kannada
@@ -537,49 +535,48 @@ class VisualQueryRequest(BaseModel):
537
  class VisualQueryResponse(BaseModel):
538
  answer: str
539
 
540
-
541
- @app.post("/v1/visual_query", response_model=VisualQueryResponse)
 
 
 
 
 
 
 
 
 
 
542
  @limiter.limit(settings.chat_rate_limit)
543
  async def visual_query(
544
  request: Request,
545
- query: str = Form(...),
546
- file: UploadFile = File(...),
547
- src_lang: str = Query(default="kan_Knda"),
548
- tgt_lang: str = Query(default="kan_Knda"),
549
- #api_key: str = Depends(get_api_key) # Uncomment if authentication is needed
550
  ):
551
- """
552
- Endpoint to process visual queries with an image and text question.
553
- Calls an external API to analyze the image and provide a response.
554
- """
555
  if not query.strip():
556
  raise HTTPException(status_code=400, detail="Query cannot be empty")
557
 
558
  logger.info("Processing visual query request", extra={
559
  "endpoint": "/v1/visual_query",
560
  "query_length": len(query),
561
- "file_name": file.filename, # Changed from "filename" to "file_name"
562
  "client_ip": get_remote_address(request),
 
563
  "src_lang": src_lang,
564
  "tgt_lang": tgt_lang
565
  })
566
 
567
- # External API URL
568
  external_url = f"https://slabstech-dhwani-internal-api-server.hf.space/v1/visual_query/?src_lang={src_lang}&tgt_lang={tgt_lang}"
569
 
570
  try:
571
- # Read file content
572
  file_content = await file.read()
 
 
573
 
574
- # Prepare multipart/form-data
575
- files = {
576
- "file": (file.filename, file_content, file.content_type)
577
- }
578
- data = {
579
- "query": query
580
- }
581
-
582
- # Make the POST request to external API
583
  response = requests.post(
584
  external_url,
585
  files=files,
@@ -587,11 +584,8 @@ async def visual_query(
587
  headers={"accept": "application/json"},
588
  timeout=60
589
  )
590
-
591
- # Raise an exception for bad status codes
592
  response.raise_for_status()
593
 
594
- # Parse the response
595
  response_data = response.json()
596
  answer = response_data.get("answer", "")
597
 
@@ -611,9 +605,6 @@ async def visual_query(
611
  except ValueError as e:
612
  logger.error(f"Invalid JSON response: {str(e)}")
613
  raise HTTPException(status_code=500, detail="Invalid response format from visual query service")
614
- except Exception as e:
615
- logger.error(f"Unexpected error in visual query: {str(e)}", exc_info=True)
616
- raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
617
 
618
  if __name__ == "__main__":
619
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
 
5
  from abc import ABC, abstractmethod
6
 
7
  import uvicorn
8
+ from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Form
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 
147
  @app.post("/v1/token",
148
  response_model=TokenResponse,
149
  summary="User Login",
150
+ description="Authenticate a user with username and password to obtain an access token and refresh token. Copy the access token and use it in the 'Authorize' button above.",
151
  tags=["Authentication"],
152
  responses={
153
  200: {"description": "Successful login", "model": TokenResponse},
 
159
  @app.post("/v1/refresh",
160
  response_model=TokenResponse,
161
  summary="Refresh Access Token",
162
+ description="Generate a new access token and refresh token using an existing valid refresh token.",
163
  tags=["Authentication"],
164
  responses={
165
+ 200: {"description": "New tokens issued", "model": TokenResponse},
166
+ 401: {"description": "Invalid or expired refresh token"}
167
  })
168
  async def refresh(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)):
169
  return await refresh_token(credentials)
 
171
  @app.post("/v1/register",
172
  response_model=TokenResponse,
173
  summary="Register New User",
174
+ description="Create a new user account and return an access token and refresh token. Requires admin access (use 'admin' user with password 'adminpass' initially).",
175
  tags=["Authentication"],
176
  responses={
177
  200: {"description": "User registered successfully", "model": TokenResponse},
 
180
  })
181
  async def register_user(
182
  register_request: RegisterRequest,
183
+ current_user: str = Depends(get_current_user_with_admin)
184
  ):
185
+ return await register(register_request, current_user)
186
 
187
  @app.post("/v1/audio/speech",
188
  summary="Generate Speech from Text",
 
520
  except ValueError as e:
521
  logger.error(f"Invalid JSON response: {str(e)}")
522
  raise HTTPException(status_code=500, detail="Invalid response format from translation service")
 
523
 
 
524
  class VisualQueryRequest(BaseModel):
525
  query: str
526
  src_lang: str = "kan_Knda" # Default to Kannada
 
535
  class VisualQueryResponse(BaseModel):
536
  answer: str
537
 
538
+ @app.post("/v1/visual_query",
539
+ response_model=VisualQueryResponse,
540
+ summary="Visual Query with Image",
541
+ description="Process a visual query with an image and text question. Rate limited to 100 requests per minute per user. Requires authentication.",
542
+ tags=["Chat"],
543
+ responses={
544
+ 200: {"description": "Query response", "model": VisualQueryResponse},
545
+ 400: {"description": "Invalid query"},
546
+ 401: {"description": "Unauthorized - Token required"},
547
+ 429: {"description": "Rate limit exceeded"},
548
+ 504: {"description": "Visual query service timeout"}
549
+ })
550
  @limiter.limit(settings.chat_rate_limit)
551
  async def visual_query(
552
  request: Request,
553
+ query: str = Form(..., description="Text query for the visual content"),
554
+ file: UploadFile = File(..., description="Image file to analyze"),
555
+ src_lang: str = Query(default="kan_Knda", description="Source language code"),
556
+ tgt_lang: str = Query(default="kan_Knda", description="Target language code"),
557
+ credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)
558
  ):
559
+ user_id = await get_current_user(credentials)
 
 
 
560
  if not query.strip():
561
  raise HTTPException(status_code=400, detail="Query cannot be empty")
562
 
563
  logger.info("Processing visual query request", extra={
564
  "endpoint": "/v1/visual_query",
565
  "query_length": len(query),
566
+ "file_name": file.filename,
567
  "client_ip": get_remote_address(request),
568
+ "user_id": user_id,
569
  "src_lang": src_lang,
570
  "tgt_lang": tgt_lang
571
  })
572
 
 
573
  external_url = f"https://slabstech-dhwani-internal-api-server.hf.space/v1/visual_query/?src_lang={src_lang}&tgt_lang={tgt_lang}"
574
 
575
  try:
 
576
  file_content = await file.read()
577
+ files = {"file": (file.filename, file_content, file.content_type)}
578
+ data = {"query": query}
579
 
 
 
 
 
 
 
 
 
 
580
  response = requests.post(
581
  external_url,
582
  files=files,
 
584
  headers={"accept": "application/json"},
585
  timeout=60
586
  )
 
 
587
  response.raise_for_status()
588
 
 
589
  response_data = response.json()
590
  answer = response_data.get("answer", "")
591
 
 
605
  except ValueError as e:
606
  logger.error(f"Invalid JSON response: {str(e)}")
607
  raise HTTPException(status_code=500, detail="Invalid response format from visual query service")
 
 
 
608
 
609
  if __name__ == "__main__":
610
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
src/server/utils/auth.py CHANGED
@@ -29,7 +29,8 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
29
 
30
  class Settings(BaseSettings):
31
  api_key_secret: str = Field(..., env="API_KEY_SECRET")
32
- token_expiration_minutes: int = Field(30, env="TOKEN_EXPIRATION_MINUTES")
 
33
  llm_model_name: str = "google/gemma-3-4b-it"
34
  max_tokens: int = 512
35
  host: str = "0.0.0.0"
@@ -40,7 +41,6 @@ class Settings(BaseSettings):
40
  external_asr_url: str = Field(..., env="EXTERNAL_ASR_URL")
41
  external_text_gen_url: str = Field(..., env="EXTERNAL_TEXT_GEN_URL")
42
  external_audio_proc_url: str = Field(..., env="EXTERNAL_AUDIO_PROC_URL")
43
- # Admin credentials required from environment variables, no defaults
44
  default_admin_username: str = Field("admin", env="DEFAULT_ADMIN_USERNAME")
45
  default_admin_password: str = Field("admin54321", env="DEFAULT_ADMIN_PASSWORD")
46
 
@@ -49,7 +49,6 @@ class Settings(BaseSettings):
49
  env_file_encoding = "utf-8"
50
 
51
  settings = Settings()
52
- #logger.info(f"Loaded API_KEY_SECRET at startup: {settings.api_key_secret}")
53
 
54
  # Seed initial data (optional)
55
  def seed_initial_data():
@@ -77,9 +76,11 @@ bearer_scheme = HTTPBearer()
77
  class TokenPayload(BaseModel):
78
  sub: str
79
  exp: float
 
80
 
81
  class TokenResponse(BaseModel):
82
  access_token: str
 
83
  token_type: str
84
 
85
  class LoginRequest(BaseModel):
@@ -90,13 +91,15 @@ class RegisterRequest(BaseModel):
90
  username: str
91
  password: str
92
 
93
- async def create_access_token(user_id: str) -> str:
94
  expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes)
95
- payload = {"sub": user_id, "exp": expire.timestamp()}
96
- #logger.info(f"Signing token with API_KEY_SECRET: {settings.api_key_secret}")
97
  token = jwt.encode(payload, settings.api_key_secret, algorithm="HS256")
98
- #logger.info(f"Generated access token for user: {user_id}")
99
- return token
 
 
 
100
 
101
  async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> str:
102
  token = credentials.credentials
@@ -106,10 +109,7 @@ async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(b
106
  headers={"WWW-Authenticate": "Bearer"},
107
  )
108
  try:
109
- #logger.info(f"Received token: {token}")
110
- #logger.info(f"Verifying token with API_KEY_SECRET: {settings.api_key_secret}")
111
  payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"], options={"verify_exp": False})
112
- #logger.info(f"Decoded payload: {payload}")
113
  token_data = TokenPayload(**payload)
114
  user_id = token_data.sub
115
 
@@ -121,9 +121,8 @@ async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(b
121
  raise credentials_exception
122
 
123
  current_time = datetime.utcnow().timestamp()
124
- logger.info(f"Current time: {current_time}, Token exp: {token_data.exp}")
125
  if current_time > token_data.exp:
126
- #logger.warning(f"Token expired: current_time={current_time}, exp={token_data.exp}")
127
  raise HTTPException(
128
  status_code=status.HTTP_401_UNAUTHORIZED,
129
  detail="Token has expired",
@@ -159,8 +158,8 @@ async def login(login_request: LoginRequest) -> TokenResponse:
159
  if not user or not pwd_context.verify(login_request.password, user.password):
160
  logger.warning(f"Login failed for user: {login_request.username}")
161
  raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
162
- token = await create_access_token(user_id=user.username)
163
- return TokenResponse(access_token=token, token_type="bearer")
164
 
165
  async def register(register_request: RegisterRequest, current_user: str = Depends(get_current_user_with_admin)) -> TokenResponse:
166
  db = SessionLocal()
@@ -176,11 +175,26 @@ async def register(register_request: RegisterRequest, current_user: str = Depend
176
  db.commit()
177
  db.close()
178
 
179
- token = await create_access_token(user_id=register_request.username)
180
  logger.info(f"Registered and generated token for user: {register_request.username} by admin {current_user}")
181
- return TokenResponse(access_token=token, token_type="bearer")
182
 
183
  async def refresh_token(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> TokenResponse:
184
- user_id = await get_current_user(credentials)
185
- new_token = await create_access_token(user_id=user_id)
186
- return TokenResponse(access_token=new_token, token_type="bearer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  class Settings(BaseSettings):
31
  api_key_secret: str = Field(..., env="API_KEY_SECRET")
32
+ token_expiration_minutes: int = Field(1440, env="TOKEN_EXPIRATION_MINUTES") # 24 hours
33
+ refresh_token_expiration_days: int = Field(7, env="REFRESH_TOKEN_EXPIRATION_DAYS") # 7 days
34
  llm_model_name: str = "google/gemma-3-4b-it"
35
  max_tokens: int = 512
36
  host: str = "0.0.0.0"
 
41
  external_asr_url: str = Field(..., env="EXTERNAL_ASR_URL")
42
  external_text_gen_url: str = Field(..., env="EXTERNAL_TEXT_GEN_URL")
43
  external_audio_proc_url: str = Field(..., env="EXTERNAL_AUDIO_PROC_URL")
 
44
  default_admin_username: str = Field("admin", env="DEFAULT_ADMIN_USERNAME")
45
  default_admin_password: str = Field("admin54321", env="DEFAULT_ADMIN_PASSWORD")
46
 
 
49
  env_file_encoding = "utf-8"
50
 
51
  settings = Settings()
 
52
 
53
  # Seed initial data (optional)
54
  def seed_initial_data():
 
76
  class TokenPayload(BaseModel):
77
  sub: str
78
  exp: float
79
+ type: str
80
 
81
  class TokenResponse(BaseModel):
82
  access_token: str
83
+ refresh_token: str
84
  token_type: str
85
 
86
  class LoginRequest(BaseModel):
 
91
  username: str
92
  password: str
93
 
94
+ async def create_access_token(user_id: str) -> dict:
95
  expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes)
96
+ payload = {"sub": user_id, "exp": expire.timestamp(), "type": "access"}
 
97
  token = jwt.encode(payload, settings.api_key_secret, algorithm="HS256")
98
+ refresh_expire = datetime.utcnow() + timedelta(days=settings.refresh_token_expiration_days)
99
+ refresh_payload = {"sub": user_id, "exp": refresh_expire.timestamp(), "type": "refresh"}
100
+ refresh_token = jwt.encode(refresh_payload, settings.api_key_secret, algorithm="HS256")
101
+ logger.info(f"Generated tokens for user: {user_id}")
102
+ return {"access_token": token, "refresh_token": refresh_token}
103
 
104
  async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> str:
105
  token = credentials.credentials
 
109
  headers={"WWW-Authenticate": "Bearer"},
110
  )
111
  try:
 
 
112
  payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"], options={"verify_exp": False})
 
113
  token_data = TokenPayload(**payload)
114
  user_id = token_data.sub
115
 
 
121
  raise credentials_exception
122
 
123
  current_time = datetime.utcnow().timestamp()
 
124
  if current_time > token_data.exp:
125
+ logger.warning(f"Token expired: current_time={current_time}, exp={token_data.exp}")
126
  raise HTTPException(
127
  status_code=status.HTTP_401_UNAUTHORIZED,
128
  detail="Token has expired",
 
158
  if not user or not pwd_context.verify(login_request.password, user.password):
159
  logger.warning(f"Login failed for user: {login_request.username}")
160
  raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
161
+ tokens = await create_access_token(user_id=user.username)
162
+ return TokenResponse(access_token=tokens["access_token"], refresh_token=tokens["refresh_token"], token_type="bearer")
163
 
164
  async def register(register_request: RegisterRequest, current_user: str = Depends(get_current_user_with_admin)) -> TokenResponse:
165
  db = SessionLocal()
 
175
  db.commit()
176
  db.close()
177
 
178
+ tokens = await create_access_token(user_id=register_request.username)
179
  logger.info(f"Registered and generated token for user: {register_request.username} by admin {current_user}")
180
+ return TokenResponse(access_token=tokens["access_token"], refresh_token=tokens["refresh_token"], token_type="bearer")
181
 
182
  async def refresh_token(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> TokenResponse:
183
+ token = credentials.credentials
184
+ try:
185
+ payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"])
186
+ token_data = TokenPayload(**payload)
187
+ if payload.get("type") != "refresh":
188
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type; refresh token required")
189
+ user_id = token_data.sub
190
+ db = SessionLocal()
191
+ user = db.query(User).filter_by(username=user_id).first()
192
+ db.close()
193
+ if not user:
194
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
195
+ if datetime.utcnow().timestamp() > token_data.exp:
196
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token has expired")
197
+ tokens = await create_access_token(user_id=user_id)
198
+ return TokenResponse(access_token=tokens["access_token"], refresh_token=tokens["refresh_token"], token_type="bearer")
199
+ except jwt.InvalidTokenError:
200
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")