ciyidogan commited on
Commit
b9b2b1e
Β·
verified Β·
1 Parent(s): 9db24d7

Update llm_factory.py

Browse files
Files changed (1) hide show
  1. llm_factory.py +41 -40
llm_factory.py CHANGED
@@ -10,59 +10,58 @@ from config_provider import ConfigProvider
10
  from utils import log
11
 
12
  class LLMFactory:
13
- """Factory class to create appropriate LLM provider based on llm_provider config"""
14
 
15
  @staticmethod
16
  def create_provider() -> LLMInterface:
17
  """Create and return appropriate LLM provider based on config"""
18
  cfg = ConfigProvider.get()
19
- llm_provider = cfg.global_config.llm_provider
20
 
21
- log(f"🏭 Creating LLM provider: {llm_provider}")
 
22
 
23
- # Get provider config
24
- provider_config = cfg.global_config.get_llm_provider_config()
25
- if not provider_config:
26
- raise ValueError(f"Unknown LLM provider: {llm_provider}")
 
 
 
27
 
28
  # Get API key
29
- api_key = LLMFactory._get_api_key()
30
- if not api_key and provider_config.requires_api_key:
31
- raise ValueError(f"API key required for {llm_provider} but not configured")
 
 
 
 
 
32
 
33
  # Create appropriate provider
34
- if llm_provider == "spark":
35
- return LLMFactory._create_spark_provider(api_key)
36
- elif llm_provider in ("gpt4o", "gpt4o-mini"):
37
- return LLMFactory._create_gpt_provider(llm_provider, api_key)
38
  else:
39
- raise ValueError(f"Unsupported LLM provider: {llm_provider}")
40
 
41
  @staticmethod
42
- def _create_spark_provider(api_key: str) -> SparkLLM:
43
  """Create Spark LLM provider"""
44
- cfg = ConfigProvider.get()
45
-
46
- endpoint = cfg.global_config.llm_provider_endpoint
47
- if not endpoint:
48
- raise ValueError("Spark requires llm_provider_endpoint to be configured")
49
-
50
- log(f"πŸš€ Creating SparkLLM provider")
51
  log(f"πŸ“ Endpoint: {endpoint}")
52
 
53
- # Determine work mode for Spark (backward compatibility)
54
- work_mode = "cloud" # Default
55
- if not cfg.global_config.is_cloud_mode():
56
- work_mode = "on-premise"
57
-
58
  return SparkLLM(
59
- spark_endpoint=str(endpoint),
60
  spark_token=api_key,
61
- work_mode=work_mode
 
62
  )
63
 
64
  @staticmethod
65
- def _create_gpt_provider(model_type: str, api_key: str) -> GPT4oLLM:
66
  """Create GPT-4o LLM provider"""
67
  # Determine model
68
  model = "gpt-4o-mini" if model_type == "gpt4o-mini" else "gpt-4o"
@@ -71,37 +70,39 @@ class LLMFactory:
71
 
72
  return GPT4oLLM(
73
  api_key=api_key,
74
- model=model
 
75
  )
76
 
77
  @staticmethod
78
- def _get_api_key() -> Optional[str]:
79
  """Get API key from config or environment"""
80
  cfg = ConfigProvider.get()
81
 
82
  # First check encrypted config
83
- api_key = cfg.global_config.get_plain_api_key()
84
  if api_key:
85
  log("πŸ”‘ Using decrypted API key from config")
86
  return api_key
87
 
88
  # Then check environment based on provider
89
- llm_provider = cfg.global_config.llm_provider
90
-
91
  env_var_map = {
92
  "spark": "SPARK_TOKEN",
 
 
93
  "gpt4o": "OPENAI_API_KEY",
94
  "gpt4o-mini": "OPENAI_API_KEY",
95
- # Add more mappings as needed
96
  }
97
 
98
- env_var = env_var_map.get(llm_provider)
99
  if env_var:
100
- if cfg.global_config.is_cloud_mode():
 
101
  api_key = os.environ.get(env_var)
102
  if api_key:
103
- log(f"πŸ”‘ Using {env_var} from environment")
104
  else:
 
105
  load_dotenv()
106
  api_key = os.getenv(env_var)
107
  if api_key:
 
