ciyidogan commited on
Commit
1c091c0
Β·
verified Β·
1 Parent(s): e2a364d

Update llm_interface.py

Browse files
Files changed (1) hide show
  1. llm_interface.py +56 -22
llm_interface.py CHANGED
@@ -12,7 +12,7 @@ class LLMInterface(ABC):
12
  """Abstract base class for LLM providers"""
13
 
14
  def __init__(self, settings: Dict[str, Any] = None):
15
- """Initialize with settings"""
16
  self.settings = settings or {}
17
  self.internal_prompt = self.settings.get("internal_prompt", "")
18
  self.parameter_collection_config = self.settings.get("parameter_collection_config", {})
@@ -28,9 +28,9 @@ class LLMInterface(ABC):
28
  pass
29
 
30
  class SparkLLM(LLMInterface):
31
- """Spark integration for HuggingFace"""
32
 
33
- def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "spark-cloud", settings: Dict[str, Any] = None):
34
  super().__init__(settings)
35
  self.spark_endpoint = spark_endpoint.rstrip("/")
36
  self.spark_token = spark_token
@@ -38,12 +38,13 @@ class SparkLLM(LLMInterface):
38
  log(f"πŸ”Œ SparkLLM initialized with endpoint: {self.spark_endpoint}")
39
 
40
  async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
41
- """Generate response using Spark"""
42
  headers = {
43
  "Authorization": f"Bearer {self.spark_token}",
44
  "Content-Type": "application/json"
45
  }
46
 
 
47
  payload = {
48
  "system_prompt": system_prompt,
49
  "user_input": user_input,
@@ -72,8 +73,40 @@ class SparkLLM(LLMInterface):
72
 
73
  async def startup(self, project_config: Dict) -> bool:
74
  """Send startup request to Spark"""
75
- # Implement if needed for Spark startup notification
76
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  class GPT4oLLM(LLMInterface):
79
  """OpenAI GPT integration"""
@@ -81,17 +114,25 @@ class GPT4oLLM(LLMInterface):
81
  def __init__(self, api_key: str, model: str = "gpt-4o-mini", settings: Dict[str, Any] = None):
82
  super().__init__(settings)
83
  self.api_key = api_key
84
- self.model = model
85
  self.client = AsyncOpenAI(api_key=api_key)
86
 
87
- # Extract settings
88
  self.temperature = settings.get("temperature", 0.7) if settings else 0.7
89
  self.max_tokens = settings.get("max_tokens", 4096) if settings else 4096
90
 
91
- log(f"βœ… Initialized GPT LLM with model: {model}")
 
 
 
 
 
 
 
 
92
 
93
  async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
94
- """Generate response using OpenAI GPT"""
95
  try:
96
  # Build messages
97
  messages = [{"role": "system", "content": system_prompt}]
@@ -106,7 +147,7 @@ class GPT4oLLM(LLMInterface):
106
  # Add current user input
107
  messages.append({"role": "user", "content": user_input})
108
 
109
- # Generate response
110
  response = await self.client.chat.completions.create(
111
  model=self.model,
112
  messages=messages,
@@ -115,18 +156,11 @@ class GPT4oLLM(LLMInterface):
115
  )
116
 
117
  return response.choices[0].message.content.strip()
118
-
119
  except Exception as e:
120
- log(f"❌ GPT error: {e}")
121
  raise
122
 
123
  async def startup(self, project_config: Dict) -> bool:
124
- """Validate API key"""
125
- try:
126
- # Test API key with a simple request
127
- response = await self.client.models.list()
128
- log(f"βœ… OpenAI API key validated, available models: {len(response.data)}")
129
- return True
130
- except Exception as e:
131
- log(f"❌ Invalid OpenAI API key: {e}")
132
- return False
 
12
  """Abstract base class for LLM providers"""
13
 
14
  def __init__(self, settings: Dict[str, Any] = None):
15
+ """Initialize with provider settings"""
16
  self.settings = settings or {}
17
  self.internal_prompt = self.settings.get("internal_prompt", "")
18
  self.parameter_collection_config = self.settings.get("parameter_collection_config", {})
 
28
  pass
29
 
30
  class SparkLLM(LLMInterface):
31
+ """Spark LLM integration"""
32
 
33
+ def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None):
34
  super().__init__(settings)
