ciyidogan commited on
Commit
394611c
Β·
verified Β·
1 Parent(s): c86c3d6

Update llm_factory.py

Browse files
Files changed (1) hide show
  1. llm_factory.py +31 -25
llm_factory.py CHANGED
@@ -2,7 +2,7 @@
2
  LLM Provider Factory for Flare
3
  """
4
  import os
5
- from typing import Optional
6
  from dotenv import load_dotenv
7
 
8
  from llm_interface import LLMInterface, SparkLLM, GPT4oLLM
@@ -10,58 +10,64 @@ from config_provider import ConfigProvider
10
  from utils import log
11
 
12
  class LLMFactory:
13
- """Factory class to create appropriate LLM provider based on configuration"""
14
 
15
  @staticmethod
16
  def create_provider() -> LLMInterface:
17
  """Create and return appropriate LLM provider based on config"""
18
  cfg = ConfigProvider.get()
19
- llm_config = cfg.global_config.llm_provider
20
 
21
- if not llm_config:
22
  raise ValueError("No LLM provider configured")
23
 
24
- provider_name = llm_config.name
25
  log(f"🏭 Creating LLM provider: {provider_name}")
26
 
27
- # Get provider definition
28
- provider_def = cfg.global_config.get_provider_config("llm", provider_name)
29
- if not provider_def:
30
  raise ValueError(f"Unknown LLM provider: {provider_name}")
31
 
32
  # Get API key
33
  api_key = LLMFactory._get_api_key(provider_name)
34
- if not api_key and provider_def.requires_api_key:
35
  raise ValueError(f"API key required for {provider_name} but not configured")
36
 
37
- # Get endpoint
38
- endpoint = llm_config.endpoint
39
- if not endpoint and provider_def.requires_endpoint:
40
- raise ValueError(f"Endpoint required for {provider_name} but not configured")
41
 
42
  # Create appropriate provider
43
- if provider_name in ("spark", "spark_cloud", "spark_onpremise"):
44
- return LLMFactory._create_spark_provider(provider_name, api_key, endpoint, llm_config.settings)
45
  elif provider_name in ("gpt4o", "gpt4o-mini"):
46
- return LLMFactory._create_gpt_provider(provider_name, api_key, llm_config.settings)
47
  else:
48
  raise ValueError(f"Unsupported LLM provider: {provider_name}")
49
 
50
  @staticmethod
51
- def _create_spark_provider(provider_name: str, api_key: str, endpoint: str, settings: dict) -> SparkLLM:
52
  """Create Spark LLM provider"""
53
- log(f"πŸš€ Creating SparkLLM provider: {provider_name}")
 
 
 
54
  log(f"πŸ“ Endpoint: {endpoint}")
55
 
 
 
 
 
 
56
  return SparkLLM(
57
- spark_endpoint=endpoint,
58
  spark_token=api_key,
59
- provider_variant=provider_name,
60
  settings=settings
61
  )
62
 
63
  @staticmethod
64
- def _create_gpt_provider(model_type: str, api_key: str, settings: dict) -> GPT4oLLM:
65
  """Create GPT-4o LLM provider"""
66
  # Determine model
67
  model = "gpt-4o-mini" if model_type == "gpt4o-mini" else "gpt-4o"
@@ -88,8 +94,6 @@ class LLMFactory:
88
  # Then check environment based on provider
89
  env_var_map = {
90
  "spark": "SPARK_TOKEN",
91
- "spark_cloud": "SPARK_TOKEN",
92
- "spark_onpremise": "SPARK_TOKEN",
93
  "gpt4o": "OPENAI_API_KEY",
94
  "gpt4o-mini": "OPENAI_API_KEY",
95
  }
@@ -101,11 +105,13 @@ class LLMFactory:
101
  api_key = os.environ.get(env_var)
102
  if api_key:
103
  log(f"πŸ”‘ Using {env_var} from HuggingFace secrets")
 
104
  else:
105
- # Local/on-premise deployment
106
  load_dotenv()
107
  api_key = os.getenv(env_var)
108
  if api_key:
109
  log(f"πŸ”‘ Using {env_var} from .env file")
 
110
 
111
- return api_key
 
2
  LLM Provider Factory for Flare
3
  """
