Spaces:
Building
Building
File size: 4,564 Bytes
8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 9f79da5 8304bb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
"""
Spark LLM Implementation
"""
import os
import httpx
import json
from typing import Dict, List, Any, AsyncIterator
from llm_interface import LLMInterface
from utils.logger import log_info, log_error, log_warning, log_debug
# Get timeout from environment
DEFAULT_LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT_SECONDS", "60"))
MAX_RESPONSE_LENGTH = int(os.getenv("LLM_MAX_RESPONSE_LENGTH", "4096"))
class SparkLLM(LLMInterface):
"""Spark LLM integration with improved error handling"""
def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None):
super().__init__(settings)
self.spark_endpoint = spark_endpoint.rstrip("/")
self.spark_token = spark_token
self.provider_variant = provider_variant
self.timeout = self.settings.get("timeout", DEFAULT_LLM_TIMEOUT)
log_info(f"π SparkLLM initialized", endpoint=self.spark_endpoint, timeout=self.timeout)
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
"""Generate response with improved error handling and streaming support"""
headers = {
"Authorization": f"Bearer {self.spark_token}",
"Content-Type": "application/json"
}
# Build context messages
messages = []
if system_prompt:
messages.append({
"role": "system",
"content": system_prompt
})
for msg in context[-10:]: # Last 10 messages for context
messages.append({
"role": msg.get("role", "user"),
"content": msg.get("content", "")
})
messages.append({
"role": "user",
"content": user_input
})
payload = {
"messages": messages,
"mode": self.provider_variant,
"max_tokens": self.settings.get("max_tokens", 2048),
"temperature": self.settings.get("temperature", 0.7),
"stream": False # For now, no streaming
}
try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
with LogTimer(f"Spark LLM request"):
response = await client.post(
f"{self.spark_endpoint}/generate",
json=payload,
headers=headers
)
# Check for rate limiting
if response.status_code == 429:
retry_after = response.headers.get("Retry-After", "60")
log_warning(f"Rate limited by Spark", retry_after=retry_after)
raise httpx.HTTPStatusError(
f"Rate limited. Retry after {retry_after}s",
request=response.request,
response=response
)
response.raise_for_status()
result = response.json()
# Extract response
content = result.get("model_answer", "")
# Check response length
if len(content) > MAX_RESPONSE_LENGTH:
log_warning(f"Response exceeded max length, truncating",
original_length=len(content),
max_length=MAX_RESPONSE_LENGTH)
content = content[:MAX_RESPONSE_LENGTH] + "..."
return content
except httpx.TimeoutException:
log_error(f"Spark request timed out", timeout=self.timeout)
raise
except httpx.HTTPStatusError as e:
log_error(f"Spark HTTP error",
status_code=e.response.status_code,
response=e.response.text[:500])
raise
except Exception as e:
log_error("Spark unexpected error", error=str(e))
raise
def get_provider_name(self) -> str:
return f"spark-{self.provider_variant}"
def get_model_info(self) -> Dict[str, Any]:
return {
"provider": "spark",
"variant": self.provider_variant,
"endpoint": self.spark_endpoint,
"max_tokens": self.settings.get("max_tokens", 2048),
"temperature": self.settings.get("temperature", 0.7)
} |