sachin
add-new table
944b400
raw
history blame
14.4 kB
import jwt
from datetime import datetime, timedelta
from fastapi import HTTPException, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
from config.logging_config import logger
from sqlalchemy import create_engine, Column, String, Boolean
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from passlib.context import CryptContext
import os
import base64
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
# SQLite database setup with Hugging Face persistent storage
DATABASE_PATH = "/data/users.db"
DATABASE_URL = f"sqlite:///{DATABASE_PATH}"
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
Base = declarative_base()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Model for admin-related users
class User(Base):
__tablename__ = "users"
username = Column(String, primary_key=True, index=True)
password = Column(String) # Stores hashed passwords
is_admin = Column(Boolean, default=False)
session_key = Column(String, nullable=True) # Stores base64-encoded session key
# Model for app users
class AppUser(Base):
__tablename__ = "app_users"
username = Column(String, primary_key=True, index=True)
password = Column(String) # Stores hashed passwords
session_key = Column(String, nullable=True) # Stores base64-encoded session key
# Ensure the /data directory exists
os.makedirs(os.path.dirname(DATABASE_PATH), exist_ok=True)
# Create database tables
Base.metadata.create_all(bind=engine)
# Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class Settings(BaseSettings):
api_key_secret: str = Field(..., env="API_KEY_SECRET")
token_expiration_minutes: int = Field(1440, env="TOKEN_EXPIRATION_MINUTES")
refresh_token_expiration_days: int = Field(7, env="REFRESH_TOKEN_EXPIRATION_DAYS")
llm_model_name: str = "google/gemma-3-4b-it"
max_tokens: int = 512
host: str = "0.0.0.0"
port: int = 7860
chat_rate_limit: str = "100/minute"
speech_rate_limit: str = "5/minute"
external_tts_url: str = Field(..., env="EXTERNAL_TTS_URL")
external_asr_url: str = Field(..., env="EXTERNAL_ASR_URL")
external_text_gen_url: str = Field(..., env="EXTERNAL_TEXT_GEN_URL")
external_audio_proc_url: str = Field(..., env="EXTERNAL_AUDIO_PROC_URL")
default_admin_username: str = Field("admin", env="DEFAULT_ADMIN_USERNAME")
default_admin_password: str = Field("admin54321", env="DEFAULT_ADMIN_PASSWORD")
database_path: str = DATABASE_PATH
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
settings = Settings()
# Seed initial data for users table only
def seed_initial_data():
db = SessionLocal()
try:
test_username = "[email protected]"
if not db.query(User).filter_by(username=test_username).first():
test_device_token = "550e8400-e29b-41d4-a716-446655440000"
hashed_password = pwd_context.hash(test_device_token)
session_key = base64.b64encode(get_random_bytes(16)).decode('utf-8')
db.add(User(username=test_username, password=hashed_password, is_admin=False, session_key=session_key))
db.commit()
admin_username = settings.default_admin_username
admin_password = settings.default_admin_password
if not db.query(User).filter_by(username=admin_username).first():
hashed_password = pwd_context.hash(admin_password)
session_key = base64.b64encode(get_random_bytes(16)).decode('utf-8')
db.add(User(username=admin_username, password=hashed_password, is_admin=True, session_key=session_key))
db.commit()
logger.info(f"Seeded initial data: test user '{test_username}', admin user '{admin_username}'")
except Exception as e:
logger.error(f"Error seeding initial data: {str(e)}")
db.rollback()
finally:
db.close()
seed_initial_data()
bearer_scheme = HTTPBearer()
class TokenPayload(BaseModel):
sub: str
exp: float
type: str
class TokenResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str
class LoginRequest(BaseModel):
username: str
password: str
class RegisterRequest(BaseModel):
username: str
password: str
def decrypt_data(encrypted_data: str, key: bytes) -> str:
try:
data = base64.b64decode(encrypted_data)
nonce, ciphertext = data[:12], data[12:]
cipher = AES.new(key, AES.MODE_GCM, nonce=nonce)
plaintext = cipher.decrypt_and_verify(ciphertext[:-16], ciphertext[-16:])
return plaintext.decode('utf-8')
except Exception as e:
logger.error(f"Decryption failed: {str(e)}")
raise HTTPException(status_code=400, detail="Invalid encrypted data")
async def create_access_token(user_id: str) -> dict:
expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes)
payload = {"sub": user_id, "exp": expire.timestamp(), "type": "access"}
token = jwt.encode(payload, settings.api_key_secret, algorithm="HS256")
refresh_expire = datetime.utcnow() + timedelta(days=settings.refresh_token_expiration_days)
refresh_payload = {"sub": user_id, "exp": refresh_expire.timestamp(), "type": "refresh"}
refresh_token = jwt.encode(refresh_payload, settings.api_key_secret, algorithm="HS256")
logger.info(f"Generated tokens for user: {user_id}")
return {"access_token": token, "refresh_token": refresh_token}
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> str:
token = credentials.credentials
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"], options={"verify_exp": False})
token_data = TokenPayload(**payload)
user_id = token_data.sub
db = SessionLocal()
# Check both users and app_users tables
user = db.query(User).filter_by(username=user_id).first()
app_user = db.query(AppUser).filter_by(username=user_id).first()
db.close()
if user_id is None or (not user and not app_user):
logger.warning(f"Invalid or unknown user: {user_id}")
raise credentials_exception
current_time = datetime.utcnow().timestamp()
if current_time > token_data.exp:
logger.warning(f"Token expired: current_time={current_time}, exp={token_data.exp}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired",
headers={"WWW-Authenticate": "Bearer"},
)
logger.info(f"Validated token for user: {user_id}")
return user_id
except jwt.InvalidSignatureError as e:
logger.error(f"Invalid signature error: {str(e)}")
raise credentials_exception
except jwt.InvalidTokenError as e:
logger.error(f"Other token error: {str(e)}")
raise credentials_exception
except Exception as e:
logger.error(f"Unexpected token validation error: {str(e)}")
raise credentials_exception
async def get_current_user_with_admin(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> str:
token = credentials.credentials
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"])
token_data = TokenPayload(**payload)
user_id = token_data.sub
db = SessionLocal()
user = db.query(User).filter_by(username=user_id).first()
db.close()
if not user:
logger.warning(f"User not found in users table: {user_id}")
raise credentials_exception
if not user.is_admin:
logger.warning(f"User {user_id} is not authorized as admin")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin access required; only admin accounts can perform this action"
)
logger.info(f"Validated admin user: {user_id}")
return user_id
except jwt.InvalidSignatureError as e:
logger.error(f"Invalid signature error: {str(e)}")
raise credentials_exception
except jwt.InvalidTokenError as e:
logger.error(f"Other token error: {str(e)}")
raise credentials_exception
except Exception as e:
logger.error(f"Unexpected admin validation error: {str(e)}")
raise credentials_exception
async def login(login_request: LoginRequest, session_key_b64: str) -> TokenResponse:
db = SessionLocal()
session_key = base64.b64decode(session_key_b64)
try:
username = decrypt_data(login_request.username, session_key)
password = decrypt_data(login_request.password, session_key)
except:
db.close()
raise HTTPException(status_code=400, detail="Invalid encrypted data")
# Check both users and app_users tables
user = db.query(User).filter_by(username=username).first()
app_user = db.query(AppUser).filter_by(username=username).first()
if not user and not app_user:
db.close()
logger.warning(f"Login failed for user: {username}")
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or device token")
target_user = user if user else app_user
if not pwd_context.verify(password, target_user.password):
db.close()
logger.warning(f"Login failed for user: {username}")
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or device token")
if target_user.session_key != session_key_b64:
target_user.session_key = session_key_b64
db.commit()
db.close()
tokens = await create_access_token(user_id=username)
return TokenResponse(access_token=tokens["access_token"], refresh_token=tokens["refresh_token"], token_type="bearer")
async def register(register_request: RegisterRequest, current_user: str = Depends(get_current_user_with_admin)) -> TokenResponse:
db = SessionLocal()
try:
existing_user = db.query(User).filter_by(username=register_request.username).first()
if existing_user:
logger.warning(f"Registration failed: Username {register_request.username} already exists")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already exists")
hashed_password = pwd_context.hash(register_request.password)
new_user = User(username=register_request.username, password=hashed_password, is_admin=False)
db.add(new_user)
db.commit()
logger.info(f"Admin {current_user} successfully registered new user: {register_request.username}")
tokens = await create_access_token(user_id=register_request.username)
return TokenResponse(access_token=tokens["access_token"], refresh_token=tokens["refresh_token"], token_type="bearer")
except Exception as e:
db.rollback()
logger.error(f"Registration error by admin {current_user}: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Registration failed: {str(e)}")
finally:
db.close()
async def app_register(register_request: RegisterRequest, session_key_b64: str) -> TokenResponse:
db = SessionLocal()
session_key = base64.b64decode(session_key_b64)
try:
username = decrypt_data(register_request.username, session_key)
password = decrypt_data(register_request.password, session_key)
except:
db.close()
raise HTTPException(status_code=400, detail="Invalid encrypted data")
# Check both tables to prevent duplicate usernames
existing_user = db.query(User).filter_by(username=username).first()
existing_app_user = db.query(AppUser).filter_by(username=username).first()
if existing_user or existing_app_user:
db.close()
logger.warning(f"App registration failed: Email {username} already exists")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered")
hashed_password = pwd_context.hash(password)
new_app_user = AppUser(username=username, password=hashed_password, session_key=session_key_b64)
db.add(new_app_user)
db.commit()
db.close()
tokens = await create_access_token(user_id=username)
logger.info(f"App registered new user: {username}")
return TokenResponse(access_token=tokens["access_token"], refresh_token=tokens["refresh_token"], token_type="bearer")
async def refresh_token(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> TokenResponse:
token = credentials.credentials
try:
payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"])
token_data = TokenPayload(**payload)
if payload.get("type") != "refresh":
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type; refresh token required")
user_id = token_data.sub
db = SessionLocal()
# Check both users and app_users tables
user = db.query(User).filter_by(username=user_id).first()
app_user = db.query(AppUser).filter_by(username=user_id).first()
db.close()
if not user and not app_user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
tokens = await create_access_token(user_id=user_id)
return TokenResponse(access_token=tokens["access_token"], refresh_token=tokens["refresh_token"], token_type="bearer")
except jwt.InvalidTokenError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")