""" Spark LLM Implementation """ import httpx from typing import Dict, List, Any from llm_interface import LLMInterface from utils import log class SparkLLM(LLMInterface): """Spark LLM integration""" def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None): super().__init__(settings) self.spark_endpoint = spark_endpoint.rstrip("/") self.spark_token = spark_token self.provider_variant = provider_variant log(f"🔌 SparkLLM initialized with endpoint: {self.spark_endpoint}") async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: """Generate response from Spark LLM""" headers = { "Authorization": f"Bearer {self.spark_token}", "Content-Type": "application/json" } # Build context messages messages = [] for msg in context: messages.append({ "role": msg.get("role", "user"), "content": msg.get("content", "") }) payload = { "user_input": user_input, "system_prompt": system_prompt, "context": messages, "mode": self.provider_variant } try: async with httpx.AsyncClient(timeout=60) as client: response = await client.post( f"{self.spark_endpoint}/generate", json=payload, headers=headers ) response.raise_for_status() result = response.json() return result.get("model_answer", "") except httpx.TimeoutException: log("⏱️ Spark request timed out") raise except Exception as e: log(f"❌ Spark error: {e}") raise async def startup(self, project_config: Dict) -> bool: """Initialize Spark with project config""" try: headers = { "Authorization": f"Bearer {self.spark_token}", "Content-Type": "application/json" } # Extract version config version = None for v in project_config.get("versions", []): if v.get("published"): version = v break if not version: log("❌ No published version found") return False llm_config = version.get("llm", {}) payload = { "project_name": project_config.get("name"), "repo_id": llm_config.get("repo_id", ""), "use_fine_tune": llm_config.get("use_fine_tune", False), "fine_tune_zip": llm_config.get("fine_tune_zip", ""), "generation_config": llm_config.get("generation_config", {}) } async with httpx.AsyncClient(timeout=30) as client: response = await client.post( f"{self.spark_endpoint}/startup", json=payload, headers=headers ) response.raise_for_status() log("✅ Spark startup successful") return True except Exception as e: log(f"❌ Spark startup failed: {e}") return False def get_provider_name(self) -> str: """Get provider name""" return f"spark-{self.provider_variant}" def get_model_info(self) -> Dict[str, Any]: """Get model information""" return { "provider": "spark", "variant": self.provider_variant, "endpoint": self.spark_endpoint }