File size: 4,328 Bytes
ab7e98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
LLM Provider Interface for Flare
"""
import os
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
import httpx
from openai import AsyncOpenAI
from utils import log

class LLMInterface(ABC):
    """Abstract base class for LLM providers"""
    
    @abstractmethod
    async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
        """Generate response from LLM"""
        pass
    
    @abstractmethod
    async def startup(self, project_config: Dict) -> bool:
        """Initialize LLM with project config"""
        pass

class SparkLLM(LLMInterface):
    """Existing Spark integration"""
    
    def __init__(self, spark_endpoint: str, spark_token: str):
        self.spark_endpoint = spark_endpoint.rstrip("/")
        self.spark_token = spark_token
    
    async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
        headers = {
            "Authorization": f"Bearer {self.spark_token}",
            "Content-Type": "application/json"
        }
        
        payload = {
            "system_prompt": system_prompt,
            "user_input": user_input,
            "context": context
        }
        
        try:
            async with httpx.AsyncClient(timeout=60) as client:
                response = await client.post(
                    f"{self.spark_endpoint}/generate",
                    json=payload,
                    headers=headers
                )
                response.raise_for_status()
                data = response.json()
                
                # Try different response fields
                raw = data.get("model_answer", "").strip()
                if not raw:
                    raw = (data.get("assistant") or data.get("text", "")).strip()
                
                return raw
        except Exception as e:
            log(f"❌ Spark error: {e}")
            raise
    
    async def startup(self, project_config: Dict) -> bool:
        """Send startup request to Spark"""
        # Existing Spark startup logic
        return True

class GPT4oLLM(LLMInterface):
    """OpenAI GPT integration"""
    
    def __init__(self, api_key: str, model: str = "gpt-4o-mini"):
        self.api_key = api_key
        self.model = model
        self.client = AsyncOpenAI(api_key=api_key)
        log(f"βœ… Initialized GPT LLM with model: {model}")
    
    async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
        try:
            # Convert context to OpenAI format
            messages = []
            
            # Add system prompt
            messages.append({"role": "system", "content": system_prompt})
            
            # Add conversation history
            for msg in context[-10:]:  # Last 10 messages
                role = "user" if msg["role"] == "user" else "assistant"
                messages.append({"role": role, "content": msg["content"]})
            
            # Add current user input
            messages.append({"role": "user", "content": user_input})
            
            # Call OpenAI API
            response = await self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                temperature=0.3,  # Low temperature for consistency
                max_tokens=512
            )
            
            content = response.choices[0].message.content.strip()
            log(f"πŸͺ„ GPT response (first 120 chars): {content[:120]}")
            
            # Log token usage for cost tracking
            if response.usage:
                log(f"πŸ“Š Tokens used - Input: {response.usage.prompt_tokens}, Output: {response.usage.completion_tokens}")
            
            return content
            
        except Exception as e:
            log(f"❌ GPT error: {e}")
            raise
    
    async def startup(self, project_config: Dict) -> bool:
        """Validate API key"""
        try:
            # Test API key with a simple request
            test_response = await self.client.models.list()
            log(f"βœ… OpenAI API key validated, available models: {len(test_response.data)}")
            return True
        except Exception as e:
            log(f"❌ Invalid OpenAI API key: {e}")
            return False