ciyidogan commited on
Commit
e2a364d
Β·
verified Β·
1 Parent(s): 53db15d

Update llm_factory.py

Browse files
Files changed (1) hide show
  1. llm_factory.py +58 -53
llm_factory.py CHANGED
@@ -2,7 +2,7 @@
2
  LLM Provider Factory for Flare
3
  """
4
  import os
5
- from typing import Optional, Dict, Any
6
  from dotenv import load_dotenv
7
 
8
  from llm_interface import LLMInterface, SparkLLM, GPT4oLLM
@@ -10,69 +10,69 @@ from config_provider import ConfigProvider
10
  from utils import log
11
 
12
  class LLMFactory:
13
- """Factory class to create appropriate LLM provider based on llm_provider config"""
14
-
15
  @staticmethod
16
  def create_provider() -> LLMInterface:
17
- """Create and return appropriate LLM provider based on config"""
18
  cfg = ConfigProvider.get()
19
- llm_provider = cfg.global_config.llm_provider
20
 
21
- if not llm_provider or not llm_provider.name:
22
  raise ValueError("No LLM provider configured")
23
 
24
- provider_name = llm_provider.name
25
  log(f"🏭 Creating LLM provider: {provider_name}")
26
 
27
- # Get provider config
28
- provider_config = cfg.global_config.get_provider_config("llm", provider_name)
29
- if not provider_config:
30
  raise ValueError(f"Unknown LLM provider: {provider_name}")
31
 
32
  # Get API key
33
- api_key = LLMFactory._get_api_key(provider_name)
34
- if not api_key and provider_config.requires_api_key:
35
- raise ValueError(f"API key required for {provider_name} but not configured")
36
-
37
- # Get settings
38
- settings = llm_provider.settings or {}
39
 
40
- # Create appropriate provider
41
  if provider_name == "spark":
42
- return LLMFactory._create_spark_provider(api_key, llm_provider.endpoint, settings)
43
- elif provider_name in ("gpt4o", "gpt4o-mini"):
44
- return LLMFactory._create_gpt_provider(provider_name, api_key, settings)
45
  else:
46
  raise ValueError(f"Unsupported LLM provider: {provider_name}")
47
 
48
  @staticmethod
49
- def _create_spark_provider(api_key: str, endpoint: Optional[str], settings: Dict[str, Any]) -> SparkLLM:
50
  """Create Spark LLM provider"""
51
- if not endpoint:
52
- raise ValueError("Spark requires endpoint to be configured")
53
 
54
- log(f"πŸš€ Creating SparkLLM provider")
55
- log(f"πŸ“ Endpoint: {endpoint}")
56
 
57
- # Determine provider variant for backward compatibility
58
- provider_variant = "spark-cloud"
59
- if not ConfigProvider.get().global_config.is_cloud_mode():
60
- provider_variant = "spark-onpremise"
 
 
 
61
 
62
  return SparkLLM(
63
- spark_endpoint=str(endpoint),
64
  spark_token=api_key,
65
  provider_variant=provider_variant,
66
- settings=settings
67
  )
68
 
69
  @staticmethod
70
- def _create_gpt_provider(model_type: str, api_key: str, settings: Dict[str, Any]) -> GPT4oLLM:
71
- """Create GPT-4o LLM provider"""
72
- # Determine model
73
- model = "gpt-4o-mini" if model_type == "gpt4o-mini" else "gpt-4o"
74
 
75
- log(f"πŸ€– Creating GPT4oLLM provider with model: {model}")
 
 
 
 
76
 
77
  return GPT4oLLM(
78
  api_key=api_key,
@@ -81,37 +81,42 @@ class LLMFactory:
81
  )
82
 
83
  @staticmethod
84
- def _get_api_key(provider_name: str) -> Optional[str]:
85
  """Get API key from config or environment"""
86
- cfg = ConfigProvider.get()
87
-
88
- # First check encrypted config
89
- api_key = cfg.global_config.get_plain_api_key("llm")
90
- if api_key:
91
- log("πŸ”‘ Using decrypted API key from config")
92
- return api_key
 
 
 
 
93
 
94
- # Then check environment based on provider
95
- env_var_map = {
96
  "spark": "SPARK_TOKEN",
97
  "gpt4o": "OPENAI_API_KEY",
98
- "gpt4o-mini": "OPENAI_API_KEY",
99
  }
100
 
101
- env_var = env_var_map.get(provider_name)
102
  if env_var:
103
- # Check if running in HuggingFace Space
104
- if os.environ.get("SPACE_ID"):
 
105
  api_key = os.environ.get(env_var)
106
  if api_key:
107
- log(f"πŸ”‘ Using {env_var} from HuggingFace secrets")
108
  return api_key
109
  else:
110
- # Local development
111
  load_dotenv()
112
  api_key = os.getenv(env_var)
113
  if api_key:
114
- log(f"πŸ”‘ Using {env_var} from .env file")
115
  return api_key
116
 
117
  return None
 
2
  LLM Provider Factory for Flare
3
  """
