""" LLM Provider Interface for Flare """ import os from abc import ABC, abstractmethod from typing import Dict, List, Optional, Any import httpx from openai import AsyncOpenAI from utils import log class LLMInterface(ABC): """Abstract base class for LLM providers""" def __init__(self, settings: Dict[str, Any] = None): """Initialize with provider settings""" self.settings = settings or {} self.internal_prompt = self.settings.get("internal_prompt", "") self.parameter_collection_config = self.settings.get("parameter_collection_config", {}) @abstractmethod async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: """Generate response from LLM""" pass @abstractmethod async def startup(self, project_config: Dict) -> bool: """Initialize LLM with project config""" pass class SparkLLM(LLMInterface): """Spark LLM integration""" def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None): super().__init__(settings) self.spark_endpoint = spark_endpoint.rstrip("/") self.spark_token = spark_token self.provider_variant = provider_variant log(f"🔌 SparkLLM initialized with endpoint: {self.spark_endpoint}") async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: """Generate response from Spark LLM""" headers = { "Authorization": f"Bearer {self.spark_token}", "Content-Type": "application/json" } # Build payload payload = { "system_prompt": system_prompt, "user_input": user_input, "context": context } try: async with httpx.AsyncClient(timeout=60) as client: response = await client.post( f"{self.spark_endpoint}/generate", json=payload, headers=headers ) response.raise_for_status() data = response.json() # Try different response fields raw = data.get("model_answer", "").strip() if not raw: raw = (data.get("assistant") or data.get("text", "")).strip() return raw except Exception as e: log(f"❌ Spark error: {e}") raise async def startup(self, project_config: Dict) -> bool: """Send startup request to Spark""" headers = { "Authorization": f"Bearer {self.spark_token}", "Content-Type": "application/json" } # Extract required fields from project config body = { "work_mode": self.provider_variant, "cloud_token": self.spark_token, "project_name": project_config.get("name"), "project_version": project_config.get("version_id"), "repo_id": project_config.get("repo_id"), "generation_config": project_config.get("generation_config", {}), "use_fine_tune": project_config.get("use_fine_tune", False), "fine_tune_zip": project_config.get("fine_tune_zip", "") } try: async with httpx.AsyncClient(timeout=10) as client: response = await client.post( f"{self.spark_endpoint}/startup", json=body, headers=headers ) if response.status_code >= 400: log(f"❌ Spark startup failed: {response.status_code} - {response.text}") return False log(f"✅ Spark acknowledged startup ({response.status_code})") return True except Exception as e: log(f"⚠️ Spark startup error: {e}") return False class GPT4oLLM(LLMInterface): """OpenAI GPT integration""" def __init__(self, api_key: str, model: str = "gpt-4o-mini", settings: Dict[str, Any] = None): super().__init__(settings) self.api_key = api_key self.model = self._map_model_name(model) self.client = AsyncOpenAI(api_key=api_key) # Extract model-specific settings self.temperature = settings.get("temperature", 0.7) if settings else 0.7 self.max_tokens = settings.get("max_tokens", 4096) if settings else 4096 log(f"✅ Initialized GPT LLM with model: {self.model}") def _map_model_name(self, model: str) -> str: """Map provider name to actual model name""" mappings = { "gpt4o": "gpt-4", "gpt4o-mini": "gpt-4o-mini" } return mappings.get(model, model) async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: """Generate response from OpenAI""" try: # Build messages messages = [{"role": "system", "content": system_prompt}] # Add context for msg in context: messages.append({ "role": msg.get("role", "user"), "content": msg.get("content", "") }) # Add current user input messages.append({"role": "user", "content": user_input}) # Call OpenAI response = await self.client.chat.completions.create( model=self.model, messages=messages, temperature=self.temperature, max_tokens=self.max_tokens ) return response.choices[0].message.content.strip() except Exception as e: log(f"❌ OpenAI error: {e}") raise async def startup(self, project_config: Dict) -> bool: """GPT doesn't need startup, always return True""" log("✅ GPT provider ready (no startup needed)") return True