Spaces:
Building
Building
Update llm_interface.py
Browse files- llm_interface.py +56 -22
llm_interface.py
CHANGED
@@ -12,7 +12,7 @@ class LLMInterface(ABC):
|
|
12 |
"""Abstract base class for LLM providers"""
|
13 |
|
14 |
def __init__(self, settings: Dict[str, Any] = None):
|
15 |
-
"""Initialize with settings"""
|
16 |
self.settings = settings or {}
|
17 |
self.internal_prompt = self.settings.get("internal_prompt", "")
|
18 |
self.parameter_collection_config = self.settings.get("parameter_collection_config", {})
|
@@ -28,9 +28,9 @@ class LLMInterface(ABC):
|
|
28 |
pass
|
29 |
|
30 |
class SparkLLM(LLMInterface):
|
31 |
-
"""Spark integration
|
32 |
|
33 |
-
def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "
|
34 |
super().__init__(settings)
|
35 |
self.spark_endpoint = spark_endpoint.rstrip("/")
|
36 |
self.spark_token = spark_token
|
@@ -38,12 +38,13 @@ class SparkLLM(LLMInterface):
|
|
38 |
log(f"π SparkLLM initialized with endpoint: {self.spark_endpoint}")
|
39 |
|
40 |
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
|
41 |
-
"""Generate response
|
42 |
headers = {
|
43 |
"Authorization": f"Bearer {self.spark_token}",
|
44 |
"Content-Type": "application/json"
|
45 |
}
|
46 |
|
|
|
47 |
payload = {
|
48 |
"system_prompt": system_prompt,
|
49 |
"user_input": user_input,
|
@@ -72,8 +73,40 @@ class SparkLLM(LLMInterface):
|
|
72 |
|
73 |
async def startup(self, project_config: Dict) -> bool:
|
74 |
"""Send startup request to Spark"""
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
class GPT4oLLM(LLMInterface):
|
79 |
"""OpenAI GPT integration"""
|
@@ -81,17 +114,25 @@ class GPT4oLLM(LLMInterface):
|
|
81 |
def __init__(self, api_key: str, model: str = "gpt-4o-mini", settings: Dict[str, Any] = None):
|
82 |
super().__init__(settings)
|
83 |
self.api_key = api_key
|
84 |
-
self.model = model
|
85 |
self.client = AsyncOpenAI(api_key=api_key)
|
86 |
|
87 |
-
# Extract settings
|
88 |
self.temperature = settings.get("temperature", 0.7) if settings else 0.7
|
89 |
self.max_tokens = settings.get("max_tokens", 4096) if settings else 4096
|
90 |
|
91 |
-
log(f"β
Initialized GPT LLM with model: {model}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
|
94 |
-
"""Generate response
|
95 |
try:
|
96 |
# Build messages
|
97 |
messages = [{"role": "system", "content": system_prompt}]
|
@@ -106,7 +147,7 @@ class GPT4oLLM(LLMInterface):
|
|
106 |
# Add current user input
|
107 |
messages.append({"role": "user", "content": user_input})
|
108 |
|
109 |
-
#
|
110 |
response = await self.client.chat.completions.create(
|
111 |
model=self.model,
|
112 |
messages=messages,
|
@@ -115,18 +156,11 @@ class GPT4oLLM(LLMInterface):
|
|
115 |
)
|
116 |
|
117 |
return response.choices[0].message.content.strip()
|
118 |
-
|
119 |
except Exception as e:
|
120 |
-
log(f"β
|
121 |
raise
|
122 |
|
123 |
async def startup(self, project_config: Dict) -> bool:
|
124 |
-
"""
|
125 |
-
|
126 |
-
|
127 |
-
response = await self.client.models.list()
|
128 |
-
log(f"β
OpenAI API key validated, available models: {len(response.data)}")
|
129 |
-
return True
|
130 |
-
except Exception as e:
|
131 |
-
log(f"β Invalid OpenAI API key: {e}")
|
132 |
-
return False
|
|
|
12 |
"""Abstract base class for LLM providers"""
|
13 |
|
14 |
def __init__(self, settings: Dict[str, Any] = None):
|
15 |
+
"""Initialize with provider settings"""
|
16 |
self.settings = settings or {}
|
17 |
self.internal_prompt = self.settings.get("internal_prompt", "")
|
18 |
self.parameter_collection_config = self.settings.get("parameter_collection_config", {})
|
|
|
28 |
pass
|
29 |
|
30 |
class SparkLLM(LLMInterface):
|
31 |
+
"""Spark LLM integration"""
|
32 |
|
33 |
+
def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None):
|
34 |
super().__init__(settings)
|
35 |
self.spark_endpoint = spark_endpoint.rstrip("/")
|
36 |
self.spark_token = spark_token
|
|
|
38 |
log(f"π SparkLLM initialized with endpoint: {self.spark_endpoint}")
|
39 |
|
40 |
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
|
41 |
+
"""Generate response from Spark LLM"""
|
42 |
headers = {
|
43 |
"Authorization": f"Bearer {self.spark_token}",
|
44 |
"Content-Type": "application/json"
|
45 |
}
|
46 |
|
47 |
+
# Build payload
|
48 |
payload = {
|
49 |
"system_prompt": system_prompt,
|
50 |
"user_input": user_input,
|
|
|
73 |
|
74 |
async def startup(self, project_config: Dict) -> bool:
|
75 |
"""Send startup request to Spark"""
|
76 |
+
headers = {
|
77 |
+
"Authorization": f"Bearer {self.spark_token}",
|
78 |
+
"Content-Type": "application/json"
|
79 |
+
}
|
80 |
+
|
81 |
+
# Extract required fields from project config
|
82 |
+
body = {
|
83 |
+
"work_mode": self.provider_variant,
|
84 |
+
"cloud_token": self.spark_token,
|
85 |
+
"project_name": project_config.get("name"),
|
86 |
+
"project_version": project_config.get("version_id"),
|
87 |
+
"repo_id": project_config.get("repo_id"),
|
88 |
+
"generation_config": project_config.get("generation_config", {}),
|
89 |
+
"use_fine_tune": project_config.get("use_fine_tune", False),
|
90 |
+
"fine_tune_zip": project_config.get("fine_tune_zip", "")
|
91 |
+
}
|
92 |
+
|
93 |
+
try:
|
94 |
+
async with httpx.AsyncClient(timeout=10) as client:
|
95 |
+
response = await client.post(
|
96 |
+
f"{self.spark_endpoint}/startup",
|
97 |
+
json=body,
|
98 |
+
headers=headers
|
99 |
+
)
|
100 |
+
|
101 |
+
if response.status_code >= 400:
|
102 |
+
log(f"β Spark startup failed: {response.status_code} - {response.text}")
|
103 |
+
return False
|
104 |
+
|
105 |
+
log(f"β
Spark acknowledged startup ({response.status_code})")
|
106 |
+
return True
|
107 |
+
except Exception as e:
|
108 |
+
log(f"β οΈ Spark startup error: {e}")
|
109 |
+
return False
|
110 |
|
111 |
class GPT4oLLM(LLMInterface):
|
112 |
"""OpenAI GPT integration"""
|
|
|
114 |
def __init__(self, api_key: str, model: str = "gpt-4o-mini", settings: Dict[str, Any] = None):
|
115 |
super().__init__(settings)
|
116 |
self.api_key = api_key
|
117 |
+
self.model = self._map_model_name(model)
|
118 |
self.client = AsyncOpenAI(api_key=api_key)
|
119 |
|
120 |
+
# Extract model-specific settings
|
121 |
self.temperature = settings.get("temperature", 0.7) if settings else 0.7
|
122 |
self.max_tokens = settings.get("max_tokens", 4096) if settings else 4096
|
123 |
|
124 |
+
log(f"β
Initialized GPT LLM with model: {self.model}")
|
125 |
+
|
126 |
+
def _map_model_name(self, model: str) -> str:
|
127 |
+
"""Map provider name to actual model name"""
|
128 |
+
mappings = {
|
129 |
+
"gpt4o": "gpt-4",
|
130 |
+
"gpt4o-mini": "gpt-4o-mini"
|
131 |
+
}
|
132 |
+
return mappings.get(model, model)
|
133 |
|
134 |
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
|
135 |
+
"""Generate response from OpenAI"""
|
136 |
try:
|
137 |
# Build messages
|
138 |
messages = [{"role": "system", "content": system_prompt}]
|
|
|
147 |
# Add current user input
|
148 |
messages.append({"role": "user", "content": user_input})
|
149 |
|
150 |
+
# Call OpenAI
|
151 |
response = await self.client.chat.completions.create(
|
152 |
model=self.model,
|
153 |
messages=messages,
|
|
|
156 |
)
|
157 |
|
158 |
return response.choices[0].message.content.strip()
|
|
|
159 |
except Exception as e:
|
160 |
+
log(f"β OpenAI error: {e}")
|
161 |
raise
|
162 |
|
163 |
async def startup(self, project_config: Dict) -> bool:
|
164 |
+
"""GPT doesn't need startup, always return True"""
|
165 |
+
log("β
GPT provider ready (no startup needed)")
|
166 |
+
return True
|
|
|
|
|
|
|
|
|
|
|
|