ciyidogan commited on
Commit
2d2ab61
Β·
verified Β·
1 Parent(s): 394611c

Update llm_interface.py

Browse files
Files changed (1) hide show
  1. llm_interface.py +29 -59
llm_interface.py CHANGED
@@ -12,7 +12,7 @@ class LLMInterface(ABC):
12
  """Abstract base class for LLM providers"""
13
 
14
  def __init__(self, settings: Dict[str, Any] = None):
15
- """Initialize with provider-specific settings"""
16
  self.settings = settings or {}
17
  self.internal_prompt = self.settings.get("internal_prompt", "")
18
  self.parameter_collection_config = self.settings.get("parameter_collection_config", {})
@@ -28,23 +28,22 @@ class LLMInterface(ABC):
28
  pass
29
 
30
  class SparkLLM(LLMInterface):
31
- """Spark LLM integration"""
32
 
33
- def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "spark", settings: Dict = None):
34
  super().__init__(settings)
35
  self.spark_endpoint = spark_endpoint.rstrip("/")
36
  self.spark_token = spark_token
37
  self.provider_variant = provider_variant
38
- log(f"πŸ”Œ SparkLLM initialized with endpoint: {self.spark_endpoint}, variant: {self.provider_variant}")
39
 
40
  async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
41
- """Generate response from Spark"""
42
  headers = {
43
  "Authorization": f"Bearer {self.spark_token}",
44
  "Content-Type": "application/json"
45
  }
46
 
47
- # Build payload
48
  payload = {
49
  "system_prompt": system_prompt,
50
  "user_input": user_input,
@@ -52,7 +51,6 @@ class SparkLLM(LLMInterface):
52
  }
53
 
54
  try:
55
- log(f"πŸ“€ Spark request to {self.spark_endpoint}/generate")
56
  async with httpx.AsyncClient(timeout=60) as client:
57
  response = await client.post(
58
  f"{self.spark_endpoint}/generate",
@@ -68,68 +66,47 @@ class SparkLLM(LLMInterface):
68
  raw = (data.get("assistant") or data.get("text", "")).strip()
69
 
70
  return raw
71
- except httpx.TimeoutException:
72
- log("⏱️ Spark timeout")
73
- raise
74
  except Exception as e:
75
  log(f"❌ Spark error: {e}")
76
  raise
77
 
78
  async def startup(self, project_config: Dict) -> bool:
79
  """Send startup request to Spark"""
80
- headers = {
81
- "Authorization": f"Bearer {self.spark_token}",
82
- "Content-Type": "application/json"
83
- }
84
-
85
- try:
86
- log(f"πŸš€ Sending startup to Spark for project: {project_config.get('project_name')}")
87
- async with httpx.AsyncClient(timeout=30) as client:
88
- response = await client.post(
89
- f"{self.spark_endpoint}/startup",
90
- json=project_config,
91
- headers=headers
92
- )
93
- response.raise_for_status()
94
- log("βœ… Spark startup successful")
95
- return True
96
- except Exception as e:
97
- log(f"❌ Spark startup failed: {e}")
98
- return False
99
 
100
  class GPT4oLLM(LLMInterface):
101
  """OpenAI GPT integration"""
102
 
103
- def __init__(self, api_key: str, model: str = "gpt-4o-mini", settings: Dict = None):
104
  super().__init__(settings)
105
  self.api_key = api_key
106
  self.model = model
107
  self.client = AsyncOpenAI(api_key=api_key)
108
- # Default GPT settings
109
- self.temperature = settings.get("temperature", 0.3) if settings else 0.3
110
- self.max_tokens = settings.get("max_tokens", 512) if settings else 512
 
 
111
  log(f"βœ… Initialized GPT LLM with model: {model}")
112
 
113
  async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
114
- """Generate response from GPT"""
115
  try:
116
- # Convert context to OpenAI format
117
- messages = []
118
-
119
- # Add system prompt
120
- messages.append({"role": "system", "content": system_prompt})
121
 
122
- # Add conversation history
123
- for msg in context[-10:]: # Last 10 messages
124
- role = "user" if msg["role"] == "user" else "assistant"
125
- messages.append({"role": role, "content": msg["content"]})
 
 
126
 
127
  # Add current user input
128
  messages.append({"role": "user", "content": user_input})
129
 
130
- log(f"πŸ“€ GPT request with {len(messages)} messages")
131
-
132
- # Call OpenAI API
133
  response = await self.client.chat.completions.create(
134
  model=self.model,
135
  messages=messages,
@@ -137,26 +114,19 @@ class GPT4oLLM(LLMInterface):
137
  max_tokens=self.max_tokens
138
  )
139
 
140
- content = response.choices[0].message.content
141
- log(f"βœ… GPT response received: {len(content)} chars")
142
-
143
- return content
144
 
145
  except Exception as e:
146
  log(f"❌ GPT error: {e}")
147
  raise
148
 
149
  async def startup(self, project_config: Dict) -> bool:
150
- """GPT doesn't need startup - just validate API key"""
151
  try:
152
- # Test API key with a minimal request
153
- response = await self.client.chat.completions.create(
154
- model=self.model,
155
- messages=[{"role": "user", "content": "test"}],
156
- max_tokens=5
157
- )
158
- log("βœ… GPT API key validated")
159
  return True
160
  except Exception as e:
161
- log(f"❌ GPT API key validation failed: {e}")
162
  return False
 
12
  """Abstract base class for LLM providers"""
13
 
14
  def __init__(self, settings: Dict[str, Any] = None):
15
+ """Initialize with settings"""
16
  self.settings = settings or {}
17
  self.internal_prompt = self.settings.get("internal_prompt", "")
18
  self.parameter_collection_config = self.settings.get("parameter_collection_config", {})
 
28
  pass
29
 
30
  class SparkLLM(LLMInterface):
31
+ """Spark integration for HuggingFace"""
32
 
33
+ def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "spark-cloud", settings: Dict[str, Any] = None):
34
  super().__init__(settings)