4
  import os
5
+ from typing import Optional
6
  from dotenv import load_dotenv
7
 
8
  from llm_interface import LLMInterface, SparkLLM, GPT4oLLM
 
10
  from utils import log
11
 
12
  class LLMFactory:
 
 
13
  @staticmethod
14
  def create_provider() -> LLMInterface:
15
+ """Create LLM provider based on configuration"""
16
  cfg = ConfigProvider.get()
17
+ llm_config = cfg.global_config.llm_provider
18
 
19
+ if not llm_config:
20
  raise ValueError("No LLM provider configured")
21
 
22
+ provider_name = llm_config.name
23
  log(f"🏭 Creating LLM provider: {provider_name}")
24
 
25
+ # Get provider definition
26
+ provider_def = cfg.global_config.get_provider_config("llm", provider_name)
27
+ if not provider_def:
28
  raise ValueError(f"Unknown LLM provider: {provider_name}")
29
 
30
  # Get API key
31
+ api_key = LLMFactory._get_api_key(provider_name, llm_config.api_key)
 
 
 
 
 
32
 
33
+ # Create provider based on name
34
  if provider_name == "spark":
35
+ return LLMFactory._create_spark_provider(llm_config, api_key, provider_def)
36
+ elif provider_name in ["gpt4o", "gpt4o-mini"]:
37
+ return LLMFactory._create_gpt_provider(llm_config, api_key, provider_def)
38
  else:
39
  raise ValueError(f"Unsupported LLM provider: {provider_name}")
40
 
41
  @staticmethod
42
+ def _create_spark_provider(llm_config, api_key: str, provider_def) -> SparkLLM:
43
  """Create Spark LLM provider"""
44
+ if not llm_config.endpoint:
45
+ raise ValueError("Spark endpoint is required")
46
 
47
+ if not api_key:
48
+ raise ValueError("Spark API token is required")
49
 
50
+ # Extract work mode variant (for backward compatibility)
51
+ provider_variant = "cloud" # Default
52
+ if os.getenv("SPACE_ID"): # HuggingFace Space
53
+ provider_variant = "hfcloud"
54
+
55
+ log(f"πŸš€ Initializing SparkLLM: {llm_config.endpoint}")
56
+ log(f"πŸ”§ Provider variant: {provider_variant}")
57
 
58
  return SparkLLM(
59
+ spark_endpoint=llm_config.endpoint,
60
  spark_token=api_key,
61
  provider_variant=provider_variant,
62
+ settings=llm_config.settings
63
  )
64
 
65
  @staticmethod
66
+ def _create_gpt_provider(llm_config, api_key: str, provider_def) -> GPT4oLLM:
67
+ """Create GPT-4 provider"""
68
+ if not api_key:
69
+ raise ValueError("OpenAI API key is required")
70
 
71
+ # Get model-specific settings
72
+ settings = llm_config.settings or {}
73
+ model = provider_def.name # gpt4o or gpt4o-mini
74
+
75
+ log(f"πŸ€– Initializing GPT4oLLM with model: {model}")
76
 
77
  return GPT4oLLM(
78
  api_key=api_key,
 
81
  )
82
 
83
  @staticmethod
84
+ def _get_api_key(provider_name: str, config_key: Optional[str]) -> Optional[str]:
85
  """Get API key from config or environment"""
86
+ # First check config
87
+ if config_key:
88
+ if config_key.startswith("enc:"):
89
+ # Decrypt if encrypted
90
+ from encryption_utils import decrypt
91
+ decrypted = decrypt(config_key)
92
+ log(f"πŸ”“ Using encrypted API key from config: ***{decrypted[-4:]}")
93
+ return decrypted
94
+ else:
95
+ log(f"πŸ”‘ Using plain API key from config: ***{config_key[-4:]}")
96
+ return config_key
97
 
98
+ # Then check environment
99
+ env_mappings = {
100
  "spark": "SPARK_TOKEN",
101
  "gpt4o": "OPENAI_API_KEY",
102
+ "gpt4o-mini": "OPENAI_API_KEY"
103
  }
104
 
105
+ env_var = env_mappings.get(provider_name)
106
  if env_var:
107
+ # Check if we're in HuggingFace Space
108
+ if os.getenv("SPACE_ID"):
109
+ # HuggingFace mode - direct environment
110
  api_key = os.environ.get(env_var)
111
  if api_key:
112
+ log(f"πŸ”‘ Using API key from HuggingFace secrets: {env_var}")
113
  return api_key
114
  else:
115
+ # Local mode - use dotenv
116
  load_dotenv()
117
  api_key = os.getenv(env_var)
118
  if api_key:
119
+ log(f"πŸ”‘ Using API key from .env: {env_var}")
120
  return api_key
121
 
122
  return None