ciyidogan commited on
Commit
fc0299f
Β·
verified Β·
1 Parent(s): b9b2b1e

Update llm_interface.py

Browse files
Files changed (1) hide show
  1. llm_interface.py +83 -56
llm_interface.py CHANGED
@@ -3,7 +3,7 @@ LLM Provider Interface for Flare
3
  """
4
  import os
5
  from abc import ABC, abstractmethod
6
- from typing import Dict, List, Optional
7
  import httpx
8
  from openai import AsyncOpenAI
9
  from utils import log
@@ -11,6 +11,12 @@ from utils import log
11
  class LLMInterface(ABC):
12
  """Abstract base class for LLM providers"""
13
 
 
 
 
 
 
 
14
  @abstractmethod
15
  async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
16
  """Generate response from LLM"""
@@ -22,21 +28,23 @@ class LLMInterface(ABC):
22
  pass
23
 
24
  class SparkLLM(LLMInterface):
25
- """Existing Spark integration"""
26
 
27
-
28
- def __init__(self, spark_endpoint: str, spark_token: str, work_mode: str = "cloud"):
29
  self.spark_endpoint = spark_endpoint.rstrip("/")
30
  self.spark_token = spark_token
31
- self.work_mode = work_mode
32
- log(f"πŸ”Œ SparkLLM initialized with endpoint: {self.spark_endpoint}")
33
 
34
  async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
 
35
  headers = {
36
  "Authorization": f"Bearer {self.spark_token}",
37
  "Content-Type": "application/json"
38
  }
39
 
 
40
  payload = {
41
  "system_prompt": system_prompt,
42
  "user_input": user_input,
@@ -44,6 +52,7 @@ class SparkLLM(LLMInterface):
44
  }
45
 
46
  try:
 
47
  async with httpx.AsyncClient(timeout=60) as client:
48
  response = await client.post(
49
  f"{self.spark_endpoint}/generate",
@@ -59,77 +68,95 @@ class SparkLLM(LLMInterface):
59
  raw = (data.get("assistant") or data.get("text", "")).strip()
60
 
61
  return raw
 
 
 
62
  except Exception as e:
63
  log(f"❌ Spark error: {e}")
64
  raise
65
 
66
  async def startup(self, project_config: Dict) -> bool:
67
  """Send startup request to Spark"""
68
- # Existing Spark startup logic
69
- return True
70
-
71
- class GPT4oLLM(LLMInterface):
72
- """OpenAI GPT integration"""
73
-
74
- def __init__(self, api_key: str, model: str = "gpt-4o-mini"):
75
- self.api_key = api_key
76
- self.model = model
77
- self.client = AsyncOpenAI(api_key=api_key)
78
- log(f"βœ… Initialized GPT LLM with model: {model}")
79
-
80
- async def generate(self, project_name: str, user_input: str, system_prompt: str, context: List[Dict], version_config: Dict = None) -> str:
81
- """Generate response from LLM with project context"""
82
  headers = {
83
  "Authorization": f"Bearer {self.spark_token}",
84
  "Content-Type": "application/json"
85
  }
86
 
87
- # Build payload with all required fields for Spark
88
- payload = {
89
- "work_mode": self.work_mode,
90
- "cloud_token": self.spark_token,
91
- "project_name": project_name,
92
- "system_prompt": system_prompt,
93
- "user_input": user_input,
94
- "context": context
95
- }
96
-
97
- # Add version-specific config if available
98
- if version_config:
99
- llm_config = version_config.get("llm", {})
100
- payload.update({
101
- "project_version": version_config.get("version_id"),
102
- "repo_id": llm_config.get("repo_id"),
103
- "generation_config": llm_config.get("generation_config"),
104
- "use_fine_tune": llm_config.get("use_fine_tune"),
105
- "fine_tune_zip": llm_config.get("fine_tune_zip")
106
- })
107
-
108
  try:
109
- log(f"πŸ“€ Spark request payload keys: {list(payload.keys())}")
110
- async with httpx.AsyncClient(timeout=60) as client:
111
  response = await client.post(
112
- f"{self.spark_endpoint}/generate",
113
- json=payload,
114
  headers=headers
115
  )
116
  response.raise_for_status()
117
- data = response.json()
118
- return data.get("model_answer", data.get("assistant", data.get("text", "")))
119
- except httpx.TimeoutException:
120
- log("⏱️ Spark timeout")
121
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  except Exception as e:
123
- log(f"❌ Spark error: {str(e)}")
124
  raise
125
 
126
  async def startup(self, project_config: Dict) -> bool:
127
- """Validate API key"""
128
  try:
129
- # Test API key with a simple request
130
- test_response = await self.client.models.list()
131
- log(f"βœ… OpenAI API key validated, available models: {len(test_response.data)}")
 
 
 
 
132
  return True
133
  except Exception as e:
134
- log(f"❌ Invalid OpenAI API key: {e}")
135
  return False
 
3
  """
4
  import os
5
  from abc import ABC, abstractmethod
6
+ from typing import Dict, List, Optional, Any
7
  import httpx
8
  from openai import AsyncOpenAI
