ciyidogan commited on
Commit
ab7e98d
Β·
verified Β·
1 Parent(s): 536e6ec

Create llm_interface.py

Browse files
Files changed (1) hide show
  1. llm_interface.py +123 -0
llm_interface.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Provider Interface for Flare
3
+ """
4
+ import os
5
+ from abc import ABC, abstractmethod
6
+ from typing import Dict, List, Optional
7
+ import httpx
8
+ from openai import AsyncOpenAI
9
+ from utils import log
10
+
11
+ class LLMInterface(ABC):
12
+ """Abstract base class for LLM providers"""
13
+
14
+ @abstractmethod
15
+ async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
16
+ """Generate response from LLM"""
17
+ pass
18
+
19
+ @abstractmethod
20
+ async def startup(self, project_config: Dict) -> bool:
21
+ """Initialize LLM with project config"""
22
+ pass
23
+
24
+ class SparkLLM(LLMInterface):
25
+ """Existing Spark integration"""
26
+
27
+ def __init__(self, spark_endpoint: str, spark_token: str):
28
+ self.spark_endpoint = spark_endpoint.rstrip("/")
29
+ self.spark_token = spark_token
30
+
31
+ async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
32
+ headers = {
33
+ "Authorization": f"Bearer {self.spark_token}",
34
+ "Content-Type": "application/json"
35
+ }
36
+
37
+ payload = {
38
+ "system_prompt": system_prompt,
39
+ "user_input": user_input,
40
+ "context": context
41
+ }
42
+
43
+ try:
44
+ async with httpx.AsyncClient(timeout=60) as client:
45
+ response = await client.post(
46
+ f"{self.spark_endpoint}/generate",
47
+ json=payload,
48
+ headers=headers
49
+ )
50
+ response.raise_for_status()
51
+ data = response.json()
52
+
53
+ # Try different response fields
54
+ raw = data.get("model_answer", "").strip()
55
+ if not raw:
56
+ raw = (data.get("assistant") or data.get("text", "")).strip()
57
+
58
+ return raw
59
+ except Exception as e:
60
+ log(f"❌ Spark error: {e}")
61
+ raise
62
+
63
+ async def startup(self, project_config: Dict) -> bool:
64
+ """Send startup request to Spark"""
65
+ # Existing Spark startup logic
66
+ return True
67
+
68
+ class GPT4oLLM(LLMInterface):
69
+ """OpenAI GPT integration"""
70
+
71
+ def __init__(self, api_key: str, model: str = "gpt-4o-mini"):
72
+ self.api_key = api_key
73
+ self.model = model
74
+ self.client = AsyncOpenAI(api_key=api_key)
75
+ log(f"βœ… Initialized GPT LLM with model: {model}")
76
+
77
+ async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
78
+ try:
79
+ # Convert context to OpenAI format
80
+ messages = []
81
+
82
+ # Add system prompt
83
+ messages.append({"role": "system", "content": system_prompt})
84
+
85
+ # Add conversation history
86
+ for msg in context[-10:]: # Last 10 messages
87
+ role = "user" if msg["role"] == "user" else "assistant"
88
+ messages.append({"role": role, "content": msg["content"]})
89
+
90
+ # Add current user input
91
+ messages.append({"role": "user", "content": user_input})
92
+
93
+ # Call OpenAI API
94
+ response = await self.client.chat.completions.create(
95
+ model=self.model,
96
+ messages=messages,
97
+ temperature=0.3, # Low temperature for consistency
98
+ max_tokens=512
99
+ )
100
+
101
+ content = response.choices[0].message.content.strip()
102
+ log(f"πŸͺ„ GPT response (first 120 chars): {content[:120]}")
103
+
104
+ # Log token usage for cost tracking
105
+ if response.usage:
106
+ log(f"πŸ“Š Tokens used - Input: {response.usage.prompt_tokens}, Output: {response.usage.completion_tokens}")
107
+
108
+ return content
109
+
110
+ except Exception as e:
111
+ log(f"❌ GPT error: {e}")
112
+ raise
113
+
114
+ async def startup(self, project_config: Dict) -> bool:
115
+ """Validate API key"""
116
+ try:
117
+ # Test API key with a simple request
118
+ test_response = await self.client.models.list()
119
+ log(f"βœ… OpenAI API key validated, available models: {len(test_response.data)}")
120
+ return True
121
+ except Exception as e:
122
+ log(f"❌ Invalid OpenAI API key: {e}")
123
+ return False