Spaces:
Building
Building
File size: 3,854 Bytes
e0ae5ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
"""
Spark LLM Implementation
"""
import httpx
from typing import Dict, List, Any
from llm_interface import LLMInterface
from utils import log
class SparkLLM(LLMInterface):
"""Spark LLM integration"""
def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None):
super().__init__(settings)
self.spark_endpoint = spark_endpoint.rstrip("/")
self.spark_token = spark_token
self.provider_variant = provider_variant
log(f"π SparkLLM initialized with endpoint: {self.spark_endpoint}")
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
"""Generate response from Spark LLM"""
headers = {
"Authorization": f"Bearer {self.spark_token}",
"Content-Type": "application/json"
}
# Build context messages
messages = []
for msg in context:
messages.append({
"role": msg.get("role", "user"),
"content": msg.get("content", "")
})
payload = {
"user_input": user_input,
"system_prompt": system_prompt,
"context": messages,
"mode": self.provider_variant
}
try:
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(
f"{self.spark_endpoint}/generate",
json=payload,
headers=headers
)
response.raise_for_status()
result = response.json()
return result.get("model_answer", "")
except httpx.TimeoutException:
log("β±οΈ Spark request timed out")
raise
except Exception as e:
log(f"β Spark error: {e}")
raise
async def startup(self, project_config: Dict) -> bool:
"""Initialize Spark with project config"""
try:
headers = {
"Authorization": f"Bearer {self.spark_token}",
"Content-Type": "application/json"
}
# Extract version config
version = None
for v in project_config.get("versions", []):
if v.get("published"):
version = v
break
if not version:
log("β No published version found")
return False
llm_config = version.get("llm", {})
payload = {
"project_name": project_config.get("name"),
"repo_id": llm_config.get("repo_id", ""),
"use_fine_tune": llm_config.get("use_fine_tune", False),
"fine_tune_zip": llm_config.get("fine_tune_zip", ""),
"generation_config": llm_config.get("generation_config", {})
}
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(
f"{self.spark_endpoint}/startup",
json=payload,
headers=headers
)
response.raise_for_status()
log("β
Spark startup successful")
return True
except Exception as e:
log(f"β Spark startup failed: {e}")
return False
def get_provider_name(self) -> str:
"""Get provider name"""
return f"spark-{self.provider_variant}"
def get_model_info(self) -> Dict[str, Any]:
"""Get model information"""
return {
"provider": "spark",
"variant": self.provider_variant,
"endpoint": self.spark_endpoint
} |