Spaces:
Building
Building
"""Admin API endpoints for Flare | |
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
Provides authentication, project, version, and API management endpoints. | |
""" | |
import os | |
import sys | |
import hashlib | |
import json | |
import jwt | |
import httpx | |
from datetime import datetime, timedelta, timezone | |
from typing import Optional, List, Dict, Any | |
from pathlib import Path | |
import threading | |
import time | |
import bcrypt | |
from fastapi import APIRouter, HTTPException, Depends, Body, Query | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from pydantic import BaseModel, Field | |
from utils import log | |
from config_provider import ConfigProvider, ProviderSettings | |
from encryption_utils import encrypt, decrypt | |
# ===================== JWT Config ===================== | |
def get_jwt_config(): | |
"""Get JWT configuration based on environment""" | |
# Check if running in HuggingFace Space | |
if os.environ.get("SPACE_ID"): | |
# Cloud mode - use secrets from environment | |
jwt_secret = os.getenv("JWT_SECRET") | |
if not jwt_secret: | |
log("β οΈ WARNING: JWT_SECRET not found in environment, using fallback") | |
jwt_secret = "flare-admin-secret-key-change-in-production" # Fallback | |
else: | |
# On-premise mode - use .env file | |
from dotenv import load_dotenv | |
load_dotenv() | |
jwt_secret = os.getenv("JWT_SECRET", "flare-admin-secret-key-change-in-production") | |
return { | |
"secret": jwt_secret, | |
"algorithm": os.getenv("JWT_ALGORITHM", "HS256"), | |
"expiration_hours": int(os.getenv("JWT_EXPIRATION_HOURS", "24")) | |
} | |
# ===================== Constants & Config ===================== | |
router = APIRouter(prefix="/api") | |
security = HTTPBearer() | |
# ===================== Models ===================== | |
class LoginRequest(BaseModel): | |
username: str | |
password: str | |
class LoginResponse(BaseModel): | |
token: str | |
username: str | |
class ChangePasswordRequest(BaseModel): | |
current_password: str | |
new_password: str | |
class ProviderSettingsUpdate(BaseModel): | |
name: str | |
api_key: Optional[str] = None | |
endpoint: Optional[str] = None | |
settings: Dict[str, Any] = Field(default_factory=dict) | |
class EnvironmentUpdate(BaseModel): | |
llm_provider: ProviderSettingsUpdate | |
tts_provider: ProviderSettingsUpdate | |
stt_provider: ProviderSettingsUpdate | |
class ProjectCreate(BaseModel): | |
name: str | |
caption: Optional[str] = "" | |
icon: Optional[str] = "folder" | |
description: Optional[str] = "" | |
default_locale: str = "tr" | |
supported_locales: List[str] = ["tr"] | |
class ProjectUpdate(BaseModel): | |
caption: Optional[str] = None | |
icon: Optional[str] = None | |
description: Optional[str] = None | |
default_locale: Optional[str] = None | |
supported_locales: Optional[List[str]] = None | |
class VersionCreate(BaseModel): | |
caption: Optional[str] = "" | |
general_prompt: str | |
llm: Optional[Dict[str, Any]] = None | |
intents: List[Dict[str, Any]] = [] | |
class VersionUpdate(BaseModel): | |
caption: Optional[str] = None | |
general_prompt: Optional[str] = None | |
llm: Optional[Dict[str, Any]] = None | |
intents: Optional[List[Dict[str, Any]]] = None | |
class APICreate(BaseModel): | |
name: str | |
url: str | |
method: str = "GET" | |
headers: Dict[str, Any] = {} | |
body_template: Dict[str, Any] = {} | |
timeout_seconds: int = 10 | |
auth: Optional[Dict[str, Any]] = None | |
response_prompt: Optional[str] = None | |
class APIUpdate(BaseModel): | |
url: Optional[str] = None | |
method: Optional[str] = None | |
headers: Optional[Dict[str, Any]] = None | |
body_template: Optional[Dict[str, Any]] = None | |
timeout_seconds: Optional[int] = None | |
auth: Optional[Dict[str, Any]] = None | |
response_prompt: Optional[str] = None | |
# ===================== Auth Helpers ===================== | |
def create_token(username: str) -> str: | |
"""Create JWT token""" | |
jwt_config = get_jwt_config() | |
payload = { | |
"sub": username, | |
"exp": datetime.now(timezone.utc) + timedelta(hours=jwt_config["expiration_hours"]), | |
"iat": datetime.now(timezone.utc) | |
} | |
return jwt.encode(payload, jwt_config["secret"], algorithm=jwt_config["algorithm"]) | |
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str: | |
"""Verify JWT token and return username""" | |
jwt_config = get_jwt_config() | |
try: | |
payload = jwt.decode( | |
credentials.credentials, | |
jwt_config["secret"], | |
algorithms=[jwt_config["algorithm"]] | |
) | |
username = payload.get("sub") | |
if not username: | |
raise HTTPException(status_code=401, detail="Invalid token") | |
return username | |
except jwt.ExpiredSignatureError: | |
raise HTTPException(status_code=401, detail="Token expired") | |
except jwt.InvalidTokenError: | |
raise HTTPException(status_code=401, detail="Invalid token") | |
def verify_password(plain_password: str, hashed_password: str, salt: str) -> bool: | |
"""Verify password against hash""" | |
# For bcrypt hashes | |
if hashed_password.startswith("$2b$"): | |
return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8')) | |
# For SHA256 hashes (legacy) | |
password_with_salt = plain_password + salt | |
password_hash = hashlib.sha256(password_with_salt.encode()).hexdigest() | |
return password_hash == hashed_password | |
def hash_password(password: str) -> tuple[str, str]: | |
"""Hash password and return (hash, salt)""" | |
# Use bcrypt for new passwords | |
salt = bcrypt.gensalt() | |
hashed = bcrypt.hashpw(password.encode('utf-8'), salt) | |
return hashed.decode('utf-8'), "" # bcrypt includes salt in hash | |
# ===================== Login Endpoints ===================== | |
async def login(request: LoginRequest): | |
"""User login""" | |
cfg = ConfigProvider.get() | |
# Find user | |
user = next((u for u in cfg.global_config.users if u.username == request.username), None) | |
if not user: | |
raise HTTPException(status_code=401, detail="Invalid credentials") | |
# Verify password | |
if not verify_password(request.password, user.password_hash, user.salt): | |
raise HTTPException(status_code=401, detail="Invalid credentials") | |
# Create token | |
token = create_token(request.username) | |
log(f"β User '{request.username}' logged in") | |
return LoginResponse(token=token, username=request.username) | |
async def change_password( | |
request: ChangePasswordRequest, | |
username: str = Depends(verify_token) | |
): | |
"""Change user password""" | |
cfg = ConfigProvider.get() | |
# Find user | |
user = next((u for u in cfg.global_config.users if u.username == username), None) | |
if not user: | |
raise HTTPException(status_code=404, detail="User not found") | |
# Verify current password | |
if not verify_password(request.current_password, user.password_hash, user.salt): | |
raise HTTPException(status_code=401, detail="Current password is incorrect") | |
# Hash new password | |
new_hash, new_salt = hash_password(request.new_password) | |
# Update user | |
user.password_hash = new_hash | |
user.salt = new_salt | |
# Save config | |
ConfigProvider._save() | |
log(f"β Password changed for user '{username}'") | |
return {"success": True} | |
# ===================== Locales Endpoints ===================== | |
async def get_available_locales(username: str = Depends(verify_token)): | |
"""Get all system-supported locales""" | |
from locale_manager import LocaleManager | |
locales = LocaleManager.get_available_locales_with_names() | |
return { | |
"locales": locales, | |
"default": LocaleManager.get_default_locale() | |
} | |
async def get_locale_details( | |
locale_code: str, | |
username: str = Depends(verify_token) | |
): | |
"""Get detailed information for a specific locale""" | |
from locale_manager import LocaleManager | |
locale_info = LocaleManager.get_locale_details(locale_code) | |
if not locale_info: | |
raise HTTPException(status_code=404, detail=f"Locale '{locale_code}' not found") | |
return locale_info | |
# ===================== Environment Endpoints ===================== | |
async def get_environment(username: str = Depends(verify_token)): | |
"""Get environment configuration""" | |
cfg = ConfigProvider.get() | |
env_config = cfg.global_config | |
return { | |
"llm_provider": env_config.llm_provider.model_dump(), | |
"tts_provider": env_config.tts_provider.model_dump(), | |
"stt_provider": env_config.stt_provider.model_dump(), | |
"providers": [p.model_dump() for p in env_config.providers] | |
} | |
async def update_environment( | |
update: EnvironmentUpdate, | |
username: str = Depends(verify_token) | |
): | |
"""Update environment configuration""" | |
log(f"π Updating environment config by {username}") | |
cfg = ConfigProvider.get() | |
# Validate LLM provider | |
llm_provider_def = cfg.global_config.get_provider_config("llm", update.llm_provider.name) | |
if not llm_provider_def: | |
raise HTTPException(status_code=400, detail=f"Unknown LLM provider: {update.llm_provider.name}") | |
# Validate requirements | |
if llm_provider_def.requires_api_key and not update.llm_provider.api_key: | |
raise HTTPException(status_code=400, detail=f"API key is required for {update.llm_provider.name}") | |
if llm_provider_def.requires_endpoint and not update.llm_provider.endpoint: | |
raise HTTPException(status_code=400, detail=f"Endpoint is required for {update.llm_provider.name}") | |
# Validate TTS provider | |
tts_provider_def = cfg.global_config.get_provider_config("tts", update.tts_provider.name) | |
if not tts_provider_def and update.tts_provider.name != "no_tts": | |
raise HTTPException(status_code=400, detail=f"Unknown TTS provider: {update.tts_provider.name}") | |
if tts_provider_def and tts_provider_def.requires_api_key and not update.tts_provider.api_key: | |
raise HTTPException(status_code=400, detail=f"API key is required for {update.tts_provider.name}") | |
# Validate STT provider | |
stt_provider_def = cfg.global_config.get_provider_config("stt", update.stt_provider.name) | |
if not stt_provider_def and update.stt_provider.name != "no_stt": | |
raise HTTPException(status_code=400, detail=f"Unknown STT provider: {update.stt_provider.name}") | |
if stt_provider_def and stt_provider_def.requires_api_key and not update.stt_provider.api_key: | |
raise HTTPException(status_code=400, detail=f"API key is required for {update.stt_provider.name}") | |
# Update via ConfigProvider | |
ConfigProvider.update_environment(update.model_dump(), username) | |
log(f"β Environment updated by {username}") | |
return {"success": True} | |
# ===================== Project Endpoints ===================== | |
async def list_projects( | |
include_deleted: bool = False, | |
username: str = Depends(verify_token) | |
): | |
"""List all projects""" | |
cfg = ConfigProvider.get() | |
projects = cfg.projects | |
# Filter deleted if needed | |
if not include_deleted: | |
projects = [p for p in projects if not getattr(p, 'deleted', False)] | |
return [p.model_dump() for p in projects] | |
async def create_project( | |
project: ProjectCreate, | |
username: str = Depends(verify_token) | |
): | |
"""Create new project""" | |
new_project = ConfigProvider.create_project(project.model_dump(), username) | |
log(f"β Project '{project.name}' created by {username}") | |
return new_project.model_dump() | |
async def get_project( | |
project_id: int, | |
username: str = Depends(verify_token) | |
): | |
"""Get project details""" | |
cfg = ConfigProvider.get() | |
project = next((p for p in cfg.projects if p.id == project_id), None) | |
if not project: | |
raise HTTPException(status_code=404, detail="Project not found") | |
return project.model_dump() | |
async def update_project( | |
project_id: int, | |
update: ProjectUpdate, | |
username: str = Depends(verify_token) | |
): | |
"""Update project""" | |
updated_project = ConfigProvider.update_project(project_id, update.model_dump(exclude_unset=True), username) | |
log(f"β Project {project_id} updated by {username}") | |
return updated_project.model_dump() | |
async def delete_project(project_id: int, username: str = Depends(verify_token)): | |
"""Delete project (soft delete)""" | |
ConfigProvider.delete_project(project_id, username) | |
log(f"β Project {project_id} deleted by {username}") | |
return {"success": True} | |
async def toggle_project(project_id: int, username: str = Depends(verify_token)): | |
"""Toggle project enabled status""" | |
project = ConfigProvider.toggle_project(project_id, username) | |
log(f"β Project {project_id} {'enabled' if project.enabled else 'disabled'} by {username}") | |
return {"success": True, "enabled": project.enabled} | |
# ===================== Version Endpoints ===================== | |
async def list_versions( | |
project_id: int, | |
include_deleted: bool = False, | |
username: str = Depends(verify_token) | |
): | |
"""List project versions""" | |
cfg = ConfigProvider.get() | |
project = next((p for p in cfg.projects if p.id == project_id), None) | |
if not project: | |
raise HTTPException(status_code=404, detail="Project not found") | |
versions = project.versions | |
# Filter deleted if needed | |
if not include_deleted: | |
versions = [v for v in versions if not getattr(v, 'deleted', False)] | |
return [v.model_dump() for v in versions] | |
async def create_version( | |
project_id: int, | |
version_data: VersionCreate, | |
username: str = Depends(verify_token) | |
): | |
"""Create new version""" | |
new_version = ConfigProvider.create_version(project_id, version_data.model_dump(), username) | |
log(f"β Version created for project {project_id} by {username}") | |
return new_version.model_dump() | |
async def update_version( | |
project_id: int, | |
version_id: int, | |
update: VersionUpdate, | |
username: str = Depends(verify_token) | |
): | |
"""Update version""" | |
updated_version = ConfigProvider.update_version(project_id, version_id, update.model_dump(exclude_unset=True), username) | |
log(f"β Version {version_id} updated for project {project_id} by {username}") | |
return updated_version.model_dump() | |
async def publish_version( | |
project_id: int, | |
version_id: int, | |
username: str = Depends(verify_token) | |
): | |
"""Publish version""" | |
project, version = ConfigProvider.publish_version(project_id, version_id, username) | |
log(f"β Version {version_id} published for project '{project.name}' by {username}") | |
# TODO: Notify LLM provider if needed (only for Spark variants) | |
return {"success": True} | |
async def delete_version( | |
project_id: int, | |
version_id: int, | |
username: str = Depends(verify_token) | |
): | |
"""Delete version (soft delete)""" | |
ConfigProvider.delete_version(project_id, version_id, username) | |
log(f"β Version {version_id} deleted for project {project_id} by {username}") | |
return {"success": True} | |
# ===================== Test Endpoints ===================== | |
async def test_connection( | |
request: dict = Body(...), | |
username: str = Depends(verify_token) | |
): | |
"""Test connection to LLM endpoint""" | |
endpoint = request.get("endpoint", "").rstrip("/") | |
api_key = request.get("api_key", "") | |
provider_name = request.get("provider", "spark") | |
if not endpoint: | |
raise HTTPException(status_code=400, detail="Endpoint is required") | |
# For Spark variants | |
if provider_name in ("spark", "spark_cloud", "spark_onpremise"): | |
try: | |
headers = { | |
"Authorization": f"Bearer {api_key}", | |
"Content-Type": "application/json" | |
} | |
async with httpx.AsyncClient(timeout=10) as client: | |
response = await client.get(f"{endpoint}/health", headers=headers) | |
if response.status_code == 200: | |
return {"success": True, "message": "Connection successful"} | |
else: | |
return {"success": False, "message": f"Health check failed: {response.status_code}"} | |
except Exception as e: | |
log(f"β Connection test failed: {e}") | |
return {"success": False, "message": str(e)} | |
# For other providers | |
return {"success": True, "message": "Provider doesn't require connection test"} | |
async def validate_regex( | |
request: dict = Body(...), | |
username: str = Depends(verify_token) | |
): | |
"""Validate regex pattern""" | |
pattern = request.get("pattern", "") | |
test_value = request.get("test_value", "") | |
try: | |
import re | |
compiled_regex = re.compile(pattern) | |
matches = bool(compiled_regex.match(test_value)) | |
return { | |
"valid": True, | |
"matches": matches, | |
"pattern": pattern, | |
"test_value": test_value | |
} | |
except Exception as e: | |
return { | |
"valid": False, | |
"matches": False, | |
"error": str(e), | |
"pattern": pattern, | |
"test_value": test_value | |
} | |
# ===================== API Endpoints ===================== | |
async def list_apis( | |
include_deleted: bool = False, | |
username: str = Depends(verify_token) | |
): | |
"""List all APIs""" | |
cfg = ConfigProvider.get() | |
apis = cfg.apis | |
# Filter deleted if needed | |
if not include_deleted: | |
apis = [a for a in apis if not getattr(a, 'deleted', False)] | |
return [a.model_dump() for a in apis] | |
async def create_api(api: APICreate, username: str = Depends(verify_token)): | |
"""Create new API""" | |
new_api = ConfigProvider.create_api(api.model_dump(), username) | |
log(f"β API '{api.name}' created by {username}") | |
return new_api.model_dump() | |
async def update_api( | |
api_name: str, | |
update: APIUpdate, | |
username: str = Depends(verify_token) | |
): | |
"""Update API""" | |
updated_api = ConfigProvider.update_api(api_name, update.model_dump(exclude_unset=True), username) | |
log(f"β API '{api_name}' updated by {username}") | |
return updated_api.model_dump() | |
async def delete_api(api_name: str, username: str = Depends(verify_token)): | |
"""Delete API (soft delete)""" | |
ConfigProvider.delete_api(api_name, username) | |
log(f"β API '{api_name}' deleted by {username}") | |
return {"success": True} | |
async def test_api( | |
request: dict = Body(...), | |
username: str = Depends(verify_token) | |
): | |
"""Test API endpoint""" | |
from api_executor import test_api_call | |
api_config = request.get("api_config") | |
test_params = request.get("test_params", {}) | |
if not api_config: | |
raise HTTPException(status_code=400, detail="API config is required") | |
try: | |
result = await test_api_call(api_config, test_params) | |
return { | |
"success": True, | |
"status_code": result.get("status_code"), | |
"response": result.get("response"), | |
"headers": result.get("headers"), | |
"duration_ms": result.get("duration_ms") | |
} | |
except Exception as e: | |
log(f"β API test failed: {e}") | |
return { | |
"success": False, | |
"error": str(e) | |
} | |
# ===================== Activity Log ===================== | |
async def get_activity_log( | |
limit: int = Query(50, ge=1, le=500), | |
offset: int = Query(0, ge=0), | |
username: str = Depends(verify_token) | |
): | |
"""Get activity log entries""" | |
# TODO: Implement when activity logging is added | |
return { | |
"entries": [], | |
"total": 0, | |
"limit": limit, | |
"offset": offset | |
} | |
# ===================== Health Check ===================== | |
async def health_check(): | |
"""Health check endpoint""" | |
return { | |
"status": "ok", | |
"timestamp": datetime.now(timezone.utc).isoformat(), | |
"version": "1.0.0" | |
} | |
# ===================== Cleanup Task ===================== | |
_cleanup_thread = None | |
def cleanup_activity_logs(): | |
"""Cleanup old activity logs periodically""" | |
while True: | |
try: | |
# Sleep for 24 hours | |
time.sleep(86400) | |
# TODO: Implement cleanup logic when activity logging is added | |
log("π§Ή Running activity log cleanup...") | |
except Exception as e: | |
log(f"β Cleanup error: {e}") | |
def start_cleanup_task(): | |
"""Start background cleanup task""" | |
global _cleanup_thread | |
if _cleanup_thread is None: | |
_cleanup_thread = threading.Thread(target=cleanup_activity_logs, daemon=True) | |
_cleanup_thread.start() | |
log("π§Ή Started cleanup task") | |
# ===================== Helper Functions ===================== | |
def _get_iso_timestamp() -> str: | |
"""Get current timestamp in ISO format""" | |
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" |