Spaces:
Building
Building
""" | |
Spark LLM Implementation | |
""" | |
import httpx | |
from typing import Dict, List, Any | |
from llm_interface import LLMInterface | |
from utils import log | |
class SparkLLM(LLMInterface): | |
"""Spark LLM integration""" | |
def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None): | |
super().__init__(settings) | |
self.spark_endpoint = spark_endpoint.rstrip("/") | |
self.spark_token = spark_token | |
self.provider_variant = provider_variant | |
log(f"π SparkLLM initialized with endpoint: {self.spark_endpoint}") | |
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: | |
"""Generate response from Spark LLM""" | |
headers = { | |
"Authorization": f"Bearer {self.spark_token}", | |
"Content-Type": "application/json" | |
} | |
# Build context messages | |
messages = [] | |
for msg in context: | |
messages.append({ | |
"role": msg.get("role", "user"), | |
"content": msg.get("content", "") | |
}) | |
payload = { | |
"user_input": user_input, | |
"system_prompt": system_prompt, | |
"context": messages, | |
"mode": self.provider_variant | |
} | |
try: | |
async with httpx.AsyncClient(timeout=60) as client: | |
response = await client.post( | |
f"{self.spark_endpoint}/generate", | |
json=payload, | |
headers=headers | |
) | |
response.raise_for_status() | |
result = response.json() | |
return result.get("model_answer", "") | |
except httpx.TimeoutException: | |
log("β±οΈ Spark request timed out") | |
raise | |
except Exception as e: | |
log(f"β Spark error: {e}") | |
raise | |
async def startup(self, project_config: Dict) -> bool: | |
"""Initialize Spark with project config""" | |
try: | |
headers = { | |
"Authorization": f"Bearer {self.spark_token}", | |
"Content-Type": "application/json" | |
} | |
# Extract version config | |
version = None | |
for v in project_config.get("versions", []): | |
if v.get("published"): | |
version = v | |
break | |
if not version: | |
log("β No published version found") | |
return False | |
llm_config = version.get("llm", {}) | |
payload = { | |
"project_name": project_config.get("name"), | |
"repo_id": llm_config.get("repo_id", ""), | |
"use_fine_tune": llm_config.get("use_fine_tune", False), | |
"fine_tune_zip": llm_config.get("fine_tune_zip", ""), | |
"generation_config": llm_config.get("generation_config", {}) | |
} | |
async with httpx.AsyncClient(timeout=30) as client: | |
response = await client.post( | |
f"{self.spark_endpoint}/startup", | |
json=payload, | |
headers=headers | |
) | |
response.raise_for_status() | |
log("β Spark startup successful") | |
return True | |
except Exception as e: | |
log(f"β Spark startup failed: {e}") | |
return False | |
def get_provider_name(self) -> str: | |
"""Get provider name""" | |
return f"spark-{self.provider_variant}" | |
def get_model_info(self) -> Dict[str, Any]: | |
"""Get model information""" | |
return { | |
"provider": "spark", | |
"variant": self.provider_variant, | |
"endpoint": self.spark_endpoint | |
} |