""" Spark LLM Implementation """ import os import httpx import json from typing import Dict, List, Any, AsyncIterator from .llm_interface import LLMInterface from utils.logger import log_info, log_error, log_warning, log_debug # Get timeout from environment DEFAULT_LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT_SECONDS", "60")) MAX_RESPONSE_LENGTH = int(os.getenv("LLM_MAX_RESPONSE_LENGTH", "4096")) class SparkLLM(LLMInterface): """Spark LLM integration with improved error handling""" def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None): super().__init__(settings) self.spark_endpoint = spark_endpoint.rstrip("/") self.spark_token = spark_token self.provider_variant = provider_variant self.timeout = self.settings.get("timeout", DEFAULT_LLM_TIMEOUT) log_info(f"🔌 SparkLLM initialized", endpoint=self.spark_endpoint, timeout=self.timeout) async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: """Generate response with improved error handling and streaming support""" headers = { "Authorization": f"Bearer {self.spark_token}", "Content-Type": "application/json" } # Build context messages messages = [] if system_prompt: messages.append({ "role": "system", "content": system_prompt }) for msg in context[-10:]: # Last 10 messages for context messages.append({ "role": msg.get("role", "user"), "content": msg.get("content", "") }) messages.append({ "role": "user", "content": user_input }) payload = { "messages": messages, "mode": self.provider_variant, "max_tokens": self.settings.get("max_tokens", 2048), "temperature": self.settings.get("temperature", 0.7), "stream": False # For now, no streaming } try: async with httpx.AsyncClient(timeout=self.timeout) as client: with LogTimer(f"Spark LLM request"): response = await client.post( f"{self.spark_endpoint}/generate", json=payload, headers=headers ) # Check for rate limiting if response.status_code == 429: retry_after = response.headers.get("Retry-After", "60") log_warning(f"Rate limited by Spark", retry_after=retry_after) raise httpx.HTTPStatusError( f"Rate limited. Retry after {retry_after}s", request=response.request, response=response ) response.raise_for_status() result = response.json() # Extract response content = result.get("model_answer", "") # Check response length if len(content) > MAX_RESPONSE_LENGTH: log_warning(f"Response exceeded max length, truncating", original_length=len(content), max_length=MAX_RESPONSE_LENGTH) content = content[:MAX_RESPONSE_LENGTH] + "..." return content except httpx.TimeoutException: log_error(f"Spark request timed out", timeout=self.timeout) raise except httpx.HTTPStatusError as e: log_error(f"Spark HTTP error", status_code=e.response.status_code, response=e.response.text[:500]) raise except Exception as e: log_error("Spark unexpected error", error=str(e)) raise def get_provider_name(self) -> str: return f"spark-{self.provider_variant}" def get_model_info(self) -> Dict[str, Any]: return { "provider": "spark", "variant": self.provider_variant, "endpoint": self.spark_endpoint, "max_tokens": self.settings.get("max_tokens", 2048), "temperature": self.settings.get("temperature", 0.7) }