File size: 4,803 Bytes
ab7e98d
 
 
 
 
fc0299f
ab7e98d
 
 
 
 
 
 
fc0299f
2d2ab61
fc0299f
 
 
 
ab7e98d
 
 
 
 
 
 
 
 
 
 
2d2ab61
ab7e98d
2d2ab61
fc0299f
ab7e98d
 
fc0299f
2d2ab61
ab7e98d
 
2d2ab61
ab7e98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d2ab61
 
fc0299f
 
 
 
2d2ab61
fc0299f
 
 
 
2d2ab61
 
 
 
 
fc0299f
 
 
2d2ab61
fc0299f
2d2ab61
 
fc0299f
2d2ab61
 
 
 
 
 
fc0299f
 
 
 
2d2ab61
fc0299f
 
 
 
 
 
 
2d2ab61
fc0299f
ab7e98d
fc0299f
ab7e98d
 
 
2d2ab61
ab7e98d
2d2ab61
 
 
ab7e98d
 
2d2ab61
ab7e98d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
LLM Provider Interface for Flare
"""
import os
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any
import httpx
from openai import AsyncOpenAI
from utils import log

class LLMInterface(ABC):
    """Abstract base class for LLM providers"""
    
    def __init__(self, settings: Dict[str, Any] = None):
        """Initialize with settings"""
        self.settings = settings or {}
        self.internal_prompt = self.settings.get("internal_prompt", "")
        self.parameter_collection_config = self.settings.get("parameter_collection_config", {})
    
    @abstractmethod
    async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
        """Generate response from LLM"""
        pass
    
    @abstractmethod
    async def startup(self, project_config: Dict) -> bool:
        """Initialize LLM with project config"""
        pass

class SparkLLM(LLMInterface):
    """Spark integration for HuggingFace"""
    
    def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "spark-cloud", settings: Dict[str, Any] = None):
        super().__init__(settings)
        self.spark_endpoint = spark_endpoint.rstrip("/")
        self.spark_token = spark_token
        self.provider_variant = provider_variant
        log(f"πŸ”Œ SparkLLM initialized with endpoint: {self.spark_endpoint}")
    
    async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
        """Generate response using Spark"""
        headers = {
            "Authorization": f"Bearer {self.spark_token}",
            "Content-Type": "application/json"
        }
        
        payload = {
            "system_prompt": system_prompt,
            "user_input": user_input,
            "context": context
        }
        
        try:
            async with httpx.AsyncClient(timeout=60) as client:
                response = await client.post(
                    f"{self.spark_endpoint}/generate",
                    json=payload,
                    headers=headers
                )
                response.raise_for_status()
                data = response.json()
                
                # Try different response fields
                raw = data.get("model_answer", "").strip()
                if not raw:
                    raw = (data.get("assistant") or data.get("text", "")).strip()
                
                return raw
        except Exception as e:
            log(f"❌ Spark error: {e}")
            raise
    
    async def startup(self, project_config: Dict) -> bool:
        """Send startup request to Spark"""
        # Implement if needed for Spark startup notification
        return True

class GPT4oLLM(LLMInterface):
    """OpenAI GPT integration"""
    
    def __init__(self, api_key: str, model: str = "gpt-4o-mini", settings: Dict[str, Any] = None):
        super().__init__(settings)
        self.api_key = api_key
        self.model = model
        self.client = AsyncOpenAI(api_key=api_key)
        
        # Extract settings
        self.temperature = settings.get("temperature", 0.7) if settings else 0.7
        self.max_tokens = settings.get("max_tokens", 4096) if settings else 4096
        
        log(f"βœ… Initialized GPT LLM with model: {model}")
    
    async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
        """Generate response using OpenAI GPT"""
        try:
            # Build messages
            messages = [{"role": "system", "content": system_prompt}]
            
            # Add context
            for msg in context:
                messages.append({
                    "role": msg.get("role", "user"),
                    "content": msg.get("content", "")
                })
            
            # Add current user input
            messages.append({"role": "user", "content": user_input})
            
            # Generate response
            response = await self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                temperature=self.temperature,
                max_tokens=self.max_tokens
            )
            
            return response.choices[0].message.content.strip()
            
        except Exception as e:
            log(f"❌ GPT error: {e}")
            raise
    
    async def startup(self, project_config: Dict) -> bool:
        """Validate API key"""
        try:
            # Test API key with a simple request
            response = await self.client.models.list()
            log(f"βœ… OpenAI API key validated, available models: {len(response.data)}")
            return True
        except Exception as e:
            log(f"❌ Invalid OpenAI API key: {e}")
            return False