flare / llm_interface.py
ciyidogan's picture
Update llm_interface.py
fc0299f verified
raw
history blame
6.1 kB
"""
LLM Provider Interface for Flare
"""
import os
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any
import httpx
from openai import AsyncOpenAI
from utils import log
class LLMInterface(ABC):
"""Abstract base class for LLM providers"""
def __init__(self, settings: Dict[str, Any] = None):
"""Initialize with provider-specific settings"""
self.settings = settings or {}
self.internal_prompt = self.settings.get("internal_prompt", "")
self.parameter_collection_config = self.settings.get("parameter_collection_config", {})
@abstractmethod
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
"""Generate response from LLM"""
pass
@abstractmethod
async def startup(self, project_config: Dict) -> bool:
"""Initialize LLM with project config"""
pass
class SparkLLM(LLMInterface):
"""Spark LLM integration"""
def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "spark", settings: Dict = None):
super().__init__(settings)
self.spark_endpoint = spark_endpoint.rstrip("/")
self.spark_token = spark_token
self.provider_variant = provider_variant
log(f"πŸ”Œ SparkLLM initialized with endpoint: {self.spark_endpoint}, variant: {self.provider_variant}")
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
"""Generate response from Spark"""
headers = {
"Authorization": f"Bearer {self.spark_token}",
"Content-Type": "application/json"
}
# Build payload
payload = {
"system_prompt": system_prompt,
"user_input": user_input,
"context": context
}
try:
log(f"πŸ“€ Spark request to {self.spark_endpoint}/generate")
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(
f"{self.spark_endpoint}/generate",
json=payload,
headers=headers
)
response.raise_for_status()
data = response.json()
# Try different response fields
raw = data.get("model_answer", "").strip()
if not raw:
raw = (data.get("assistant") or data.get("text", "")).strip()
return raw
except httpx.TimeoutException:
log("⏱️ Spark timeout")
raise
except Exception as e:
log(f"❌ Spark error: {e}")
raise
async def startup(self, project_config: Dict) -> bool:
"""Send startup request to Spark"""
headers = {
"Authorization": f"Bearer {self.spark_token}",
"Content-Type": "application/json"
}
try:
log(f"πŸš€ Sending startup to Spark for project: {project_config.get('project_name')}")
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(
f"{self.spark_endpoint}/startup",
json=project_config,
headers=headers
)
response.raise_for_status()
log("βœ… Spark startup successful")
return True
except Exception as e:
log(f"❌ Spark startup failed: {e}")
return False
class GPT4oLLM(LLMInterface):
"""OpenAI GPT integration"""
def __init__(self, api_key: str, model: str = "gpt-4o-mini", settings: Dict = None):
super().__init__(settings)
self.api_key = api_key
self.model = model
self.client = AsyncOpenAI(api_key=api_key)
# Default GPT settings
self.temperature = settings.get("temperature", 0.3) if settings else 0.3
self.max_tokens = settings.get("max_tokens", 512) if settings else 512
log(f"βœ… Initialized GPT LLM with model: {model}")
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
"""Generate response from GPT"""
try:
# Convert context to OpenAI format
messages = []
# Add system prompt
messages.append({"role": "system", "content": system_prompt})
# Add conversation history
for msg in context[-10:]: # Last 10 messages
role = "user" if msg["role"] == "user" else "assistant"
messages.append({"role": role, "content": msg["content"]})
# Add current user input
messages.append({"role": "user", "content": user_input})
log(f"πŸ“€ GPT request with {len(messages)} messages")
# Call OpenAI API
response = await self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=self.temperature,
max_tokens=self.max_tokens
)
content = response.choices[0].message.content
log(f"βœ… GPT response received: {len(content)} chars")
return content
except Exception as e:
log(f"❌ GPT error: {e}")
raise
async def startup(self, project_config: Dict) -> bool:
"""GPT doesn't need startup - just validate API key"""
try:
# Test API key with a minimal request
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": "test"}],
max_tokens=5
)
log("βœ… GPT API key validated")
return True
except Exception as e:
log(f"❌ GPT API key validation failed: {e}")
return False