Spaces:
Building
Building
""" | |
LLM Provider Factory for Flare | |
""" | |
import os | |
from typing import Optional | |
from dotenv import load_dotenv | |
from llm_interface import LLMInterface, SparkLLM, GPT4oLLM | |
from config_provider import ConfigProvider | |
from utils import log | |
class LLMFactory: | |
"""Factory class to create appropriate LLM provider based on configuration""" | |
def create_provider() -> LLMInterface: | |
"""Create and return appropriate LLM provider based on config""" | |
cfg = ConfigProvider.get() | |
llm_config = cfg.global_config.llm_provider | |
if not llm_config: | |
raise ValueError("No LLM provider configured") | |
provider_name = llm_config.name | |
log(f"π Creating LLM provider: {provider_name}") | |
# Get provider definition | |
provider_def = cfg.global_config.get_provider_config("llm", provider_name) | |
if not provider_def: | |
raise ValueError(f"Unknown LLM provider: {provider_name}") | |
# Get API key | |
api_key = LLMFactory._get_api_key(provider_name) | |
if not api_key and provider_def.requires_api_key: | |
raise ValueError(f"API key required for {provider_name} but not configured") | |
# Get endpoint | |
endpoint = llm_config.endpoint | |
if not endpoint and provider_def.requires_endpoint: | |
raise ValueError(f"Endpoint required for {provider_name} but not configured") | |
# Create appropriate provider | |
if provider_name in ("spark", "spark_cloud", "spark_onpremise"): | |
return LLMFactory._create_spark_provider(provider_name, api_key, endpoint, llm_config.settings) | |
elif provider_name in ("gpt4o", "gpt4o-mini"): | |
return LLMFactory._create_gpt_provider(provider_name, api_key, llm_config.settings) | |
else: | |
raise ValueError(f"Unsupported LLM provider: {provider_name}") | |
def _create_spark_provider(provider_name: str, api_key: str, endpoint: str, settings: dict) -> SparkLLM: | |
"""Create Spark LLM provider""" | |
log(f"π Creating SparkLLM provider: {provider_name}") | |
log(f"π Endpoint: {endpoint}") | |
return SparkLLM( | |
spark_endpoint=endpoint, | |
spark_token=api_key, | |
provider_variant=provider_name, | |
settings=settings | |
) | |
def _create_gpt_provider(model_type: str, api_key: str, settings: dict) -> GPT4oLLM: | |
"""Create GPT-4o LLM provider""" | |
# Determine model | |
model = "gpt-4o-mini" if model_type == "gpt4o-mini" else "gpt-4o" | |
log(f"π€ Creating GPT4oLLM provider with model: {model}") | |
return GPT4oLLM( | |
api_key=api_key, | |
model=model, | |
settings=settings | |
) | |
def _get_api_key(provider_name: str) -> Optional[str]: | |
"""Get API key from config or environment""" | |
cfg = ConfigProvider.get() | |
# First check encrypted config | |
api_key = cfg.global_config.get_plain_api_key("llm") | |
if api_key: | |
log("π Using decrypted API key from config") | |
return api_key | |
# Then check environment based on provider | |
env_var_map = { | |
"spark": "SPARK_TOKEN", | |
"spark_cloud": "SPARK_TOKEN", | |
"spark_onpremise": "SPARK_TOKEN", | |
"gpt4o": "OPENAI_API_KEY", | |
"gpt4o-mini": "OPENAI_API_KEY", | |
} | |
env_var = env_var_map.get(provider_name) | |
if env_var: | |
# Check if running in HuggingFace Space | |
if os.environ.get("SPACE_ID"): | |
api_key = os.environ.get(env_var) | |
if api_key: | |
log(f"π Using {env_var} from HuggingFace secrets") | |
else: | |
# Local/on-premise deployment | |
load_dotenv() | |
api_key = os.getenv(env_var) | |
if api_key: | |
log(f"π Using {env_var} from .env file") | |
return api_key |