File size: 8,039 Bytes
844386f
 
9781b82
26b1877
4335561
9781b82
4335561
26b1877
8d6faeb
 
5a8554e
 
8d6faeb
 
445a506
8d6faeb
 
 
 
 
 
445a506
26b1877
8d6faeb
 
 
445a506
 
 
9781b82
4335561
 
844386f
 
 
 
 
 
4335561
 
 
 
a403177
3fb3087
 
844386f
9781b82
 
844386f
9781b82
 
4335561
844386f
a403177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26b1877
 
844386f
 
4335561
 
844386f
 
 
 
 
90b38cc
 
 
 
445a506
 
 
 
844386f
 
 
4335561
844386f
 
 
9781b82
26b1877
 
844386f
 
 
 
 
 
4335561
 
 
 
844386f
 
4335561
8d6faeb
 
 
 
90b38cc
844386f
4335561
 
 
 
 
844386f
 
 
 
 
4335561
844386f
 
4335561
 
 
 
 
844386f
 
4335561
844386f
9781b82
26b1877
 
 
 
 
 
 
 
 
 
90b38cc
8d6faeb
 
 
5a8554e
8d6faeb
90b38cc
8d6faeb
 
 
26b1877
445a506
 
 
 
 
 
 
 
26b1877
445a506
 
 
 
 
26b1877
445a506
 
26b1877
 
8d6faeb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import jwt
from datetime import datetime, timedelta
from fastapi import HTTPException, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
from config.logging_config import logger
from sqlalchemy import create_engine, Column, String, Boolean
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from passlib.context import CryptContext

# SQLite database setup
DATABASE_URL = "sqlite:///users.db"
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
Base = declarative_base()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

class User(Base):
    __tablename__ = "users"
    username = Column(String, primary_key=True, index=True)
    password = Column(String)  # Stores hashed passwords
    is_admin = Column(Boolean, default=False)  # New admin flag

Base.metadata.create_all(bind=engine)

# Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

class Settings(BaseSettings):
    api_key_secret: str = Field(..., env="API_KEY_SECRET")
    token_expiration_minutes: int = Field(30, env="TOKEN_EXPIRATION_MINUTES")
    llm_model_name: str = "google/gemma-3-4b-it"
    max_tokens: int = 512
    host: str = "0.0.0.0"
    port: int = 7860
    chat_rate_limit: str = "100/minute"
    speech_rate_limit: str = "5/minute"
    external_tts_url: str = Field(..., env="EXTERNAL_TTS_URL")
    external_asr_url: str = Field(..., env="EXTERNAL_ASR_URL")
    external_text_gen_url: str = Field(..., env="EXTERNAL_TEXT_GEN_URL")
    external_audio_proc_url: str = Field(..., env="EXTERNAL_AUDIO_PROC_URL")
    # Admin credentials required from environment variables, no defaults
    default_admin_username: str = Field("admin", env="DEFAULT_ADMIN_USERNAME")
    default_admin_password: str = Field("admin54321", env="DEFAULT_ADMIN_PASSWORD")

    class Config:
        env_file = ".env"
        env_file_encoding = "utf-8"

settings = Settings()
logger.info(f"Loaded API_KEY_SECRET at startup: {settings.api_key_secret}")

# Seed initial data (optional)
def seed_initial_data():
    db = SessionLocal()
    # Seed test user (non-admin)
    if not db.query(User).filter_by(username="testuser").first():
        hashed_password = pwd_context.hash("password123")
        db.add(User(username="testuser", password=hashed_password, is_admin=False))
        db.commit()
    # Seed admin user using environment variables
    admin_username = settings.default_admin_username
    admin_password = settings.default_admin_password
    if not db.query(User).filter_by(username=admin_username).first():
        hashed_password = pwd_context.hash(admin_password)
        db.add(User(username=admin_username, password=hashed_password, is_admin=True))
        db.commit()
    db.close()
    logger.info(f"Seeded initial data: admin user '{admin_username}'")