9
  from utils import log
 
11
  class LLMInterface(ABC):
12
  """Abstract base class for LLM providers"""
13
 
14
+ def __init__(self, settings: Dict[str, Any] = None):
15
+ """Initialize with provider-specific settings"""
16
+ self.settings = settings or {}
17
+ self.internal_prompt = self.settings.get("internal_prompt", "")
18
+ self.parameter_collection_config = self.settings.get("parameter_collection_config", {})
19
+
20
  @abstractmethod
21
  async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
22
  """Generate response from LLM"""
 
28
  pass
29
 
30
  class SparkLLM(LLMInterface):
31
+ """Spark LLM integration"""
32
 
33
+ def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "spark", settings: Dict = None):
34
+ super().__init__(settings)
35
  self.spark_endpoint = spark_endpoint.rstrip("/")
36
  self.spark_token = spark_token
37
+ self.provider_variant = provider_variant
38
+ log(f"πŸ”Œ SparkLLM initialized with endpoint: {self.spark_endpoint}, variant: {self.provider_variant}")
39
 
40
  async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
41
+ """Generate response from Spark"""
42
  headers = {
43
  "Authorization": f"Bearer {self.spark_token}",
44
  "Content-Type": "application/json"
45
  }
46
 
47
+ # Build payload
48
  payload = {
49
  "system_prompt": system_prompt,
50
  "user_input": user_input,
 
52
  }
53
 
54
  try:
55
+ log(f"πŸ“€ Spark request to {self.spark_endpoint}/generate")
56
  async with httpx.AsyncClient(timeout=60) as client:
57
  response = await client.post(
58
  f"{self.spark_endpoint}/generate",
 
68
  raw = (data.get("assistant") or data.get("text", "")).strip()
69
 
70
  return raw
71
+ except httpx.TimeoutException:
72
+ log("⏱️ Spark timeout")
73
+ raise
74
  except Exception as e:
75
  log(f"❌ Spark error: {e}")
76
  raise
77
 
78
  async def startup(self, project_config: Dict) -> bool:
79
  """Send startup request to Spark"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  headers = {
81
  "Authorization": f"Bearer {self.spark_token}",
82
  "Content-Type": "application/json"
83
  }
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  try:
86
+ log(f"πŸš€ Sending startup to Spark for project: {project_config.get('project_name')}")
87
+ async with httpx.AsyncClient(timeout=30) as client:
88
  response = await client.post(
89
+ f"{self.spark_endpoint}/startup",
90
+ json=project_config,
91
  headers=headers
92
  )
93
  response.raise_for_status()
94
+ log("βœ… Spark startup successful")
95
+ return True
96
+ except Exception as e:
97
+ log(f"❌ Spark startup failed: {e}")
98
+ return False
99
+
100
+ class GPT4oLLM(LLMInterface):
101
+ """OpenAI GPT integration"""
102
+
103
+ def __init__(self, api_key: str, model: str = "gpt-4o-mini", settings: Dict = None):
104
+ super().__init__(settings)
105
+ self.api_key = api_key
106
+ self.model = model
107
+ self.client = AsyncOpenAI(api_key=api_key)
108
+ # Default GPT settings
109
+ self.temperature = settings.get("temperature", 0.3) if settings else 0.3
110
+ self.max_tokens = settings.get("max_tokens", 512) if settings else 512
111
+ log(f"βœ… Initialized GPT LLM with model: {model}")
112
+
113
+ async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
114
+ """Generate response from GPT"""
115
+ try:
116
+ # Convert context to OpenAI format
117
+ messages = []
118
+
119
+ # Add system prompt
120
+ messages.append({"role": "system", "content": system_prompt})
121
+
122
+ # Add conversation history
123
+ for msg in context[-10:]: # Last 10 messages
124
+ role = "user" if msg["role"] == "user" else "assistant"
125
+ messages.append({"role": role, "content": msg["content"]})
126
+
127
+ # Add current user input
128
+ messages.append({"role": "user", "content": user_input})
129
+
130
+ log(f"πŸ“€ GPT request with {len(messages)} messages")
131
+
132
+ # Call OpenAI API
133
+ response = await self.client.chat.completions.create(
134
+ model=self.model,
135
+ messages=messages,
136
+ temperature=self.temperature,
137
+ max_tokens=self.max_tokens
138
+ )
139
+
140
+ content = response.choices[0].message.content
141
+ log(f"βœ… GPT response received: {len(content)} chars")
142
+
143
+ return content
144
+
145
  except Exception as e:
146
+ log(f"❌ GPT error: {e}")
147
  raise
148
 
149
  async def startup(self, project_config: Dict) -> bool:
150
+ """GPT doesn't need startup - just validate API key"""
151
  try:
152
+ # Test API key with a minimal request
153
+ response = await self.client.chat.completions.create(
154
+ model=self.model,
155
+ messages=[{"role": "user", "content": "test"}],
156
+ max_tokens=5
157
+ )
158
+ log("βœ… GPT API key validated")
159
  return True
160
  except Exception as e:
161
+ log(f"❌ GPT API key validation failed: {e}")
162
  return False