File size: 3,345 Bytes
2ebef02
 
 
 
e2a364d
2ebef02
 
71406b0
 
 
2ebef02
99c3903
2ebef02
 
 
 
e2a364d
2ebef02
e2a364d
2ebef02
e2a364d
b9b2b1e
2ebef02
e2a364d
99c3903
b9b2b1e
e2a364d
 
 
b9b2b1e
2ebef02
 
e2a364d
2ebef02
e2a364d
394611c
e2a364d
 
 
2ebef02
b9b2b1e
2ebef02
 
71406b0
2ebef02
71406b0
 
 
394611c
71406b0
 
 
394611c
2ebef02
71406b0
2ebef02
71406b0
e2a364d
2ebef02
 
 
71406b0
 
 
2ebef02
71406b0
 
2ebef02
 
 
71406b0
2ebef02
71406b0
 
 
 
e2a364d
71406b0
 
2ebef02
71406b0
e2a364d
2ebef02
 
e2a364d
2ebef02
 
e2a364d
2ebef02
71406b0
 
99c3903
71406b0
2ebef02
71406b0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
LLM Provider Factory for Flare
"""
import os
from typing import Optional
from dotenv import load_dotenv

from llm_interface import LLMInterface
from llm_spark import SparkLLM
from llm_openai import OpenAILLM
from config_provider import ConfigProvider
from logger import log_info, log_error, log_warning, log_debug

class LLMFactory:
    @staticmethod
    def create_provider() -> LLMInterface:
        """Create LLM provider based on configuration"""
        cfg = ConfigProvider.get()
        llm_config = cfg.global_config.llm_provider
        
        if not llm_config:
            raise ValueError("No LLM provider configured")
        
        provider_name = llm_config.name
        log_info(f"🏭 Creating LLM provider: {provider_name}")
        
        # Get provider definition
        provider_def = cfg.global_config.get_provider_config("llm", provider_name)
        if not provider_def:
            raise ValueError(f"Unknown LLM provider: {provider_name}")
        
        # Get API key
        api_key = LLMFactory._get_api_key(provider_name, llm_config.api_key)
        
        # Create provider based on name
        if provider_name == "spark":
            return LLMFactory._create_spark_provider(llm_config, api_key, provider_def)
        elif provider_name in ["gpt4o", "gpt4o-mini"]:
            return LLMFactory._create_gpt_provider(llm_config, api_key, provider_def)
        else:
            raise ValueError(f"Unsupported LLM provider: {provider_name}")
    
    @staticmethod
    def _create_spark_provider(llm_config, api_key, provider_def):
        """Create Spark LLM provider"""
        endpoint = llm_config.endpoint
        if not endpoint:
            raise ValueError("Spark endpoint not configured")
        
        # Determine variant based on environment
        is_cloud = bool(os.environ.get("SPACE_ID"))
        variant = "hfcloud" if is_cloud else "on-premise"
        
        return SparkLLM(
            spark_endpoint=endpoint,
            spark_token=api_key,
            provider_variant=variant,
            settings=llm_config.settings
        )
    
    @staticmethod
    def _create_gpt_provider(llm_config, api_key, provider_def):
        """Create OpenAI GPT provider"""
        return OpenAILLM(
            api_key=api_key,
            model=llm_config.name,
            settings=llm_config.settings
        )
    
    @staticmethod
    def _get_api_key(provider_name: str, configured_key: Optional[str]) -> str:
        """Get API key from config or environment"""
        # First try configured key
        if configured_key:
            # Handle encrypted keys
            if configured_key.startswith("enc:"):
                from encryption_utils import decrypt
                return decrypt(configured_key)
            return configured_key
        
        # Then try environment variables
        env_mappings = {
            "spark": "SPARK_TOKEN",
            "gpt4o": "OPENAI_API_KEY",
            "gpt4o-mini": "OPENAI_API_KEY"
        }
        
        env_var = env_mappings.get(provider_name)
        if env_var:
            key = os.environ.get(env_var)
            if key:
                log_info(f"πŸ“Œ Using API key from environment: {env_var}")
                return key
        
        raise ValueError(f"No API key found for provider: {provider_name}")