Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
sachin
commited on
Commit
·
8d6faeb
1
Parent(s):
90b38cc
add-user/
Browse files- requirements.txt +2 -1
- src/server/main.py +11 -7
- src/server/utils/auth.py +43 -15
requirements.txt
CHANGED
@@ -5,4 +5,5 @@ slowapi
|
|
5 |
requests
|
6 |
python-multipart
|
7 |
pillow
|
8 |
-
pyjwt
|
|
|
|
5 |
requests
|
6 |
python-multipart
|
7 |
pillow
|
8 |
+
pyjwt
|
9 |
+
sqlalchemy
|
src/server/main.py
CHANGED
@@ -14,9 +14,7 @@ from slowapi.util import get_remote_address
|
|
14 |
import requests
|
15 |
from PIL import Image
|
16 |
|
17 |
-
|
18 |
-
#from utils.auth import get_current_user, login, TokenResponse, Settings
|
19 |
-
from utils.auth import get_current_user, login, TokenResponse, Settings, LoginRequest
|
20 |
|
21 |
# Assuming these are in your project structure
|
22 |
from config.tts_config import SPEED, ResponseFormat, config as tts_config
|
@@ -96,7 +94,16 @@ class ExternalTTSService(TTSService):
|
|
96 |
def get_tts_service() -> TTSService:
|
97 |
return ExternalTTSService()
|
98 |
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
@app.get("/v1/health")
|
101 |
async def health_check():
|
102 |
return {"status": "healthy", "model": settings.llm_model_name}
|
@@ -105,9 +112,6 @@ async def health_check():
|
|
105 |
async def home():
|
106 |
return RedirectResponse(url="/docs")
|
107 |
|
108 |
-
@app.post("/v1/token", response_model=TokenResponse)
|
109 |
-
async def token(login_request: LoginRequest):
|
110 |
-
return await login(login_request)
|
111 |
|
112 |
@app.post("/v1/audio/speech")
|
113 |
@limiter.limit(settings.speech_rate_limit)
|
|
|
14 |
import requests
|
15 |
from PIL import Image
|
16 |
|
17 |
+
from utils.auth import get_current_user, login, refresh_token, TokenResponse, Settings, LoginRequest
|
|
|
|
|
18 |
|
19 |
# Assuming these are in your project structure
|
20 |
from config.tts_config import SPEED, ResponseFormat, config as tts_config
|
|
|
94 |
def get_tts_service() -> TTSService:
|
95 |
return ExternalTTSService()
|
96 |
|
97 |
+
|
98 |
+
|
99 |
+
@app.post("/v1/token", response_model=TokenResponse)
|
100 |
+
async def token(login_request: LoginRequest):
|
101 |
+
return await login(login_request)
|
102 |
+
|
103 |
+
@app.post("/v1/refresh", response_model=TokenResponse)
|
104 |
+
async def refresh(token_response: TokenResponse = Depends(refresh_token)):
|
105 |
+
return token_response
|
106 |
+
|
107 |
@app.get("/v1/health")
|
108 |
async def health_check():
|
109 |
return {"status": "healthy", "model": settings.llm_model_name}
|
|
|
112 |
async def home():
|
113 |
return RedirectResponse(url="/docs")
|
114 |
|
|
|
|
|
|
|
115 |
|
116 |
@app.post("/v1/audio/speech")
|
117 |
@limiter.limit(settings.speech_rate_limit)
|
src/server/utils/auth.py
CHANGED
@@ -5,7 +5,33 @@ from fastapi import HTTPException, status, Depends
|
|
5 |
from pydantic import BaseModel, Field
|
6 |
from pydantic_settings import BaseSettings
|
7 |
from config.logging_config import logger
|
8 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
class Settings(BaseSettings):
|
11 |
api_key_secret: str = Field(..., env="API_KEY_SECRET")
|
@@ -38,13 +64,6 @@ class TokenResponse(BaseModel):
|
|
38 |
access_token: str
|
39 |
token_type: str
|
40 |
|
41 |
-
# Simple in-memory user store (replace with database in production)
|
42 |
-
# Format: {username: password}
|
43 |
-
USERS_DB: Dict[str, str] = {
|
44 |
-
"testuser": "password123",
|
45 |
-
"admin": "adminpass"
|
46 |
-
}
|
47 |
-
|
48 |
class LoginRequest(BaseModel):
|
49 |
username: str
|
50 |
password: str
|
@@ -71,7 +90,10 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> str:
|
|
71 |
token_data = TokenPayload(**payload)
|
72 |
user_id = token_data.sub
|
73 |
|
74 |
-
|
|
|
|
|
|
|
75 |
logger.warning(f"Invalid or unknown user: {user_id}")
|
76 |
raise credentials_exception
|
77 |
|
@@ -98,10 +120,16 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> str:
|
|
98 |
raise credentials_exception
|
99 |
|
100 |
async def login(login_request: LoginRequest) -> TokenResponse:
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
105 |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
|
106 |
-
token = await create_access_token(user_id=username)
|
107 |
-
return TokenResponse(access_token=token, token_type="bearer")
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from pydantic import BaseModel, Field
|
6 |
from pydantic_settings import BaseSettings
|
7 |
from config.logging_config import logger
|
8 |
+
from sqlalchemy import create_engine, Column, String
|
9 |
+
from sqlalchemy.ext.declarative import declarative_base
|
10 |
+
from sqlalchemy.orm import sessionmaker
|
11 |
+
|
12 |
+
# SQLite database setup
|
13 |
+
DATABASE_URL = "sqlite:///users.db"
|
14 |
+
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) # For SQLite threading
|
15 |
+
Base = declarative_base()
|
16 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
17 |
+
|
18 |
+
class User(Base):
|
19 |
+
__tablename__ = "users"
|
20 |
+
username = Column(String, primary_key=True, index=True)
|
21 |
+
password = Column(String) # In production, hash this with bcrypt
|
22 |
+
|
23 |
+
# Create the database tables
|
24 |
+
Base.metadata.create_all(bind=engine)
|
25 |
+
|
26 |
+
# Seed initial data (optional, for testing)
|
27 |
+
def seed_initial_data():
|
28 |
+
db = SessionLocal()
|
29 |
+
if not db.query(User).filter_by(username="testuser").first():
|
30 |
+
db.add(User(username="testuser", password="password123"))
|
31 |
+
db.commit()
|
32 |
+
db.close()
|
33 |
+
|
34 |
+
seed_initial_data() # Run once at startup
|
35 |
|
36 |
class Settings(BaseSettings):
|
37 |
api_key_secret: str = Field(..., env="API_KEY_SECRET")
|
|
|
64 |
access_token: str
|
65 |
token_type: str
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
class LoginRequest(BaseModel):
|
68 |
username: str
|
69 |
password: str
|
|
|
90 |
token_data = TokenPayload(**payload)
|
91 |
user_id = token_data.sub
|
92 |
|
93 |
+
db = SessionLocal()
|
94 |
+
user = db.query(User).filter_by(username=user_id).first()
|
95 |
+
db.close()
|
96 |
+
if user_id is None or not user:
|
97 |
logger.warning(f"Invalid or unknown user: {user_id}")
|
98 |
raise credentials_exception
|
99 |
|
|
|
120 |
raise credentials_exception
|
121 |
|
122 |
async def login(login_request: LoginRequest) -> TokenResponse:
|
123 |
+
db = SessionLocal()
|
124 |
+
user = db.query(User).filter_by(username=login_request.username).first()
|
125 |
+
db.close()
|
126 |
+
if not user or user.password != login_request.password: # Use password hashing in production
|
127 |
+
logger.warning(f"Login failed for user: {login_request.username}")
|
128 |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
|
129 |
+
token = await create_access_token(user_id=user.username)
|
130 |
+
return TokenResponse(access_token=token, token_type="bearer")
|
131 |
+
|
132 |
+
async def refresh_token(token: str = Depends(oauth2_scheme)) -> TokenResponse:
|
133 |
+
user_id = await get_current_user(token)
|
134 |
+
new_token = await create_access_token(user_id=user_id)
|
135 |
+
return TokenResponse(access_token=new_token, token_type="bearer")
|