Spaces:
Building
Building
""" | |
LLM Provider Interface for Flare | |
""" | |
import os | |
from abc import ABC, abstractmethod | |
from typing import Dict, List, Optional, Any | |
import httpx | |
from openai import AsyncOpenAI | |
from utils import log | |
class LLMInterface(ABC): | |
"""Abstract base class for LLM providers""" | |
def __init__(self, settings: Dict[str, Any] = None): | |
"""Initialize with provider-specific settings""" | |
self.settings = settings or {} | |
self.internal_prompt = self.settings.get("internal_prompt", "") | |
self.parameter_collection_config = self.settings.get("parameter_collection_config", {}) | |
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: | |
"""Generate response from LLM""" | |
pass | |
async def startup(self, project_config: Dict) -> bool: | |
"""Initialize LLM with project config""" | |
pass | |
class SparkLLM(LLMInterface): | |
"""Spark LLM integration""" | |
def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "spark", settings: Dict = None): | |
super().__init__(settings) | |
self.spark_endpoint = spark_endpoint.rstrip("/") | |
self.spark_token = spark_token | |
self.provider_variant = provider_variant | |
log(f"π SparkLLM initialized with endpoint: {self.spark_endpoint}, variant: {self.provider_variant}") | |
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: | |
"""Generate response from Spark""" | |
headers = { | |
"Authorization": f"Bearer {self.spark_token}", | |
"Content-Type": "application/json" | |
} | |
# Build payload | |
payload = { | |
"system_prompt": system_prompt, | |
"user_input": user_input, | |
"context": context | |
} | |
try: | |
log(f"π€ Spark request to {self.spark_endpoint}/generate") | |
async with httpx.AsyncClient(timeout=60) as client: | |
response = await client.post( | |
f"{self.spark_endpoint}/generate", | |
json=payload, | |
headers=headers | |
) | |
response.raise_for_status() | |
data = response.json() | |
# Try different response fields | |
raw = data.get("model_answer", "").strip() | |
if not raw: | |
raw = (data.get("assistant") or data.get("text", "")).strip() | |
return raw | |
except httpx.TimeoutException: | |
log("β±οΈ Spark timeout") | |
raise | |
except Exception as e: | |
log(f"β Spark error: {e}") | |
raise | |
async def startup(self, project_config: Dict) -> bool: | |
"""Send startup request to Spark""" | |
headers = { | |
"Authorization": f"Bearer {self.spark_token}", | |
"Content-Type": "application/json" | |
} | |
try: | |
log(f"π Sending startup to Spark for project: {project_config.get('project_name')}") | |
async with httpx.AsyncClient(timeout=30) as client: | |
response = await client.post( | |
f"{self.spark_endpoint}/startup", | |
json=project_config, | |
headers=headers | |
) | |
response.raise_for_status() | |
log("β Spark startup successful") | |
return True | |
except Exception as e: | |
log(f"β Spark startup failed: {e}") | |
return False | |
class GPT4oLLM(LLMInterface): | |
"""OpenAI GPT integration""" | |
def __init__(self, api_key: str, model: str = "gpt-4o-mini", settings: Dict = None): | |
super().__init__(settings) | |
self.api_key = api_key | |
self.model = model | |
self.client = AsyncOpenAI(api_key=api_key) | |
# Default GPT settings | |
self.temperature = settings.get("temperature", 0.3) if settings else 0.3 | |
self.max_tokens = settings.get("max_tokens", 512) if settings else 512 | |
log(f"β Initialized GPT LLM with model: {model}") | |
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: | |
"""Generate response from GPT""" | |
try: | |
# Convert context to OpenAI format | |
messages = [] | |
# Add system prompt | |
messages.append({"role": "system", "content": system_prompt}) | |
# Add conversation history | |
for msg in context[-10:]: # Last 10 messages | |
role = "user" if msg["role"] == "user" else "assistant" | |
messages.append({"role": role, "content": msg["content"]}) | |
# Add current user input | |
messages.append({"role": "user", "content": user_input}) | |
log(f"π€ GPT request with {len(messages)} messages") | |
# Call OpenAI API | |
response = await self.client.chat.completions.create( | |
model=self.model, | |
messages=messages, | |
temperature=self.temperature, | |
max_tokens=self.max_tokens | |
) | |
content = response.choices[0].message.content | |
log(f"β GPT response received: {len(content)} chars") | |
return content | |
except Exception as e: | |
log(f"β GPT error: {e}") | |
raise | |
async def startup(self, project_config: Dict) -> bool: | |
"""GPT doesn't need startup - just validate API key""" | |
try: | |
# Test API key with a minimal request | |
response = await self.client.chat.completions.create( | |
model=self.model, | |
messages=[{"role": "user", "content": "test"}], | |
max_tokens=5 | |
) | |
log("β GPT API key validated") | |
return True | |
except Exception as e: | |
log(f"β GPT API key validation failed: {e}") | |
return False |