10
  from utils import log
11
 
12
  class LLMFactory:
13
+ """Factory class to create appropriate LLM provider based on configuration"""
14
 
15
  @staticmethod
16
  def create_provider() -> LLMInterface:
17
  """Create and return appropriate LLM provider based on config"""
18
  cfg = ConfigProvider.get()
19
+ llm_config = cfg.global_config.llm_provider
20
 
21
+ if not llm_config:
22
+ raise ValueError("No LLM provider configured")
23
 
24
+ provider_name = llm_config.name
25
+ log(f"🏭 Creating LLM provider: {provider_name}")
26
+
27
+ # Get provider definition
28
+ provider_def = cfg.global_config.get_provider_config("llm", provider_name)
29
+ if not provider_def:
30
+ raise ValueError(f"Unknown LLM provider: {provider_name}")
31
 
32
  # Get API key
33
+ api_key = LLMFactory._get_api_key(provider_name)
34
+ if not api_key and provider_def.requires_api_key:
35
+ raise ValueError(f"API key required for {provider_name} but not configured")
36
+
37
+ # Get endpoint
38
+ endpoint = llm_config.endpoint
39
+ if not endpoint and provider_def.requires_endpoint:
40
+ raise ValueError(f"Endpoint required for {provider_name} but not configured")
41
 
42
  # Create appropriate provider
43
+ if provider_name in ("spark", "spark_cloud", "spark_onpremise"):
44
+ return LLMFactory._create_spark_provider(provider_name, api_key, endpoint, llm_config.settings)
45
+ elif provider_name in ("gpt4o", "gpt4o-mini"):
46
+ return LLMFactory._create_gpt_provider(provider_name, api_key, llm_config.settings)
47
  else:
48
+ raise ValueError(f"Unsupported LLM provider: {provider_name}")
49
 
50
  @staticmethod
51
+ def _create_spark_provider(provider_name: str, api_key: str, endpoint: str, settings: dict) -> SparkLLM:
52
  """Create Spark LLM provider"""
53
+ log(f"πŸš€ Creating SparkLLM provider: {provider_name}")
 
 
 
 
 
 
54
  log(f"πŸ“ Endpoint: {endpoint}")
55
 
 
 
 
 
 
56
  return SparkLLM(
57
+ spark_endpoint=endpoint,
58
  spark_token=api_key,
59
+ provider_variant=provider_name,
60
+ settings=settings
61
  )
62
 
63
  @staticmethod
64
+ def _create_gpt_provider(model_type: str, api_key: str, settings: dict) -> GPT4oLLM:
65
  """Create GPT-4o LLM provider"""
66
  # Determine model
67
  model = "gpt-4o-mini" if model_type == "gpt4o-mini" else "gpt-4o"
 
70
 
71
  return GPT4oLLM(
72
  api_key=api_key,
73
+ model=model,
74
+ settings=settings
75
  )
76
 
77
  @staticmethod
78
+ def _get_api_key(provider_name: str) -> Optional[str]:
79
  """Get API key from config or environment"""
80
  cfg = ConfigProvider.get()
81
 
82
  # First check encrypted config
83
+ api_key = cfg.global_config.get_plain_api_key("llm")
84
  if api_key:
85
  log("πŸ”‘ Using decrypted API key from config")
86
  return api_key
87
 
88
  # Then check environment based on provider
 
 
89
  env_var_map = {
90
  "spark": "SPARK_TOKEN",
91
+ "spark_cloud": "SPARK_TOKEN",
92
+ "spark_onpremise": "SPARK_TOKEN",
93
  "gpt4o": "OPENAI_API_KEY",
94
  "gpt4o-mini": "OPENAI_API_KEY",
 
95
  }
96
 
97
+ env_var = env_var_map.get(provider_name)
98
  if env_var:
99
+ # Check if running in HuggingFace Space
100
+ if os.environ.get("SPACE_ID"):
101
  api_key = os.environ.get(env_var)
102
  if api_key:
103
+ log(f"πŸ”‘ Using {env_var} from HuggingFace secrets")
104
  else:
105
+ # Local/on-premise deployment
106
  load_dotenv()
107
  api_key = os.getenv(env_var)
108
  if api_key: