ciyidogan commited on
Commit
5dba816
Β·
verified Β·
1 Parent(s): e5d2ec0

Update llm_spark.py

Browse files
Files changed (1) hide show
  1. llm_spark.py +71 -64
llm_spark.py CHANGED
@@ -1,23 +1,30 @@
1
  """
2
  Spark LLM Implementation
3
  """
 
4
  import httpx
5
- from typing import Dict, List, Any
 
6
  from llm_interface import LLMInterface
7
  from logger import log_info, log_error, log_warning, log_debug
8
 
 
 
 
 
9
  class SparkLLM(LLMInterface):
10
- """Spark LLM integration"""
11
 
12
  def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None):
13
  super().__init__(settings)
14
  self.spark_endpoint = spark_endpoint.rstrip("/")
15
  self.spark_token = spark_token
16
  self.provider_variant = provider_variant
17
- log_info(f"πŸ”Œ SparkLLM initialized with endpoint: {self.spark_endpoint}")
 
18
 
19
  async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
20
- """Generate response from Spark LLM"""
21
  headers = {
22
  "Authorization": f"Bearer {self.spark_token}",
23
  "Content-Type": "application/json"
@@ -25,85 +32,85 @@ class SparkLLM(LLMInterface):
25
 
26
  # Build context messages
27
  messages = []
28
- for msg in context:
 
 
 
 
 
 
29
  messages.append({
30
  "role": msg.get("role", "user"),
31
  "content": msg.get("content", "")
32
  })
33
 
 
 
 
 
 
34
  payload = {
35
- "user_input": user_input,
36
- "system_prompt": system_prompt,
37
- "context": messages,
38
- "mode": self.provider_variant
 
39
  }
40
 
41
  try:
42
- async with httpx.AsyncClient(timeout=60) as client:
43
- response = await client.post(
44
- f"{self.spark_endpoint}/generate",
45
- json=payload,
46
- headers=headers
47
- )
48
- response.raise_for_status()
49
- result = response.json()
50
- return result.get("model_answer", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  except httpx.TimeoutException:
52
- log_warning("⏱️ Spark request timed out")
53
  raise
54
- except Exception as e:
55
- log_error("❌ Spark error", e)
 
 
56
  raise
57
-
58
- async def startup(self, project_config: Dict) -> bool:
59
- """Initialize Spark with project config"""
60
- try:
61
- headers = {
62
- "Authorization": f"Bearer {self.spark_token}",
63
- "Content-Type": "application/json"
64
- }
65
-
66
- # Extract version config
67
- version = None
68
- for v in project_config.get("versions", []):
69
- if v.get("published"):
70
- version = v
71
- break
72
-
73
- if not version:
74
- log_info("❌ No published version found")
75
- return False
76
-
77
- llm_config = version.get("llm", {})
78
- payload = {
79
- "project_name": project_config.get("name"),
80
- "repo_id": llm_config.get("repo_id", ""),
81
- "use_fine_tune": llm_config.get("use_fine_tune", False),
82
- "fine_tune_zip": llm_config.get("fine_tune_zip", ""),
83
- "generation_config": llm_config.get("generation_config", {})
84
- }
85
-
86
- async with httpx.AsyncClient(timeout=30) as client:
87
- response = await client.post(
88
- f"{self.spark_endpoint}/startup",
89
- json=payload,
90
- headers=headers
91
- )
92
- response.raise_for_status()
93
- log_info("βœ… Spark startup successful")
94
- return True
95
  except Exception as e:
96
- log_error("❌ Spark startup failed", e)
97
- return False
98
 
99
  def get_provider_name(self) -> str:
100
- """Get provider name"""
101
  return f"spark-{self.provider_variant}"
102
 
103
  def get_model_info(self) -> Dict[str, Any]:
104
- """Get model information"""
105
  return {
106
  "provider": "spark",
107
  "variant": self.provider_variant,
108
- "endpoint": self.spark_endpoint
 
 
109
  }
 
1
  """
2
  Spark LLM Implementation
3
  """
4
+ import os
5
  import httpx
6
+ import json
7
+ from typing import Dict, List, Any, AsyncIterator
8
  from llm_interface import LLMInterface
9
  from logger import log_info, log_error, log_warning, log_debug
10
 
11
+ # Get timeout from environment
12
+ DEFAULT_LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT_SECONDS", "60"))
13
+ MAX_RESPONSE_LENGTH = int(os.getenv("LLM_MAX_RESPONSE_LENGTH", "4096"))
14
+
15
  class SparkLLM(LLMInterface):
16
+ """Spark LLM integration with improved error handling"""
17
 
18
  def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None):
19
  super().__init__(settings)
20
  self.spark_endpoint = spark_endpoint.rstrip("/")
21
  self.spark_token = spark_token
22
  self.provider_variant = provider_variant
23
+ self.timeout = self.settings.get("timeout", DEFAULT_LLM_TIMEOUT)
24
+ log_info(f"πŸ”Œ SparkLLM initialized", endpoint=self.spark_endpoint, timeout=self.timeout)
25
 
26
  async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
27
+ """Generate response with improved error handling and streaming support"""
28
  headers = {
29
  "Authorization": f"Bearer {self.spark_token}",
30
  "Content-Type": "application/json"
 
32
 
33
  # Build context messages
34
  messages = []
35
+ if system_prompt:
36
+ messages.append({
37
+ "role": "system",
38
+ "content": system_prompt
39
+ })
40
+
41
+ for msg in context[-10:]: # Last 10 messages for context
42
  messages.append({
43
  "role": msg.get("role", "user"),
44
  "content": msg.get("content", "")
45
  })
46
 
47
+ messages.append({
48
+ "role": "user",
49
+ "content": user_input
50
+ })
51
+
52
  payload = {
53
+ "messages": messages,
54
+ "mode": self.provider_variant,
55
+ "max_tokens": self.settings.get("max_tokens", 2048),
56
+ "temperature": self.settings.get("temperature", 0.7),
57
+ "stream": False # For now, no streaming
58
  }
59
 
60
  try:
61
+ async with httpx.AsyncClient(timeout=self.timeout) as client:
62
+ with LogTimer(f"Spark LLM request"):
63
+ response = await client.post(
64
+ f"{self.spark_endpoint}/generate",
65
+ json=payload,
66
+ headers=headers
67
+ )
68
+
69
+ # Check for rate limiting
70
+ if response.status_code == 429:
71
+ retry_after = response.headers.get("Retry-After", "60")
72
+ log_warning(f"Rate limited by Spark", retry_after=retry_after)
73
+ raise httpx.HTTPStatusError(
74
+ f"Rate limited. Retry after {retry_after}s",
75
+ request=response.request,
76
+ response=response
77
+ )
78
+
79
+ response.raise_for_status()
80
+ result = response.json()
81
+
82
+ # Extract response
83
+ content = result.get("model_answer", "")
84
+
85
+ # Check response length
86
+ if len(content) > MAX_RESPONSE_LENGTH:
87
+ log_warning(f"Response exceeded max length, truncating",
88
+ original_length=len(content),
89
+ max_length=MAX_RESPONSE_LENGTH)
90
+ content = content[:MAX_RESPONSE_LENGTH] + "..."
91
+
92
+ return content
93
+
94
  except httpx.TimeoutException:
95
+ log_error(f"Spark request timed out", timeout=self.timeout)
96
  raise
97
+ except httpx.HTTPStatusError as e:
98
+ log_error(f"Spark HTTP error",
99
+ status_code=e.response.status_code,
100
+ response=e.response.text[:500])
101
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  except Exception as e:
103
+ log_error("Spark unexpected error", error=str(e))
104
+ raise
105
 
106
  def get_provider_name(self) -> str:
 
107
  return f"spark-{self.provider_variant}"
108
 
109
  def get_model_info(self) -> Dict[str, Any]:
 
110
  return {
111
  "provider": "spark",
112
  "variant": self.provider_variant,
113
+ "endpoint": self.spark_endpoint,
114
+ "max_tokens": self.settings.get("max_tokens", 2048),
115
+ "temperature": self.settings.get("temperature", 0.7)
116
  }