Spaces:
Building
Building
Update llm_factory.py
Browse files- llm_factory.py +58 -53
llm_factory.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
LLM Provider Factory for Flare
|
3 |
"""
|
4 |
import os
|
5 |
-
from typing import Optional
|
6 |
from dotenv import load_dotenv
|
7 |
|
8 |
from llm_interface import LLMInterface, SparkLLM, GPT4oLLM
|
@@ -10,69 +10,69 @@ from config_provider import ConfigProvider
|
|
10 |
from utils import log
|
11 |
|
12 |
class LLMFactory:
|
13 |
-
"""Factory class to create appropriate LLM provider based on llm_provider config"""
|
14 |
-
|
15 |
@staticmethod
|
16 |
def create_provider() -> LLMInterface:
|
17 |
-
"""Create
|
18 |
cfg = ConfigProvider.get()
|
19 |
-
|
20 |
|
21 |
-
if not
|
22 |
raise ValueError("No LLM provider configured")
|
23 |
|
24 |
-
provider_name =
|
25 |
log(f"π Creating LLM provider: {provider_name}")
|
26 |
|
27 |
-
# Get provider
|
28 |
-
|
29 |
-
if not
|
30 |
raise ValueError(f"Unknown LLM provider: {provider_name}")
|
31 |
|
32 |
# Get API key
|
33 |
-
api_key = LLMFactory._get_api_key(provider_name)
|
34 |
-
if not api_key and provider_config.requires_api_key:
|
35 |
-
raise ValueError(f"API key required for {provider_name} but not configured")
|
36 |
-
|
37 |
-
# Get settings
|
38 |
-
settings = llm_provider.settings or {}
|
39 |
|
40 |
-
# Create
|
41 |
if provider_name == "spark":
|
42 |
-
return LLMFactory._create_spark_provider(
|
43 |
-
elif provider_name in
|
44 |
-
return LLMFactory._create_gpt_provider(
|
45 |
else:
|
46 |
raise ValueError(f"Unsupported LLM provider: {provider_name}")
|
47 |
|
48 |
@staticmethod
|
49 |
-
def _create_spark_provider(
|
50 |
"""Create Spark LLM provider"""
|
51 |
-
if not endpoint:
|
52 |
-
raise ValueError("Spark
|
53 |
|
54 |
-
|
55 |
-
|
56 |
|
57 |
-
#
|
58 |
-
provider_variant = "
|
59 |
-
if
|
60 |
-
provider_variant = "
|
|
|
|
|
|
|
61 |
|
62 |
return SparkLLM(
|
63 |
-
spark_endpoint=
|
64 |
spark_token=api_key,
|
65 |
provider_variant=provider_variant,
|
66 |
-
settings=settings
|
67 |
)
|
68 |
|
69 |
@staticmethod
|
70 |
-
def _create_gpt_provider(
|
71 |
-
"""Create GPT-
|
72 |
-
|
73 |
-
|
74 |
|
75 |
-
|
|
|
|
|
|
|
|
|
76 |
|
77 |
return GPT4oLLM(
|
78 |
api_key=api_key,
|
@@ -81,37 +81,42 @@ class LLMFactory:
|
|
81 |
)
|
82 |
|
83 |
@staticmethod
|
84 |
-
def _get_api_key(provider_name: str) -> Optional[str]:
|
85 |
"""Get API key from config or environment"""
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
93 |
|
94 |
-
# Then check environment
|
95 |
-
|
96 |
"spark": "SPARK_TOKEN",
|
97 |
"gpt4o": "OPENAI_API_KEY",
|
98 |
-
"gpt4o-mini": "OPENAI_API_KEY"
|
99 |
}
|
100 |
|
101 |
-
env_var =
|
102 |
if env_var:
|
103 |
-
# Check if
|
104 |
-
if os.
|
|
|
105 |
api_key = os.environ.get(env_var)
|
106 |
if api_key:
|
107 |
-
log(f"π Using
|
108 |
return api_key
|
109 |
else:
|
110 |
-
# Local
|
111 |
load_dotenv()
|
112 |
api_key = os.getenv(env_var)
|
113 |
if api_key:
|
114 |
-
log(f"π Using
|
115 |
return api_key
|
116 |
|
117 |
return None
|
|
|
2 |
LLM Provider Factory for Flare
|
3 |
"""
|
4 |
import os
|
5 |
+
from typing import Optional
|
6 |
from dotenv import load_dotenv
|
7 |
|
8 |
from llm_interface import LLMInterface, SparkLLM, GPT4oLLM
|
|
|
10 |
from utils import log
|
11 |
|
12 |
class LLMFactory:
|
|
|
|
|
13 |
@staticmethod
|
14 |
def create_provider() -> LLMInterface:
|
15 |
+
"""Create LLM provider based on configuration"""
|
16 |
cfg = ConfigProvider.get()
|
17 |
+
llm_config = cfg.global_config.llm_provider
|
18 |
|
19 |
+
if not llm_config:
|
20 |
raise ValueError("No LLM provider configured")
|
21 |
|
22 |
+
provider_name = llm_config.name
|
23 |
log(f"π Creating LLM provider: {provider_name}")
|
24 |
|
25 |
+
# Get provider definition
|
26 |
+
provider_def = cfg.global_config.get_provider_config("llm", provider_name)
|
27 |
+
if not provider_def:
|
28 |
raise ValueError(f"Unknown LLM provider: {provider_name}")
|
29 |
|
30 |
# Get API key
|
31 |
+
api_key = LLMFactory._get_api_key(provider_name, llm_config.api_key)
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
# Create provider based on name
|
34 |
if provider_name == "spark":
|
35 |
+
return LLMFactory._create_spark_provider(llm_config, api_key, provider_def)
|
36 |
+
elif provider_name in ["gpt4o", "gpt4o-mini"]:
|
37 |
+
return LLMFactory._create_gpt_provider(llm_config, api_key, provider_def)
|
38 |
else:
|
39 |
raise ValueError(f"Unsupported LLM provider: {provider_name}")
|
40 |
|
41 |
@staticmethod
|
42 |
+
def _create_spark_provider(llm_config, api_key: str, provider_def) -> SparkLLM:
|
43 |
"""Create Spark LLM provider"""
|
44 |
+
if not llm_config.endpoint:
|
45 |
+
raise ValueError("Spark endpoint is required")
|
46 |
|
47 |
+
if not api_key:
|
48 |
+
raise ValueError("Spark API token is required")
|
49 |
|
50 |
+
# Extract work mode variant (for backward compatibility)
|
51 |
+
provider_variant = "cloud" # Default
|
52 |
+
if os.getenv("SPACE_ID"): # HuggingFace Space
|
53 |
+
provider_variant = "hfcloud"
|
54 |
+
|
55 |
+
log(f"π Initializing SparkLLM: {llm_config.endpoint}")
|
56 |
+
log(f"π§ Provider variant: {provider_variant}")
|
57 |
|
58 |
return SparkLLM(
|
59 |
+
spark_endpoint=llm_config.endpoint,
|
60 |
spark_token=api_key,
|
61 |
provider_variant=provider_variant,
|
62 |
+
settings=llm_config.settings
|
63 |
)
|
64 |
|
65 |
@staticmethod
|
66 |
+
def _create_gpt_provider(llm_config, api_key: str, provider_def) -> GPT4oLLM:
|
67 |
+
"""Create GPT-4 provider"""
|
68 |
+
if not api_key:
|
69 |
+
raise ValueError("OpenAI API key is required")
|
70 |
|
71 |
+
# Get model-specific settings
|
72 |
+
settings = llm_config.settings or {}
|
73 |
+
model = provider_def.name # gpt4o or gpt4o-mini
|
74 |
+
|
75 |
+
log(f"π€ Initializing GPT4oLLM with model: {model}")
|
76 |
|
77 |
return GPT4oLLM(
|
78 |
api_key=api_key,
|
|
|
81 |
)
|
82 |
|
83 |
@staticmethod
|
84 |
+
def _get_api_key(provider_name: str, config_key: Optional[str]) -> Optional[str]:
|
85 |
"""Get API key from config or environment"""
|
86 |
+
# First check config
|
87 |
+
if config_key:
|
88 |
+
if config_key.startswith("enc:"):
|
89 |
+
# Decrypt if encrypted
|
90 |
+
from encryption_utils import decrypt
|
91 |
+
decrypted = decrypt(config_key)
|
92 |
+
log(f"π Using encrypted API key from config: ***{decrypted[-4:]}")
|
93 |
+
return decrypted
|
94 |
+
else:
|
95 |
+
log(f"π Using plain API key from config: ***{config_key[-4:]}")
|
96 |
+
return config_key
|
97 |
|
98 |
+
# Then check environment
|
99 |
+
env_mappings = {
|
100 |
"spark": "SPARK_TOKEN",
|
101 |
"gpt4o": "OPENAI_API_KEY",
|
102 |
+
"gpt4o-mini": "OPENAI_API_KEY"
|
103 |
}
|
104 |
|
105 |
+
env_var = env_mappings.get(provider_name)
|
106 |
if env_var:
|
107 |
+
# Check if we're in HuggingFace Space
|
108 |
+
if os.getenv("SPACE_ID"):
|
109 |
+
# HuggingFace mode - direct environment
|
110 |
api_key = os.environ.get(env_var)
|
111 |
if api_key:
|
112 |
+
log(f"π Using API key from HuggingFace secrets: {env_var}")
|
113 |
return api_key
|
114 |
else:
|
115 |
+
# Local mode - use dotenv
|
116 |
load_dotenv()
|
117 |
api_key = os.getenv(env_var)
|
118 |
if api_key:
|
119 |
+
log(f"π Using API key from .env: {env_var}")
|
120 |
return api_key
|
121 |
|
122 |
return None
|