flare / llm_spark.py
ciyidogan's picture
Create llm_spark.py
e0ae5ce verified
raw
history blame
3.85 kB
"""
Spark LLM Implementation
"""
import httpx
from typing import Dict, List, Any
from llm_interface import LLMInterface
from utils import log
class SparkLLM(LLMInterface):
"""Spark LLM integration"""
def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None):
super().__init__(settings)
self.spark_endpoint = spark_endpoint.rstrip("/")
self.spark_token = spark_token
self.provider_variant = provider_variant
log(f"πŸ”Œ SparkLLM initialized with endpoint: {self.spark_endpoint}")
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
"""Generate response from Spark LLM"""
headers = {
"Authorization": f"Bearer {self.spark_token}",
"Content-Type": "application/json"
}
# Build context messages
messages = []
for msg in context:
messages.append({
"role": msg.get("role", "user"),
"content": msg.get("content", "")
})
payload = {
"user_input": user_input,
"system_prompt": system_prompt,
"context": messages,
"mode": self.provider_variant
}
try:
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(
f"{self.spark_endpoint}/generate",
json=payload,
headers=headers
)
response.raise_for_status()
result = response.json()
return result.get("model_answer", "")
except httpx.TimeoutException:
log("⏱️ Spark request timed out")
raise
except Exception as e:
log(f"❌ Spark error: {e}")
raise
async def startup(self, project_config: Dict) -> bool:
"""Initialize Spark with project config"""
try:
headers = {
"Authorization": f"Bearer {self.spark_token}",
"Content-Type": "application/json"
}
# Extract version config
version = None
for v in project_config.get("versions", []):
if v.get("published"):
version = v
break
if not version:
log("❌ No published version found")
return False
llm_config = version.get("llm", {})
payload = {
"project_name": project_config.get("name"),
"repo_id": llm_config.get("repo_id", ""),
"use_fine_tune": llm_config.get("use_fine_tune", False),
"fine_tune_zip": llm_config.get("fine_tune_zip", ""),
"generation_config": llm_config.get("generation_config", {})
}
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(
f"{self.spark_endpoint}/startup",
json=payload,
headers=headers
)
response.raise_for_status()
log("βœ… Spark startup successful")
return True
except Exception as e:
log(f"❌ Spark startup failed: {e}")
return False
def get_provider_name(self) -> str:
"""Get provider name"""
return f"spark-{self.provider_variant}"
def get_model_info(self) -> Dict[str, Any]:
"""Get model information"""
return {
"provider": "spark",
"variant": self.provider_variant,
"endpoint": self.spark_endpoint
}