Spaces:
Building
Building
File size: 4,565 Bytes
edec17e 8304bb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
"""
Spark LLM Implementation
"""
import os
import httpx
import json
from typing import Dict, List, Any, AsyncIterator
from .llm_interface import LLMInterface
from utils.logger import log_info, log_error, log_warning, log_debug
# Get timeout from environment
DEFAULT_LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT_SECONDS", "60"))
MAX_RESPONSE_LENGTH = int(os.getenv("LLM_MAX_RESPONSE_LENGTH", "4096"))
class SparkLLM(LLMInterface):
"""Spark LLM integration with improved error handling"""
def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None):
super().__init__(settings)
self.spark_endpoint = spark_endpoint.rstrip("/")
self.spark_token = spark_token
self.provider_variant = provider_variant
self.timeout = self.settings.get("timeout", DEFAULT_LLM_TIMEOUT)
log_info(f"π SparkLLM initialized", endpoint=self.spark_endpoint, timeout=self.timeout)
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
"""Generate response with improved error handling and streaming support"""
headers = {
"Authorization": f"Bearer {self.spark_token}",
"Content-Type": "application/json"
}
# Build context messages
messages = []
if system_prompt:
messages.append({
"role": "system",
"content": system_prompt
})
for msg in context[-10:]: # Last 10 messages for context
messages.append({
"role": msg.get("role", "user"),
"content": msg.get("content", "")
})
messages.append({
"role": "user",
"content": user_input
})
payload = {
"messages": messages,
"mode": self.provider_variant,
"max_tokens": self.settings.get("max_tokens", 2048),
"temperature": self.settings.get("temperature", 0.7),
"stream": False # For now, no streaming
}
try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
with LogTimer(f"Spark LLM request"):
response = await client.post(
f"{self.spark_endpoint}/generate",
json=payload,
headers=headers
)
# Check for rate limiting
if response.status_code == 429:
retry_after = response.headers.get("Retry-After", "60")
log_warning(f"Rate limited by Spark", retry_after=retry_after)
raise httpx.HTTPStatusError(
f"Rate limited. Retry after {retry_after}s",
request=response.request,
response=response
)
response.raise_for_status()
result = response.json()
# Extract response
content = result.get("model_answer", "")
# Check response length
if len(content) > MAX_RESPONSE_LENGTH:
log_warning(f"Response exceeded max length, truncating",
original_length=len(content),
max_length=MAX_RESPONSE_LENGTH)
content = content[:MAX_RESPONSE_LENGTH] + "..."
return content
except httpx.TimeoutException:
log_error(f"Spark request timed out", timeout=self.timeout)
raise
except httpx.HTTPStatusError as e:
log_error(f"Spark HTTP error",
status_code=e.response.status_code,
response=e.response.text[:500])
raise
except Exception as e:
log_error("Spark unexpected error", error=str(e))
raise
def get_provider_name(self) -> str:
return f"spark-{self.provider_variant}"
def get_model_info(self) -> Dict[str, Any]:
return {
"provider": "spark",
"variant": self.provider_variant,
"endpoint": self.spark_endpoint,
"max_tokens": self.settings.get("max_tokens", 2048),
"temperature": self.settings.get("temperature", 0.7)
} |