Spaces:
Building
Building
""" | |
Flare β ConfigProvider | |
""" | |
from __future__ import annotations | |
import json, os, threading | |
from pathlib import Path | |
from typing import Any, Dict, List, Optional | |
import commentjson | |
from datetime import datetime, timezone | |
from pydantic import BaseModel, Field, HttpUrl, ValidationError | |
from utils import log | |
from encryption_utils import decrypt | |
# ============== Provider Configs ============== | |
class ProviderConfig(BaseModel): | |
"""Configuration for available providers (LLM, TTS, STT)""" | |
type: str = Field(..., pattern=r"^(llm|tts|stt)$") | |
name: str | |
display_name: str | |
requires_endpoint: bool = False | |
requires_api_key: bool = True | |
requires_repo_info: bool = False | |
description: Optional[str] = None | |
class ProviderSettings(BaseModel): | |
"""Provider-specific settings""" | |
name: str | |
api_key: Optional[str] = None | |
endpoint: Optional[str] = None | |
settings: Dict[str, Any] = Field(default_factory=dict) | |
# ============== Parameter Collection ============== | |
class ParameterCollectionConfig(BaseModel): | |
"""Configuration for parameter collection behavior""" | |
max_params_per_question: int = Field(2, ge=1, le=5) | |
retry_unanswered: bool = True | |
collection_prompt: str = "" | |
class Config: | |
extra = "allow" | |
# ============== Global Config ============== | |
class GlobalConfig(BaseModel): | |
# Provider configurations | |
llm_provider: ProviderSettings | |
tts_provider: ProviderSettings | |
stt_provider: ProviderSettings | |
# Available providers | |
providers: List[ProviderConfig] = Field(default_factory=list) | |
# Users | |
users: List["UserConfig"] = [] | |
def get_provider_config(self, provider_type: str, provider_name: str) -> Optional[ProviderConfig]: | |
"""Get configuration for a specific provider""" | |
return next((p for p in self.providers if p.type == provider_type and p.name == provider_name), None) | |
def get_providers_by_type(self, provider_type: str) -> List[ProviderConfig]: | |
"""Get all providers of a specific type""" | |
return [p for p in self.providers if p.type == provider_type] | |
def get_plain_api_key(self, provider_type: str) -> Optional[str]: | |
"""Get decrypted API key for a provider type""" | |
provider = getattr(self, f"{provider_type}_provider", None) | |
if not provider or not provider.api_key: | |
return None | |
raw_key = provider.api_key | |
# If it starts with enc:, decrypt it | |
if raw_key.startswith("enc:"): | |
log(f"π Decrypting {provider_type} API key...") | |
decrypted = decrypt(raw_key) | |
log(f"π {provider_type} key decrypted: {'***' + decrypted[-4:] if decrypted else 'None'}") | |
return decrypted | |
log(f"π {provider_type} key/path: {'***' + raw_key[-4:] if raw_key else 'None'}") | |
return raw_key | |
class Config: | |
extra = "allow" | |
# ============== User Config ============== | |
class UserConfig(BaseModel): | |
username: str | |
password_hash: str | |
salt: str | |
# ============== Retry / Proxy ============== | |
class RetryConfig(BaseModel): | |
retry_count: int = Field(3, alias="max_attempts") | |
backoff_seconds: int = 2 | |
strategy: str = Field("static", pattern=r"^(static|exponential)$") | |
class ProxyConfig(BaseModel): | |
enabled: bool = True | |
url: HttpUrl | |
# ============== API & Auth ============== | |
class APIAuthConfig(BaseModel): | |
enabled: bool = False | |
token_endpoint: Optional[HttpUrl] = None | |
response_token_path: str = "access_token" | |
token_request_body: Dict[str, Any] = Field({}, alias="body_template") | |
token_refresh_endpoint: Optional[HttpUrl] = None | |
token_refresh_body: Dict[str, Any] = {} | |
class Config: | |
extra = "allow" | |
populate_by_name = True | |
class APIConfig(BaseModel): | |
name: str | |
url: HttpUrl | |
method: str = Field("GET", pattern=r"^(GET|POST|PUT|PATCH|DELETE)$") | |
headers: Dict[str, Any] = {} | |
body_template: Dict[str, Any] = {} | |
timeout_seconds: int = 10 | |
retry: RetryConfig = RetryConfig() | |
proxy: Optional[str | ProxyConfig] = None | |
auth: Optional[APIAuthConfig] = None | |
response_prompt: Optional[str] = None | |
class Config: | |
extra = "allow" | |
populate_by_name = True | |
# ============== Localized Content ============== | |
class LocalizedCaption(BaseModel): | |
locale_code: str | |
caption: str | |
class LocalizedExample(BaseModel): | |
locale_code: str | |
example: str | |
# ============== Intent / Param ============== | |
class ParameterConfig(BaseModel): | |
name: str | |
caption: List[LocalizedCaption] # Multi-language captions | |
type: str = Field(..., pattern=r"^(int|float|bool|str|string|date)$") | |
required: bool = True | |
variable_name: str | |
extraction_prompt: Optional[str] = None | |
validation_regex: Optional[str] = None | |
invalid_prompt: Optional[str] = None | |
type_error_prompt: Optional[str] = None | |
def get_caption_for_locale(self, locale_code: str, default_locale: str = "tr") -> str: | |
"""Get caption for specific locale with fallback""" | |
# Try exact match | |
caption = next((c.caption for c in self.caption if c.locale_code == locale_code), None) | |
if caption: | |
return caption | |
# Try default locale | |
caption = next((c.caption for c in self.caption if c.locale_code == default_locale), None) | |
if caption: | |
return caption | |
# Return first available or parameter name | |
return self.caption[0].caption if self.caption else self.name | |
def canonical_type(self) -> str: | |
if self.type == "string": | |
return "str" | |
elif self.type == "date": | |
return "str" # Store dates as strings in ISO format | |
return self.type | |
class IntentConfig(BaseModel): | |
name: str | |
caption: Optional[str] = "" | |
dependencies: List[str] = [] | |
examples: List[LocalizedExample] = [] # Multi-language examples | |
detection_prompt: Optional[str] = None | |
parameters: List[ParameterConfig] = [] | |
action: str | |
fallback_timeout_prompt: Optional[str] = None | |
fallback_error_prompt: Optional[str] = None | |
def get_examples_for_locale(self, locale_code: str) -> List[str]: | |
"""Get examples for specific locale""" | |
return [ex.example for ex in self.examples if ex.locale_code == locale_code] | |
class Config: | |
extra = "allow" | |
# ============== Version / Project ============== | |
class LLMConfig(BaseModel): | |
repo_id: str | |
generation_config: Dict[str, Any] = {} | |
use_fine_tune: bool = False | |
fine_tune_zip: str = "" | |
class VersionConfig(BaseModel): | |
id: int = Field(..., alias="version_number") | |
caption: Optional[str] = "" | |
published: bool = False | |
general_prompt: str | |
llm: Optional["LLMConfig"] = None # Optional based on provider | |
intents: List["IntentConfig"] | |
class Config: | |
extra = "allow" | |
populate_by_name = True | |
class ProjectConfig(BaseModel): | |
id: Optional[int] = None | |
name: str | |
caption: Optional[str] = "" | |
enabled: bool = True | |
default_locale: str = "tr" # Changed from default_language | |
supported_locales: List[str] = ["tr"] # Changed from supported_languages | |
last_version_number: Optional[int] = None | |
versions: List[VersionConfig] | |
class Config: | |
extra = "allow" | |
# ============== Service Config ============== | |
class ServiceConfig(BaseModel): | |
global_config: GlobalConfig = Field(..., alias="config") | |
projects: List[ProjectConfig] | |
apis: List[APIConfig] | |
# runtime helpers (skip validation) | |
_api_by_name: Dict[str, APIConfig] = {} | |
def build_index(self): | |
self._api_by_name = {a.name: a for a in self.apis} | |
def get_api(self, name: str) -> APIConfig | None: | |
return self._api_by_name.get(name) | |
# ============== Provider Singleton ============== | |
class ConfigProvider: | |
_instance: Optional[ServiceConfig] = None | |
_CONFIG_PATH = Path(__file__).parent / "service_config.jsonc" | |
_lock = threading.Lock() | |
def get(cls) -> ServiceConfig: | |
"""Get cached config - thread-safe""" | |
if cls._instance is None: | |
with cls._lock: | |
# Double-checked locking pattern | |
if cls._instance is None: | |
cls._instance = cls._load() | |
cls._instance.build_index() | |
return cls._instance | |
def reload(cls) -> ServiceConfig: | |
"""Force reload configuration from file - used after UI saves""" | |
with cls._lock: | |
log("π Reloading configuration...") | |
cls._instance = None | |
return cls.get() | |
def update_config(cls, config_dict: dict): | |
"""Update the current configuration with new values""" | |
if cls._instance is None: | |
cls.get() | |
# Update global config | |
if 'config' in config_dict: | |
for key, value in config_dict['config'].items(): | |
if hasattr(cls._instance.global_config, key): | |
setattr(cls._instance.global_config, key, value) | |
# Save to file | |
cls._instance.save() | |
def _load(cls) -> ServiceConfig: | |
"""Load configuration from service_config.jsonc""" | |
try: | |
log(f"π Loading config from: {cls._CONFIG_PATH}") | |
if not cls._CONFIG_PATH.exists(): | |
raise FileNotFoundError(f"Config file not found: {cls._CONFIG_PATH}") | |
with open(cls._CONFIG_PATH, 'r', encoding='utf-8') as f: | |
config_data = commentjson.load(f) | |
# Ensure required fields exist in config data | |
if 'config' not in config_data: | |
config_data['config'] = {} | |
config_section = config_data['config'] | |
# Set defaults for missing fields if needed | |
if 'llm_provider' not in config_section: | |
config_section['llm_provider'] = { | |
"name": "spark", | |
"api_key": None, | |
"endpoint": "http://localhost:7861", | |
"settings": {} | |
} | |
if 'tts_provider' not in config_section: | |
config_section['tts_provider'] = { | |
"name": "no_tts", | |
"api_key": None, | |
"endpoint": None, | |
"settings": {} | |
} | |
if 'stt_provider' not in config_section: | |
config_section['stt_provider'] = { | |
"name": "no_stt", | |
"api_key": None, | |
"endpoint": None, | |
"settings": {} | |
} | |
if 'providers' not in config_section: | |
config_section['providers'] = [] | |
config_section.setdefault('users', []) | |
# Convert string body/headers to dict if needed | |
for api in config_data.get('apis', []): | |
if isinstance(api.get('headers'), str): | |
try: | |
api['headers'] = json.loads(api['headers']) | |
except: | |
api['headers'] = {} | |
if isinstance(api.get('body_template'), str): | |
try: | |
api['body_template'] = json.loads(api['body_template']) | |
except: | |
api['body_template'] = {} | |
# Handle auth section | |
if api.get('auth'): | |
auth = api['auth'] | |
if isinstance(auth.get('token_request_body'), str): | |
try: | |
auth['token_request_body'] = json.loads(auth['token_request_body']) | |
except: | |
auth['token_request_body'] = {} | |
if isinstance(auth.get('token_refresh_body'), str): | |
try: | |
auth['token_refresh_body'] = json.loads(auth['token_refresh_body']) | |
except: | |
auth['token_refresh_body'] = {} | |
# Parse and validate | |
cfg = ServiceConfig.model_validate(config_data) | |
log("β Configuration loaded successfully") | |
return cfg | |
except FileNotFoundError: | |
log(f"β Config file not found: {cls._CONFIG_PATH}") | |
raise | |
except (commentjson.JSONLibraryException, ValidationError) as e: | |
log(f"β Config parsing/validation error: {e}") | |
raise | |
except Exception as e: | |
log(f"β Unexpected error loading config: {e}") | |
raise | |
# ============== Environment Methods ============== | |
def update_environment(cls, env_data: dict, username: str): | |
"""Update environment configuration""" | |
if cls._instance is None: | |
cls.get() | |
cfg = cls._instance | |
timestamp = _get_iso_timestamp() | |
# Update provider configurations | |
if 'llm_provider' in env_data: | |
cfg.global_config.llm_provider = ProviderSettings(**env_data['llm_provider']) | |
if 'tts_provider' in env_data: | |
cfg.global_config.tts_provider = ProviderSettings(**env_data['tts_provider']) | |
if 'stt_provider' in env_data: | |
cfg.global_config.stt_provider = ProviderSettings(**env_data['stt_provider']) | |
# Save | |
cls._save() | |
# Log activity | |
cls._log_activity("UPDATE_ENVIRONMENT", "environment", username=username) | |
# ============== Save Method ============== | |
def _save(cls): | |
"""Save current config to file""" | |
if cls._instance is None: | |
return | |
with cls._lock: | |
config_dict = { | |
"config": cls._instance.global_config.model_dump(exclude_none=True), | |
"projects": [p.model_dump(by_alias=True) for p in cls._instance.projects], | |
"apis": [a.model_dump(by_alias=True) for a in cls._instance.apis] | |
} | |
# Write to file | |
with open(cls._CONFIG_PATH, 'w', encoding='utf-8') as f: | |
f.write(commentjson.dumps(config_dict, indent=2, ensure_ascii=False)) | |
log("πΎ Configuration saved") | |
# ============== Activity Logging ============== | |
def _log_activity(cls, action: str, entity_type: str, entity_id: Any = None, | |
entity_name: str = None, username: str = "system", details: str = None): | |
"""Log activity - placeholder for future implementation""" | |
log(f"π Activity: {action} on {entity_type} by {username}") | |
# ============== Helper Functions ============== | |
def _get_iso_timestamp() -> str: | |
"""Get current timestamp in ISO format""" | |
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" |