File size: 4,565 Bytes
edec17e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8304bb2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""

Spark LLM Implementation

"""
import os
import httpx
import json
from typing import Dict, List, Any, AsyncIterator
from .llm_interface import LLMInterface
from utils.logger import log_info, log_error, log_warning, log_debug

# Get timeout from environment
DEFAULT_LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT_SECONDS", "60"))
MAX_RESPONSE_LENGTH = int(os.getenv("LLM_MAX_RESPONSE_LENGTH", "4096"))

class SparkLLM(LLMInterface):
    """Spark LLM integration with improved error handling"""

    def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None):
        super().__init__(settings)
        self.spark_endpoint = spark_endpoint.rstrip("/")
        self.spark_token = spark_token
        self.provider_variant = provider_variant
        self.timeout = self.settings.get("timeout", DEFAULT_LLM_TIMEOUT)
        log_info(f"πŸ”Œ SparkLLM initialized", endpoint=self.spark_endpoint, timeout=self.timeout)

    async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
        """Generate response with improved error handling and streaming support"""
        headers = {
            "Authorization": f"Bearer {self.spark_token}",
            "Content-Type": "application/json"
        }

        # Build context messages
        messages = []
        if system_prompt:
            messages.append({
                "role": "system",
                "content": system_prompt
            })

        for msg in context[-10:]:  # Last 10 messages for context
            messages.append({
                "role": msg.get("role", "user"),
                "content": msg.get("content", "")
            })

        messages.append({
            "role": "user",
            "content": user_input
        })

        payload = {
            "messages": messages,
            "mode": self.provider_variant,
            "max_tokens": self.settings.get("max_tokens", 2048),
            "temperature": self.settings.get("temperature", 0.7),
            "stream": False  # For now, no streaming
        }

        try:
            async with httpx.AsyncClient(timeout=self.timeout) as client:
                with LogTimer(f"Spark LLM request"):
                    response = await client.post(
                        f"{self.spark_endpoint}/generate",
                        json=payload,
                        headers=headers
                    )

                    # Check for rate limiting
                    if response.status_code == 429:
                        retry_after = response.headers.get("Retry-After", "60")
                        log_warning(f"Rate limited by Spark", retry_after=retry_after)
                        raise httpx.HTTPStatusError(
                            f"Rate limited. Retry after {retry_after}s",
                            request=response.request,
                            response=response
                        )

                    response.raise_for_status()
                    result = response.json()

                    # Extract response
                    content = result.get("model_answer", "")

                    # Check response length
                    if len(content) > MAX_RESPONSE_LENGTH:
                        log_warning(f"Response exceeded max length, truncating",
                                  original_length=len(content),
                                  max_length=MAX_RESPONSE_LENGTH)
                        content = content[:MAX_RESPONSE_LENGTH] + "..."

                    return content

        except httpx.TimeoutException:
            log_error(f"Spark request timed out", timeout=self.timeout)
            raise
        except httpx.HTTPStatusError as e:
            log_error(f"Spark HTTP error",
                     status_code=e.response.status_code,
                     response=e.response.text[:500])
            raise
        except Exception as e:
            log_error("Spark unexpected error", error=str(e))
            raise

    def get_provider_name(self) -> str:
        return f"spark-{self.provider_variant}"

    def get_model_info(self) -> Dict[str, Any]:
        return {
            "provider": "spark",
            "variant": self.provider_variant,
            "endpoint": self.spark_endpoint,
            "max_tokens": self.settings.get("max_tokens", 2048),
            "temperature": self.settings.get("temperature", 0.7)
        }