35
  self.spark_endpoint = spark_endpoint.rstrip("/")
36
  self.spark_token = spark_token
37
  self.provider_variant = provider_variant
38
+ log(f"πŸ”Œ SparkLLM initialized with endpoint: {self.spark_endpoint}")
39
 
40
  async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
41
+ """Generate response using Spark"""
42
  headers = {
43
  "Authorization": f"Bearer {self.spark_token}",
44
  "Content-Type": "application/json"
45
  }
46
 
 
47
  payload = {
48
  "system_prompt": system_prompt,
49
  "user_input": user_input,
 
51
  }
52
 
53
  try:
 
54
  async with httpx.AsyncClient(timeout=60) as client:
55
  response = await client.post(
56
  f"{self.spark_endpoint}/generate",
 
66
  raw = (data.get("assistant") or data.get("text", "")).strip()
67
 
68
  return raw
 
 
 
69
  except Exception as e:
70
  log(f"❌ Spark error: {e}")
71
  raise
72
 
73
  async def startup(self, project_config: Dict) -> bool:
74
  """Send startup request to Spark"""
75
+ # Implement if needed for Spark startup notification
76
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  class GPT4oLLM(LLMInterface):
79
  """OpenAI GPT integration"""
80
 
81
+ def __init__(self, api_key: str, model: str = "gpt-4o-mini", settings: Dict[str, Any] = None):
82
  super().__init__(settings)
83
  self.api_key = api_key
84
  self.model = model
85
  self.client = AsyncOpenAI(api_key=api_key)
86
+
87
+ # Extract settings
88
+ self.temperature = settings.get("temperature", 0.7) if settings else 0.7
89
+ self.max_tokens = settings.get("max_tokens", 4096) if settings else 4096
90
+
91
  log(f"βœ… Initialized GPT LLM with model: {model}")
92
 
93
  async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
94
+ """Generate response using OpenAI GPT"""
95
  try:
96
+ # Build messages
97
+ messages = [{"role": "system", "content": system_prompt}]
 
 
 
98
 
99
+ # Add context
100
+ for msg in context:
101
+ messages.append({
102
+ "role": msg.get("role", "user"),
103
+ "content": msg.get("content", "")
104
+ })
105
 
106
  # Add current user input
107
  messages.append({"role": "user", "content": user_input})
108
 
109
+ # Generate response
 
 
110
  response = await self.client.chat.completions.create(
111
  model=self.model,
112
  messages=messages,
 
114
  max_tokens=self.max_tokens
115
  )
116
 
117
+ return response.choices[0].message.content.strip()
 
 
 
118
 
119
  except Exception as e:
120
  log(f"❌ GPT error: {e}")
121
  raise
122
 
123
  async def startup(self, project_config: Dict) -> bool:
124
+ """Validate API key"""
125
  try:
126
+ # Test API key with a simple request
127
+ response = await self.client.models.list()
128
+ log(f"βœ… OpenAI API key validated, available models: {len(response.data)}")
 
 
 
 
129
  return True
130
  except Exception as e:
131
+ log(f"❌ Invalid OpenAI API key: {e}")
132
  return False