ciyidogan commited on
Commit
e0ae5ce
Β·
verified Β·
1 Parent(s): a71c268

Create llm_spark.py

Browse files
Files changed (1) hide show
  1. llm_spark.py +109 -0
llm_spark.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Spark LLM Implementation
3
+ """
4
+ import httpx
5
+ from typing import Dict, List, Any
6
+ from llm_interface import LLMInterface
7
+ from utils import log
8
+
9
+ class SparkLLM(LLMInterface):
10
+ """Spark LLM integration"""
11
+
12
+ def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None):
13
+ super().__init__(settings)
14
+ self.spark_endpoint = spark_endpoint.rstrip("/")
15
+ self.spark_token = spark_token
16
+ self.provider_variant = provider_variant
17
+ log(f"πŸ”Œ SparkLLM initialized with endpoint: {self.spark_endpoint}")
18
+
19
+ async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
20
+ """Generate response from Spark LLM"""
21
+ headers = {
22
+ "Authorization": f"Bearer {self.spark_token}",
23
+ "Content-Type": "application/json"
24
+ }
25
+
26
+ # Build context messages
27
+ messages = []
28
+ for msg in context:
29
+ messages.append({
30
+ "role": msg.get("role", "user"),
31
+ "content": msg.get("content", "")
32
+ })
33
+
34
+ payload = {
35
+ "user_input": user_input,
36
+ "system_prompt": system_prompt,
37
+ "context": messages,
38
+ "mode": self.provider_variant
39
+ }
40
+
41
+ try:
42
+ async with httpx.AsyncClient(timeout=60) as client:
43
+ response = await client.post(
44
+ f"{self.spark_endpoint}/generate",
45
+ json=payload,
46
+ headers=headers
47
+ )
48
+ response.raise_for_status()
49
+ result = response.json()
50
+ return result.get("model_answer", "")
51
+ except httpx.TimeoutException:
52
+ log("⏱️ Spark request timed out")
53
+ raise
54
+ except Exception as e:
55
+ log(f"❌ Spark error: {e}")
56
+ raise
57
+
58
+ async def startup(self, project_config: Dict) -> bool:
59
+ """Initialize Spark with project config"""
60
+ try:
61
+ headers = {
62
+ "Authorization": f"Bearer {self.spark_token}",
63
+ "Content-Type": "application/json"
64
+ }
65
+
66
+ # Extract version config
67
+ version = None
68
+ for v in project_config.get("versions", []):
69
+ if v.get("published"):
70
+ version = v
71
+ break
72
+
73
+ if not version:
74
+ log("❌ No published version found")
75
+ return False
76
+
77
+ llm_config = version.get("llm", {})
78
+ payload = {
79
+ "project_name": project_config.get("name"),
80
+ "repo_id": llm_config.get("repo_id", ""),
81
+ "use_fine_tune": llm_config.get("use_fine_tune", False),
82
+ "fine_tune_zip": llm_config.get("fine_tune_zip", ""),
83
+ "generation_config": llm_config.get("generation_config", {})
84
+ }
85
+
86
+ async with httpx.AsyncClient(timeout=30) as client:
87
+ response = await client.post(
88
+ f"{self.spark_endpoint}/startup",
89
+ json=payload,
90
+ headers=headers
91
+ )
92
+ response.raise_for_status()
93
+ log("βœ… Spark startup successful")
94
+ return True
95
+ except Exception as e:
96
+ log(f"❌ Spark startup failed: {e}")
97
+ return False
98
+
99
+ def get_provider_name(self) -> str:
100
+ """Get provider name"""
101
+ return f"spark-{self.provider_variant}"
102
+
103
+ def get_model_info(self) -> Dict[str, Any]:
104
+ """Get model information"""
105
+ return {
106
+ "provider": "spark",
107
+ "variant": self.provider_variant,
108
+ "endpoint": self.spark_endpoint
109
+ }