seed_initial_data()

# Use HTTPBearer
bearer_scheme = HTTPBearer()

class TokenPayload(BaseModel):
    sub: str
    exp: float

class TokenResponse(BaseModel):
    access_token: str
    token_type: str

class LoginRequest(BaseModel):
    username: str
    password: str

class RegisterRequest(BaseModel):
    username: str
    password: str

async def create_access_token(user_id: str) -> str:
    expire = datetime.utcnow() + timedelta(minutes=settings.token_expiration_minutes)
    payload = {"sub": user_id, "exp": expire.timestamp()}
    logger.info(f"Signing token with API_KEY_SECRET: {settings.api_key_secret}")
    token = jwt.encode(payload, settings.api_key_secret, algorithm="HS256")
    logger.info(f"Generated access token for user: {user_id}")
    return token

async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> str:
    token = credentials.credentials
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Invalid authentication credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        logger.info(f"Received token: {token}")
        logger.info(f"Verifying token with API_KEY_SECRET: {settings.api_key_secret}")
        payload = jwt.decode(token, settings.api_key_secret, algorithms=["HS256"], options={"verify_exp": False})
        logger.info(f"Decoded payload: {payload}")
        token_data = TokenPayload(**payload)
        user_id = token_data.sub
        
        db = SessionLocal()
        user = db.query(User).filter_by(username=user_id).first()
        db.close()
        if user_id is None or not user:
            logger.warning(f"Invalid or unknown user: {user_id}")
            raise credentials_exception
        
        current_time = datetime.utcnow().timestamp()
        logger.info(f"Current time: {current_time}, Token exp: {token_data.exp}")
        if current_time > token_data.exp:
            logger.warning(f"Token expired: current_time={current_time}, exp={token_data.exp}")
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Token has expired",
                headers={"WWW-Authenticate": "Bearer"},
            )
        
        logger.info(f"Validated token for user: {user_id}")
        return user_id
    except jwt.InvalidSignatureError as e:
        logger.error(f"Invalid signature error: {str(e)}")
        raise credentials_exception
    except jwt.InvalidTokenError as e:
        logger.error(f"Other token error: {str(e)}")
        raise credentials_exception
    except Exception as e:
        logger.error(f"Unexpected token validation error: {str(e)}")
        raise credentials_exception

async def get_current_user_with_admin(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> str:
    user_id = await get_current_user(credentials)
    db = SessionLocal()
    user = db.query(User).filter_by(username=user_id).first()
    db.close()
    if not user or not user.is_admin:
        logger.warning(f"User {user_id} is not authorized as admin")
        raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required")
    return user_id

async def login(login_request: LoginRequest) -> TokenResponse:
    db = SessionLocal()
    user = db.query(User).filter_by(username=login_request.username).first()
    db.close()
    if not user or not pwd_context.verify(login_request.password, user.password):
        logger.warning(f"Login failed for user: {login_request.username}")
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
    token = await create_access_token(user_id=user.username)
    return TokenResponse(access_token=token, token_type="bearer")

async def register(register_request: RegisterRequest, current_user: str = Depends(get_current_user_with_admin)) -> TokenResponse:
    db = SessionLocal()
    existing_user = db.query(User).filter_by(username=register_request.username).first()
    if existing_user:
        db.close()
        logger.warning(f"Registration failed: Username {register_request.username} already exists")
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already exists")
    
    hashed_password = pwd_context.hash(register_request.password)
    new_user = User(username=register_request.username, password=hashed_password, is_admin=False)
    db.add(new_user)
    db.commit()
    db.close()
    
    token = await create_access_token(user_id=register_request.username)
    logger.info(f"Registered and generated token for user: {register_request.username} by admin {current_user}")
    return TokenResponse(access_token=token, token_type="bearer")

async def refresh_token(credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme)) -> TokenResponse:
    user_id = await get_current_user(credentials)
    new_token = await create_access_token(user_id=user_id)
    return TokenResponse(access_token=new_token, token_type="bearer")