sachin commited on
Commit
8d6faeb
·
1 Parent(s): 90b38cc
Files changed (3) hide show
  1. requirements.txt +2 -1
  2. src/server/main.py +11 -7
  3. src/server/utils/auth.py +43 -15
requirements.txt CHANGED
@@ -5,4 +5,5 @@ slowapi
5
  requests
6
  python-multipart
7
  pillow
8
- pyjwt
 
 
5
  requests
6
  python-multipart
7
  pillow
8
+ pyjwt
9
+ sqlalchemy
src/server/main.py CHANGED
@@ -14,9 +14,7 @@ from slowapi.util import get_remote_address
14
  import requests
15
  from PIL import Image
16
 
17
- # Import from auth.py
18
- #from utils.auth import get_current_user, login, TokenResponse, Settings
19
- from utils.auth import get_current_user, login, TokenResponse, Settings, LoginRequest
20
 
21
  # Assuming these are in your project structure
22
  from config.tts_config import SPEED, ResponseFormat, config as tts_config
@@ -96,7 +94,16 @@ class ExternalTTSService(TTSService):
96
  def get_tts_service() -> TTSService:
97
  return ExternalTTSService()
98
 
99
- # Endpoints
 
 
 
 
 
 
 
 
 
100
  @app.get("/v1/health")
101
  async def health_check():
102
  return {"status": "healthy", "model": settings.llm_model_name}
@@ -105,9 +112,6 @@ async def health_check():
105
  async def home():
106
  return RedirectResponse(url="/docs")
107
 
108
- @app.post("/v1/token", response_model=TokenResponse)
109
- async def token(login_request: LoginRequest):
110
- return await login(login_request)
111
 
112
  @app.post("/v1/audio/speech")
113
  @limiter.limit(settings.speech_rate_limit)
 
14
  import requests
15
  from PIL import Image
16
 
17
+ from utils.auth import get_current_user, login, refresh_token, TokenResponse, Settings, LoginRequest
 
 
18
 
19
  # Assuming these are in your project structure
20
  from config.tts_config import SPEED, ResponseFormat, config as tts_config
 
94
  def get_tts_service() -> TTSService:
95
  return ExternalTTSService()
96
 
97
+
98
+
99
+ @app.post("/v1/token", response_model=TokenResponse)
100
+ async def token(login_request: LoginRequest):
101
+ return await login(login_request)
102
+
103
+ @app.post("/v1/refresh", response_model=TokenResponse)
104
+ async def refresh(token_response: TokenResponse = Depends(refresh_token)):
105
+ return token_response
106
+
107
  @app.get("/v1/health")
108
  async def health_check():
109
  return {"status": "healthy", "model": settings.llm_model_name}
 
112
  async def home():
113
  return RedirectResponse(url="/docs")
114
 
 
 
 
115
 
116
  @app.post("/v1/audio/speech")
117
  @limiter.limit(settings.speech_rate_limit)
src/server/utils/auth.py CHANGED
@@ -5,7 +5,33 @@ from fastapi import HTTPException, status, Depends
5
  from pydantic import BaseModel, Field
6
  from pydantic_settings import BaseSettings
7
  from config.logging_config import logger
8
- from typing import Dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  class Settings(BaseSettings):
11
  api_key_secret: str = Field(..., env="API_KEY_SECRET")
@@ -38,13 +64,6 @@ class TokenResponse(BaseModel):
38
  access_token: str
39
  token_type: str
40
 
41
- # Simple in-memory user store (replace with database in production)
42
- # Format: {username: password}
43
- USERS_DB: Dict[str, str] = {
44
- "testuser": "password123",
45
- "admin": "adminpass"
46
- }
47
-
48
  class LoginRequest(BaseModel):
49
  username: str
50
  password: str
@@ -71,7 +90,10 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> str:
71
  token_data = TokenPayload(**payload)
72
  user_id = token_data.sub
73
 
74
- if user_id is None or user_id not in USERS_DB:
 
 
 
75
  logger.warning(f"Invalid or unknown user: {user_id}")
76
  raise credentials_exception
77
 
@@ -98,10 +120,16 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> str:
98
  raise credentials_exception
99
 
100
  async def login(login_request: LoginRequest) -> TokenResponse:
101
- username = login_request.username
102
- password = login_request.password
103
- if username not in USERS_DB or USERS_DB[username] != password:
104
- logger.warning(f"Login failed for user: {username}")
 
105
  raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
106
- token = await create_access_token(user_id=username)
107
- return TokenResponse(access_token=token, token_type="bearer")
 
 
 
 
 
 
5
  from pydantic import BaseModel, Field
6
  from pydantic_settings import BaseSettings
7
  from config.logging_config import logger
8
+ from sqlalchemy import create_engine, Column, String
9
+ from sqlalchemy.ext.declarative import declarative_base
10
+ from sqlalchemy.orm import sessionmaker
11
+
12
+ # SQLite database setup
13
+ DATABASE_URL = "sqlite:///users.db"
14
+ engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) # For SQLite threading
15
+ Base = declarative_base()
16
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
17
+
18
+ class User(Base):
19
+ __tablename__ = "users"
20
+ username = Column(String, primary_key=True, index=True)
21
+ password = Column(String) # In production, hash this with bcrypt
22
+
23
+ # Create the database tables
24
+ Base.metadata.create_all(bind=engine)
25
+
26
+ # Seed initial data (optional, for testing)
27
+ def seed_initial_data():
28
+ db = SessionLocal()
29
+ if not db.query(User).filter_by(username="testuser").first():
30
+ db.add(User(username="testuser", password="password123"))
31
+ db.commit()
32
+ db.close()
33
+
34
+ seed_initial_data() # Run once at startup
35
 
36
  class Settings(BaseSettings):
37
  api_key_secret: str = Field(..., env="API_KEY_SECRET")
 
64
  access_token: str
65
  token_type: str
66
 
 
 
 
 
 
 
 
67
  class LoginRequest(BaseModel):
68
  username: str
69
  password: str
 
90
  token_data = TokenPayload(**payload)
91
  user_id = token_data.sub
92
 
93
+ db = SessionLocal()
94
+ user = db.query(User).filter_by(username=user_id).first()
95
+ db.close()
96
+ if user_id is None or not user:
97
  logger.warning(f"Invalid or unknown user: {user_id}")
98
  raise credentials_exception
99
 
 
120
  raise credentials_exception
121
 
122
  async def login(login_request: LoginRequest) -> TokenResponse:
123
+ db = SessionLocal()
124
+ user = db.query(User).filter_by(username=login_request.username).first()
125
+ db.close()
126
+ if not user or user.password != login_request.password: # Use password hashing in production
127
+ logger.warning(f"Login failed for user: {login_request.username}")
128
  raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
129
+ token = await create_access_token(user_id=user.username)
130
+ return TokenResponse(access_token=token, token_type="bearer")
131
+
132
+ async def refresh_token(token: str = Depends(oauth2_scheme)) -> TokenResponse:
133
+ user_id = await get_current_user(token)
134
+ new_token = await create_access_token(user_id=user_id)
135
+ return TokenResponse(access_token=new_token, token_type="bearer")