35
  self.spark_endpoint = spark_endpoint.rstrip("/")
36
  self.spark_token = spark_token
 
38
  log(f"πŸ”Œ SparkLLM initialized with endpoint: {self.spark_endpoint}")
39
 
40
  async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
41
+ """Generate response from Spark LLM"""
42
  headers = {
43
  "Authorization": f"Bearer {self.spark_token}",
44
  "Content-Type": "application/json"
45
  }
46
 
47
+ # Build payload
48
  payload = {
49
  "system_prompt": system_prompt,
50
  "user_input": user_input,
 
73
 
74
  async def startup(self, project_config: Dict) -> bool:
75
  """Send startup request to Spark"""
76
+ headers = {
77
+ "Authorization": f"Bearer {self.spark_token}",
78
+ "Content-Type": "application/json"
79
+ }
80
+
81
+ # Extract required fields from project config
82
+ body = {
83
+ "work_mode": self.provider_variant,
84
+ "cloud_token": self.spark_token,
85
+ "project_name": project_config.get("name"),
86
+ "project_version": project_config.get("version_id"),
87
+ "repo_id": project_config.get("repo_id"),
88
+ "generation_config": project_config.get("generation_config", {}),
89
+ "use_fine_tune": project_config.get("use_fine_tune", False),
90
+ "fine_tune_zip": project_config.get("fine_tune_zip", "")
91
+ }
92
+
93
+ try:
94
+ async with httpx.AsyncClient(timeout=10) as client:
95
+ response = await client.post(
96
+ f"{self.spark_endpoint}/startup",
97
+ json=body,
98
+ headers=headers
99
+ )
100
+
101
+ if response.status_code >= 400:
102
+ log(f"❌ Spark startup failed: {response.status_code} - {response.text}")
103
+ return False
104
+
105
+ log(f"βœ… Spark acknowledged startup ({response.status_code})")
106
+ return True
107
+ except Exception as e:
108
+ log(f"⚠️ Spark startup error: {e}")
109
+ return False
110
 
111
  class GPT4oLLM(LLMInterface):
112
  """OpenAI GPT integration"""
 
114
  def __init__(self, api_key: str, model: str = "gpt-4o-mini", settings: Dict[str, Any] = None):
115
  super().__init__(settings)
116
  self.api_key = api_key
117
+ self.model = self._map_model_name(model)
118
  self.client = AsyncOpenAI(api_key=api_key)
119
 
120
+ # Extract model-specific settings
121
  self.temperature = settings.get("temperature", 0.7) if settings else 0.7
122
  self.max_tokens = settings.get("max_tokens", 4096) if settings else 4096
123
 
124
+ log(f"βœ… Initialized GPT LLM with model: {self.model}")
125
+
126
+ def _map_model_name(self, model: str) -> str:
127
+ """Map provider name to actual model name"""
128
+ mappings = {
129
+ "gpt4o": "gpt-4",
130
+ "gpt4o-mini": "gpt-4o-mini"
131
+ }
132
+ return mappings.get(model, model)
133
 
134
  async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
135
+ """Generate response from OpenAI"""
136
  try:
137
  # Build messages
138
  messages = [{"role": "system", "content": system_prompt}]
 
147
  # Add current user input
148
  messages.append({"role": "user", "content": user_input})
149
 
150
+ # Call OpenAI
151
  response = await self.client.chat.completions.create(
152
  model=self.model,
153
  messages=messages,
 
156
  )
157
 
158
  return response.choices[0].message.content.strip()
 
159
  except Exception as e:
160
+ log(f"❌ OpenAI error: {e}")
161
  raise
162
 
163
  async def startup(self, project_config: Dict) -> bool:
164
+ """GPT doesn't need startup, always return True"""
165
+ log("βœ… GPT provider ready (no startup needed)")
166
+ return True