ciyidogan commited on
Commit
a71c268
Β·
verified Β·
1 Parent(s): b72fa83

Update llm_interface.py

Browse files
Files changed (1) hide show
  1. llm_interface.py +8 -138
llm_interface.py CHANGED
@@ -4,8 +4,6 @@ LLM Provider Interface for Flare
4
  import os
5
  from abc import ABC, abstractmethod
6
  from typing import Dict, List, Optional, Any
7
- import httpx
8
- from openai import AsyncOpenAI
9
  from utils import log
10
 
11
  class LLMInterface(ABC):
@@ -26,141 +24,13 @@ class LLMInterface(ABC):
26
  async def startup(self, project_config: Dict) -> bool:
27
  """Initialize LLM with project config"""
28
  pass
29
-
30
- class SparkLLM(LLMInterface):
31
- """Spark LLM integration"""
32
-
33
- def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None):
34
- super().__init__(settings)
35
- self.spark_endpoint = spark_endpoint.rstrip("/")
36
- self.spark_token = spark_token
37
- self.provider_variant = provider_variant
38
- log(f"πŸ”Œ SparkLLM initialized with endpoint: {self.spark_endpoint}")
39
-
40
- async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
41
- """Generate response from Spark LLM"""
42
- headers = {
43
- "Authorization": f"Bearer {self.spark_token}",
44
- "Content-Type": "application/json"
45
- }
46
-
47
- # Build payload
48
- payload = {
49
- "system_prompt": system_prompt,
50
- "user_input": user_input,
51
- "context": context
52
- }
53
-
54
- try:
55
- async with httpx.AsyncClient(timeout=60) as client:
56
- response = await client.post(
57
- f"{self.spark_endpoint}/generate",
58
- json=payload,
59
- headers=headers
60
- )
61
- response.raise_for_status()
62
- data = response.json()
63
-
64
- # Try different response fields
65
- raw = data.get("model_answer", "").strip()
66
- if not raw:
67
- raw = (data.get("assistant") or data.get("text", "")).strip()
68
-
69
- return raw
70
- except Exception as e:
71
- log(f"❌ Spark error: {e}")
72
- raise
73
-
74
- async def startup(self, project_config: Dict) -> bool:
75
- """Send startup request to Spark"""
76
- headers = {
77
- "Authorization": f"Bearer {self.spark_token}",
78
- "Content-Type": "application/json"
79
- }
80
-
81
- # Extract required fields from project config
82
- body = {
83
- "work_mode": self.provider_variant,
84
- "cloud_token": self.spark_token,
85
- "project_name": project_config.get("name"),
86
- "project_version": project_config.get("version_id"),
87
- "repo_id": project_config.get("repo_id"),
88
- "generation_config": project_config.get("generation_config", {}),
89
- "use_fine_tune": project_config.get("use_fine_tune", False),
90
- "fine_tune_zip": project_config.get("fine_tune_zip", "")
91
- }
92
-
93
- try:
94
- async with httpx.AsyncClient(timeout=10) as client:
95
- response = await client.post(
96
- f"{self.spark_endpoint}/startup",
97
- json=body,
98
- headers=headers
99
- )
100
-
101
- if response.status_code >= 400:
102
- log(f"❌ Spark startup failed: {response.status_code} - {response.text}")
103
- return False
104
-
105
- log(f"βœ… Spark acknowledged startup ({response.status_code})")
106
- return True
107
- except Exception as e:
108
- log(f"⚠️ Spark startup error: {e}")
109
- return False
110
-
111
- class GPT4oLLM(LLMInterface):
112
- """OpenAI GPT integration"""
113
 
114
- def __init__(self, api_key: str, model: str = "gpt-4o-mini", settings: Dict[str, Any] = None):
115
- super().__init__(settings)
116
- self.api_key = api_key
117
- self.model = self._map_model_name(model)
118
- self.client = AsyncOpenAI(api_key=api_key)
119
-
120
- # Extract model-specific settings
121
- self.temperature = settings.get("temperature", 0.7) if settings else 0.7
122
- self.max_tokens = settings.get("max_tokens", 4096) if settings else 4096
123
-
124
- log(f"βœ… Initialized GPT LLM with model: {self.model}")
125
-
126
- def _map_model_name(self, model: str) -> str:
127
- """Map provider name to actual model name"""
128
- mappings = {
129
- "gpt4o": "gpt-4",
130
- "gpt4o-mini": "gpt-4o-mini"
131
- }
132
- return mappings.get(model, model)
133
-
134
- async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
135
- """Generate response from OpenAI"""
136
- try:
137
- # Build messages
138
- messages = [{"role": "system", "content": system_prompt}]
139
-
140
- # Add context
141
- for msg in context:
142
- messages.append({
143
- "role": msg.get("role", "user"),
144
- "content": msg.get("content", "")
145
- })
146
-
147
- # Add current user input
148
- messages.append({"role": "user", "content": user_input})
149
-
150
- # Call OpenAI
151
- response = await self.client.chat.completions.create(
152
- model=self.model,
153
- messages=messages,
154
- temperature=self.temperature,
155
- max_tokens=self.max_tokens
156
- )
157
-
158
- return response.choices[0].message.content.strip()
159
- except Exception as e:
160
- log(f"❌ OpenAI error: {e}")
161
- raise
162
 
163
- async def startup(self, project_config: Dict) -> bool:
164
- """GPT doesn't need startup, always return True"""
165
- log("βœ… GPT provider ready (no startup needed)")
166
- return True
 
4
  import os
5
  from abc import ABC, abstractmethod
6
  from typing import Dict, List, Optional, Any
 
 
7
  from utils import log
8
 
9
  class LLMInterface(ABC):
 
24
  async def startup(self, project_config: Dict) -> bool:
25
  """Initialize LLM with project config"""
26
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ @abstractmethod
29
+ def get_provider_name(self) -> str:
30
+ """Get provider name for logging"""
31
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ @abstractmethod
34
+ def get_model_info(self) -> Dict[str, Any]:
35
+ """Get model information"""
36
+ pass