Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
sachin
commited on
Commit
·
8a99693
1
Parent(s):
c0849df
update-refresh token
Browse files- src/server/main.py +30 -39
- src/server/utils/auth.py +34 -20
src/server/main.py
CHANGED
@@ -5,7 +5,7 @@ from typing import List, Optional
|
|
5 |
from abc import ABC, abstractmethod
|
6 |
|
7 |
import uvicorn
|
8 |
-
from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Form
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
|
11 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
@@ -147,7 +147,7 @@ async def home():
|
|
147 |
@app.post("/v1/token",
|
148 |
response_model=TokenResponse,
|
149 |
summary="User Login",
|
150 |
-
description="Authenticate a user with username and password to obtain an access token. Copy the token and use it in the 'Authorize' button above.",
|
151 |
tags=["Authentication"],
|
152 |
responses={
|
153 |
200: {"description": "Successful login", "model": TokenResponse},
|
@@ -159,11 +159,11 @@ async def token(login_request: LoginRequest):
|
|
159 |
@app.post("/v1/refresh",
|
160 |
response_model=TokenResponse,
|
161 |
summary="Refresh Access Token",
|
162 |
-
description="Generate a new access token using an existing valid token.",
|
163 |
tags=["Authentication"],
|
164 |
responses={
|
165 |
-
200: {"description": "New
|
166 |
-
401: {"description": "Invalid or expired token"}
|
167 |
})
|
168 |
async def refresh(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)):
|
169 |
return await refresh_token(credentials)
|
@@ -171,7 +171,7 @@ async def refresh(credentials: HTTPAuthorizationCredentials = Depends(bearer_sch
|
|
171 |
@app.post("/v1/register",
|
172 |
response_model=TokenResponse,
|
173 |
summary="Register New User",
|
174 |
-
description="Create a new user account and return an access token. Requires admin access (use 'admin' user with password 'adminpass' initially).",
|
175 |
tags=["Authentication"],
|
176 |
responses={
|
177 |
200: {"description": "User registered successfully", "model": TokenResponse},
|
@@ -180,9 +180,9 @@ async def refresh(credentials: HTTPAuthorizationCredentials = Depends(bearer_sch
|
|
180 |
})
|
181 |
async def register_user(
|
182 |
register_request: RegisterRequest,
|
183 |
-
current_user: str = Depends(get_current_user_with_admin)
|
184 |
):
|
185 |
-
return await register(register_request, current_user)
|
186 |
|
187 |
@app.post("/v1/audio/speech",
|
188 |
summary="Generate Speech from Text",
|
@@ -520,9 +520,7 @@ async def translate(
|
|
520 |
except ValueError as e:
|
521 |
logger.error(f"Invalid JSON response: {str(e)}")
|
522 |
raise HTTPException(status_code=500, detail="Invalid response format from translation service")
|
523 |
-
|
524 |
|
525 |
-
# Request/Response Models for Visual Query
|
526 |
class VisualQueryRequest(BaseModel):
|
527 |
query: str
|
528 |
src_lang: str = "kan_Knda" # Default to Kannada
|
@@ -537,49 +535,48 @@ class VisualQueryRequest(BaseModel):
|
|
537 |
class VisualQueryResponse(BaseModel):
|
538 |
answer: str
|
539 |
|
540 |
-
|
541 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
542 |
@limiter.limit(settings.chat_rate_limit)
|
543 |
async def visual_query(
|
544 |
request: Request,
|
545 |
-
query: str = Form(
|
546 |
-
file: UploadFile = File(
|
547 |
-
src_lang: str = Query(default="kan_Knda"),
|
548 |
-
tgt_lang: str = Query(default="kan_Knda"),
|
549 |
-
|
550 |
):
|
551 |
-
|
552 |
-
Endpoint to process visual queries with an image and text question.
|
553 |
-
Calls an external API to analyze the image and provide a response.
|
554 |
-
"""
|
555 |
if not query.strip():
|
556 |
raise HTTPException(status_code=400, detail="Query cannot be empty")
|
557 |
|
558 |
logger.info("Processing visual query request", extra={
|
559 |
"endpoint": "/v1/visual_query",
|
560 |
"query_length": len(query),
|
561 |
-
"file_name": file.filename,
|
562 |
"client_ip": get_remote_address(request),
|
|
|
563 |
"src_lang": src_lang,
|
564 |
"tgt_lang": tgt_lang
|
565 |
})
|
566 |
|
567 |
-
# External API URL
|
568 |
external_url = f"https://slabstech-dhwani-internal-api-server.hf.space/v1/visual_query/?src_lang={src_lang}&tgt_lang={tgt_lang}"
|
569 |
|
570 |
try:
|
571 |
-
# Read file content
|
572 |
file_content = await file.read()
|
|
|
|
|
573 |
|
574 |
-
# Prepare multipart/form-data
|
575 |
-
files = {
|
576 |
-
"file": (file.filename, file_content, file.content_type)
|
577 |
-
}
|
578 |
-
data = {
|
579 |
-
"query": query
|
580 |
-
}
|
581 |
-
|
582 |
-
# Make the POST request to external API
|
583 |
response = requests.post(
|
584 |
external_url,
|
585 |
files=files,
|
@@ -587,11 +584,8 @@ async def visual_query(
|
|
587 |
headers={"accept": "application/json"},
|
588 |
timeout=60
|
589 |
)
|
590 |
-
|
591 |
-
# Raise an exception for bad status codes
|
592 |
response.raise_for_status()
|
593 |
|
594 |
-
# Parse the response
|
595 |
response_data = response.json()
|
596 |
answer = response_data.get("answer", "")
|
597 |
|
@@ -611,9 +605,6 @@ async def visual_query(
|
|
611 |
except ValueError as e:
|
612 |
logger.error(f"Invalid JSON response: {str(e)}")
|
613 |
raise HTTPException(status_code=500, detail="Invalid response format from visual query service")
|
614 |
-
except Exception as e:
|
615 |
-
logger.error(f"Unexpected error in visual query: {str(e)}", exc_info=True)
|
616 |
-
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
617 |
|
618 |
if __name__ == "__main__":
|
619 |
parser = argparse.ArgumentParser(description="Run the FastAPI server.")
|
|
|
5 |
from abc import ABC, abstractmethod
|
6 |
|
7 |
import uvicorn
|
8 |
+
from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Form
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
|
11 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
|
147 |
@app.post("/v1/token",
|
148 |
response_model=TokenResponse,
|
149 |
summary="User Login",
|
150 |
+
description="Authenticate a user with username and password to obtain an access token and refresh token. Copy the access token and use it in the 'Authorize' button above.",
|
151 |
tags=["Authentication"],
|
152 |
responses={
|
153 |
200: {"description": "Successful login", "model": TokenResponse},
|
|
|
159 |
@app.post("/v1/refresh",
|
160 |
response_model=TokenResponse,
|
161 |
summary="Refresh Access Token",
|
162 |
+
description="Generate a new access token and refresh token using an existing valid refresh token.",
|
163 |
tags=["Authentication"],
|
164 |
responses={
|
165 |
+
200: {"description": "New tokens issued", "model": TokenResponse},
|
166 |
+
401: {"description": "Invalid or expired refresh token"}
|
167 |
})
|
168 |
async def refresh(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)):
|
169 |
return await refresh_token(credentials)
|
|
|
171 |
@app.post("/v1/register",
|
172 |
response_model=TokenResponse,
|
173 |
summary="Register New User",
|
174 |
+
description="Create a new user account and return an access token and refresh token. Requires admin access (use 'admin' user with password 'adminpass' initially).",
|
175 |
tags=["Authentication"],
|
176 |
responses={
|
177 |
200: {"description": "User registered successfully", "model": TokenResponse},
|
|
|
180 |
})
|
181 |
async def register_user(
|
182 |
register_request: RegisterRequest,
|
183 |
+
current_user: str = Depends(get_current_user_with_admin)
|
184 |
):
|
185 |
+
return await register(register_request, current_user)
|
186 |
|
187 |
@app.post("/v1/audio/speech",
|
188 |
summary="Generate Speech from Text",
|
|
|
520 |
except ValueError as e:
|
521 |
logger.error(f"Invalid JSON response: {str(e)}")
|
522 |
raise HTTPException(status_code=500, detail="Invalid response format from translation service")
|
|
|
523 |
|
|
|
524 |
class VisualQueryRequest(BaseModel):
|
525 |
query: str
|
526 |
src_lang: str = "kan_Knda" # Default to Kannada
|
|
|
535 |
class VisualQueryResponse(BaseModel):
|
536 |
answer: str
|
537 |
|
538 |
+
@app.post("/v1/visual_query",
|
539 |
+
response_model=VisualQueryResponse,
|
540 |
+
summary="Visual Query with Image",
|
541 |
+
description="Process a visual query with an image and text question. Rate limited to 100 requests per minute per user. Requires authentication.",
|
542 |
+
tags=["Chat"],
|
543 |
+
responses={
|
544 |
+
200: {"description": "Query response", "model": VisualQueryResponse},
|
545 |
+
400: {"description": "Invalid query"},
|
546 |
+
401: {"description": "Unauthorized - Token required"},
|
547 |
+
429: {"description": "Rate limit exceeded"},
|
548 |
+
504: {"description": "Visual query service timeout"}
|
549 |
+
})
|
550 |
@limiter.limit(settings.chat_rate_limit)
|
551 |
async def visual_query(
|
552 |
request: Request,
|
553 |
+
query: str = Form(..., description="Text query for the visual content"),
|
554 |
+
file: UploadFile = File(..., description="Image file to analyze"),
|
555 |
+
src_lang: str = Query(default="kan_Knda", description="Source language code"),
|
556 |
+
tgt_lang: str = Query(default="kan_Knda", description="Target language code"),
|
557 |
+
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)
|
558 |
):
|
559 |
+
user_id = await get_current_user(credentials)
|
|
|
|
|
|
|
560 |
if not query.strip():
|
561 |
raise HTTPException(status_code=400, detail="Query cannot be empty")
|
562 |
|
563 |
logger.info("Processing visual query request", extra={
|
564 |
"endpoint": "/v1/visual_query",
|
565 |
"query_length": len(query),
|
566 |
+
"file_name": file.filename,
|
567 |
"client_ip": get_remote_address(request),
|
568 |
+
"user_id": user_id,
|
569 |
"src_lang": src_lang,
|
570 |
"tgt_lang": tgt_lang
|
571 |
})
|
572 |
|
|
|
573 |
external_url = f"https://slabstech-dhwani-internal-api-server.hf.space/v1/visual_query/?src_lang={src_lang}&tgt_lang={tgt_lang}"
|
574 |
|
575 |
try:
|
|
|
576 |
file_content = await file.read()
|
577 |
+
files = {"file": (file.filename, file_content, file.content_type)}
|
578 |
+
data = {"query": query}
|
579 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
580 |
response = requests.post(
|
581 |
external_url,
|
582 |
files=files,
|
|
|
584 |
headers={"accept": "application/json"},
|
585 |
timeout=60
|
586 |
)
|
|
|
|
|
587 |
response.raise_for_status()
|
588 |
|
|
|
589 |
response_data = response.json()
|
590 |
answer = response_data.get("answer", "")
|
591 |
|
|
|
605 |
except ValueError as e:
|
606 |
logger.error(f"Invalid JSON response: {str(e)}")
|
607 |
raise HTTPException(status_code=500, detail="Invalid response format from visual query service")
|
|
|
|
|
|
|
608 |
|
609 |
if __name__ == "__main__":
|
610 |
parser = argparse.ArgumentParser(description="Run the FastAPI server.")
|
src/server/utils/auth.py
CHANGED
@@ -29,7 +29,8 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
29 |
|
30 |
class Settings(BaseSettings):
|
31 |
api_key_secret: str = Field(..., env="API_KEY_SECRET")
|
32 |
-
token_expiration_minutes: int = Field(
|
|
|
33 |
llm_model_name: str = "google/gemma-3-4b-it"
|
34 |
max_tokens: int = 512
|
35 |
host: str = "0.0.0.0"
|
@@ -40,7 +41,6 @@ class Settings(BaseSettings):
|
|
40 |
external_asr_url: str = Field(..., env="EXTERNAL_ASR_URL")
|
41 |
external_text_gen_url: str = Field(..., env="EXTERNAL_TEXT_GEN_URL")
|
42 |
external_audio_proc_url: str = Field(..., env="EXTERNAL_AUDIO_PROC_URL")
|
43 |
-
# Admin credentials required from environment variables, no defaults
|
44 |
default_admin_username: str = Field("admin", env="DEFAULT_ADMIN_USERNAME")
|
45 |
default_admin_password: str = Field("admin54321", env="DEFAULT_ADMIN_PASSWORD")
|
46 |
|
@@ -49,7 +49,6 @@ class Settings(BaseSettings):
|
|
49 |
env_file_encoding = "utf-8"
|
50 |
|
51 |
settings = Settings()
|
52 |
-
#logger.info(f"Loaded API_KEY_SECRET at startup: {settings.api_key_secret}")
|
53 |
|
54 |
# Seed initial data (optional)
|
55 |
def seed_initial_data():
|
@@ -77,9 +76,11 @@ bearer_scheme = HTTPBearer()
|
|
77 |
class TokenPayload(BaseModel):
|
78 |
sub: str
|
79 |
exp: float
|
|
|
80 |
|
81 |
class TokenResponse(BaseModel):
|
82 |
access_token: str
|
|
|
83 |
token_type: str
|
84 |
|
85 |
class LoginRequest(BaseModel):
|
@@ -90,13 +91,15 @@ class RegisterRequest(BaseModel):
|
|
90 |
username: str
|
91 |
password: str
|
92 |
|
93 |
-
async def create_access_token(user_id: str) ->
|
94 |
expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes)
|
95 |
-
payload = {"sub": user_id, "exp": expire.timestamp()}
|
96 |
-
#logger.info(f"Signing token with API_KEY_SECRET: {settings.api_key_secret}")
|
97 |
token = jwt.encode(payload, settings.api_key_secret, algorithm="HS256")
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
100 |
|
101 |
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> str:
|
102 |
token = credentials.credentials
|
@@ -106,10 +109,7 @@ async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(b
|
|
106 |
headers={"WWW-Authenticate": "Bearer"},
|
107 |
)
|
108 |
try:
|
109 |
-
#logger.info(f"Received token: {token}")
|
110 |
-
#logger.info(f"Verifying token with API_KEY_SECRET: {settings.api_key_secret}")
|
111 |
payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"], options={"verify_exp": False})
|
112 |
-
#logger.info(f"Decoded payload: {payload}")
|
113 |
token_data = TokenPayload(**payload)
|
114 |
user_id = token_data.sub
|
115 |
|
@@ -121,9 +121,8 @@ async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(b
|
|
121 |
raise credentials_exception
|
122 |
|
123 |
current_time = datetime.utcnow().timestamp()
|
124 |
-
logger.info(f"Current time: {current_time}, Token exp: {token_data.exp}")
|
125 |
if current_time > token_data.exp:
|
126 |
-
|
127 |
raise HTTPException(
|
128 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
129 |
detail="Token has expired",
|
@@ -159,8 +158,8 @@ async def login(login_request: LoginRequest) -> TokenResponse:
|
|
159 |
if not user or not pwd_context.verify(login_request.password, user.password):
|
160 |
logger.warning(f"Login failed for user: {login_request.username}")
|
161 |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
|
162 |
-
|
163 |
-
return TokenResponse(access_token=
|
164 |
|
165 |
async def register(register_request: RegisterRequest, current_user: str = Depends(get_current_user_with_admin)) -> TokenResponse:
|
166 |
db = SessionLocal()
|
@@ -176,11 +175,26 @@ async def register(register_request: RegisterRequest, current_user: str = Depend
|
|
176 |
db.commit()
|
177 |
db.close()
|
178 |
|
179 |
-
|
180 |
logger.info(f"Registered and generated token for user: {register_request.username} by admin {current_user}")
|
181 |
-
return TokenResponse(access_token=
|
182 |
|
183 |
async def refresh_token(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> TokenResponse:
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
class Settings(BaseSettings):
|
31 |
api_key_secret: str = Field(..., env="API_KEY_SECRET")
|
32 |
+
token_expiration_minutes: int = Field(1440, env="TOKEN_EXPIRATION_MINUTES") # 24 hours
|
33 |
+
refresh_token_expiration_days: int = Field(7, env="REFRESH_TOKEN_EXPIRATION_DAYS") # 7 days
|
34 |
llm_model_name: str = "google/gemma-3-4b-it"
|
35 |
max_tokens: int = 512
|
36 |
host: str = "0.0.0.0"
|
|
|
41 |
external_asr_url: str = Field(..., env="EXTERNAL_ASR_URL")
|
42 |
external_text_gen_url: str = Field(..., env="EXTERNAL_TEXT_GEN_URL")
|
43 |
external_audio_proc_url: str = Field(..., env="EXTERNAL_AUDIO_PROC_URL")
|
|
|
44 |
default_admin_username: str = Field("admin", env="DEFAULT_ADMIN_USERNAME")
|
45 |
default_admin_password: str = Field("admin54321", env="DEFAULT_ADMIN_PASSWORD")
|
46 |
|
|
|
49 |
env_file_encoding = "utf-8"
|
50 |
|
51 |
settings = Settings()
|
|
|
52 |
|
53 |
# Seed initial data (optional)
|
54 |
def seed_initial_data():
|
|
|
76 |
class TokenPayload(BaseModel):
|
77 |
sub: str
|
78 |
exp: float
|
79 |
+
type: str
|
80 |
|
81 |
class TokenResponse(BaseModel):
|
82 |
access_token: str
|
83 |
+
refresh_token: str
|
84 |
token_type: str
|
85 |
|
86 |
class LoginRequest(BaseModel):
|
|
|
91 |
username: str
|
92 |
password: str
|
93 |
|
94 |
+
async def create_access_token(user_id: str) -> dict:
|
95 |
expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes)
|
96 |
+
payload = {"sub": user_id, "exp": expire.timestamp(), "type": "access"}
|
|
|
97 |
token = jwt.encode(payload, settings.api_key_secret, algorithm="HS256")
|
98 |
+
refresh_expire = datetime.utcnow() + timedelta(days=settings.refresh_token_expiration_days)
|
99 |
+
refresh_payload = {"sub": user_id, "exp": refresh_expire.timestamp(), "type": "refresh"}
|
100 |
+
refresh_token = jwt.encode(refresh_payload, settings.api_key_secret, algorithm="HS256")
|
101 |
+
logger.info(f"Generated tokens for user: {user_id}")
|
102 |
+
return {"access_token": token, "refresh_token": refresh_token}
|
103 |
|
104 |
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> str:
|
105 |
token = credentials.credentials
|
|
|
109 |
headers={"WWW-Authenticate": "Bearer"},
|
110 |
)
|
111 |
try:
|
|
|
|
|
112 |
payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"], options={"verify_exp": False})
|
|
|
113 |
token_data = TokenPayload(**payload)
|
114 |
user_id = token_data.sub
|
115 |
|
|
|
121 |
raise credentials_exception
|
122 |
|
123 |
current_time = datetime.utcnow().timestamp()
|
|
|
124 |
if current_time > token_data.exp:
|
125 |
+
logger.warning(f"Token expired: current_time={current_time}, exp={token_data.exp}")
|
126 |
raise HTTPException(
|
127 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
128 |
detail="Token has expired",
|
|
|
158 |
if not user or not pwd_context.verify(login_request.password, user.password):
|
159 |
logger.warning(f"Login failed for user: {login_request.username}")
|
160 |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
|
161 |
+
tokens = await create_access_token(user_id=user.username)
|
162 |
+
return TokenResponse(access_token=tokens["access_token"], refresh_token=tokens["refresh_token"], token_type="bearer")
|
163 |
|
164 |
async def register(register_request: RegisterRequest, current_user: str = Depends(get_current_user_with_admin)) -> TokenResponse:
|
165 |
db = SessionLocal()
|
|
|
175 |
db.commit()
|
176 |
db.close()
|
177 |
|
178 |
+
tokens = await create_access_token(user_id=register_request.username)
|
179 |
logger.info(f"Registered and generated token for user: {register_request.username} by admin {current_user}")
|
180 |
+
return TokenResponse(access_token=tokens["access_token"], refresh_token=tokens["refresh_token"], token_type="bearer")
|
181 |
|
182 |
async def refresh_token(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> TokenResponse:
|
183 |
+
token = credentials.credentials
|
184 |
+
try:
|
185 |
+
payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"])
|
186 |
+
token_data = TokenPayload(**payload)
|
187 |
+
if payload.get("type") != "refresh":
|
188 |
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type; refresh token required")
|
189 |
+
user_id = token_data.sub
|
190 |
+
db = SessionLocal()
|
191 |
+
user = db.query(User).filter_by(username=user_id).first()
|
192 |
+
db.close()
|
193 |
+
if not user:
|
194 |
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
|
195 |
+
if datetime.utcnow().timestamp() > token_data.exp:
|
196 |
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token has expired")
|
197 |
+
tokens = await create_access_token(user_id=user_id)
|
198 |
+
return TokenResponse(access_token=tokens["access_token"], refresh_token=tokens["refresh_token"], token_type="bearer")
|
199 |
+
except jwt.InvalidTokenError:
|
200 |
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
|