4
  import os
5
+ from typing import Optional, Dict, Any
6
  from dotenv import load_dotenv
7
 
8
  from llm_interface import LLMInterface, SparkLLM, GPT4oLLM
 
10
  from utils import log
11
 
12
  class LLMFactory:
13
+ """Factory class to create appropriate LLM provider based on llm_provider config"""
14
 
15
  @staticmethod
16
  def create_provider() -> LLMInterface:
17
  """Create and return appropriate LLM provider based on config"""
18
  cfg = ConfigProvider.get()
19
+ llm_provider = cfg.global_config.llm_provider
20
 
21
+ if not llm_provider or not llm_provider.name:
22
  raise ValueError("No LLM provider configured")
23
 
24
+ provider_name = llm_provider.name
25
  log(f"🏭 Creating LLM provider: {provider_name}")
26
 
27
+ # Get provider config
28
+ provider_config = cfg.global_config.get_provider_config("llm", provider_name)
29
+ if not provider_config:
30
  raise ValueError(f"Unknown LLM provider: {provider_name}")
31
 
32
  # Get API key
33
  api_key = LLMFactory._get_api_key(provider_name)
34
+ if not api_key and provider_config.requires_api_key:
35
  raise ValueError(f"API key required for {provider_name} but not configured")
36
 
37
+ # Get settings
38
+ settings = llm_provider.settings or {}
 
 
39
 
40
  # Create appropriate provider
41
+ if provider_name == "spark":
42
+ return LLMFactory._create_spark_provider(api_key, llm_provider.endpoint, settings)
43
  elif provider_name in ("gpt4o", "gpt4o-mini"):
44
+ return LLMFactory._create_gpt_provider(provider_name, api_key, settings)
45
  else:
46
  raise ValueError(f"Unsupported LLM provider: {provider_name}")
47
 
48
  @staticmethod
49
+ def _create_spark_provider(api_key: str, endpoint: Optional[str], settings: Dict[str, Any]) -> SparkLLM:
50
  """Create Spark LLM provider"""
51
+ if not endpoint:
52
+ raise ValueError("Spark requires endpoint to be configured")
53
+
54
+ log(f"πŸš€ Creating SparkLLM provider")
55
  log(f"πŸ“ Endpoint: {endpoint}")
56
 
57
+ # Determine provider variant for backward compatibility
58
+ provider_variant = "spark-cloud"
59
+ if not ConfigProvider.get().global_config.is_cloud_mode():
60
+ provider_variant = "spark-onpremise"
61
+
62
  return SparkLLM(
63
+ spark_endpoint=str(endpoint),
64
  spark_token=api_key,
65
+ provider_variant=provider_variant,
66
  settings=settings
67
  )
68
 
69
  @staticmethod
70
+ def _create_gpt_provider(model_type: str, api_key: str, settings: Dict[str, Any]) -> GPT4oLLM:
71
  """Create GPT-4o LLM provider"""
72
  # Determine model
73
  model = "gpt-4o-mini" if model_type == "gpt4o-mini" else "gpt-4o"
 
94
  # Then check environment based on provider
95
  env_var_map = {
96
  "spark": "SPARK_TOKEN",
 
 
97
  "gpt4o": "OPENAI_API_KEY",
98
  "gpt4o-mini": "OPENAI_API_KEY",
99
  }
 
105
  api_key = os.environ.get(env_var)
106
  if api_key:
107
  log(f"πŸ”‘ Using {env_var} from HuggingFace secrets")
108
+ return api_key
109
  else:
110
+ # Local development
111
  load_dotenv()
112
  api_key = os.getenv(env_var)
113
  if api_key:
114
  log(f"πŸ”‘ Using {env_var} from .env file")
115
+ return api_key
116
 
117
+ return None