Spaces:
Sleeping
Sleeping
llms
Browse files
llms.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from vertexai.generative_models import GenerativeModel
|
3 |
+
from openai import OpenAI
|
4 |
+
import json
|
5 |
+
from typing import Dict, List, Optional, Union
|
6 |
+
|
7 |
+
class LLMProvider(ABC):
|
8 |
+
@abstractmethod
|
9 |
+
def generate(self, prompt: Optional[str] = None, messages: Optional[List[Dict]] = None, **kwargs) -> str:
|
10 |
+
pass
|
11 |
+
|
12 |
+
class GeminiProvider(LLMProvider):
|
13 |
+
def generate(self, prompt: Optional[str] = None, messages: Optional[List[Dict]] = None, **kwargs) -> str:
|
14 |
+
model_name = kwargs.get('model', 'gemini-pro')
|
15 |
+
|
16 |
+
model = GenerativeModel(model_name=model_name)
|
17 |
+
|
18 |
+
content = prompt
|
19 |
+
if messages and not prompt:
|
20 |
+
content = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
|
21 |
+
|
22 |
+
response = model.generate_content(content)
|
23 |
+
return response.text
|
24 |
+
|
25 |
+
class OpenAIProvider(LLMProvider):
|
26 |
+
def __init__(self, client):
|
27 |
+
self.client = client
|
28 |
+
|
29 |
+
def generate(self, prompt: Optional[str] = None, messages: Optional[List[Dict]] = None, **kwargs) -> str:
|
30 |
+
if prompt and not messages:
|
31 |
+
messages = [{"role": "user", "content": prompt}]
|
32 |
+
|
33 |
+
completion_kwargs = {
|
34 |
+
"model": kwargs.get("model", "gpt-3.5-turbo"),
|
35 |
+
"messages": messages,
|
36 |
+
"max_tokens": kwargs.get("max_tokens", 1000),
|
37 |
+
"temperature": kwargs.get("temperature", 1.0),
|
38 |
+
}
|
39 |
+
|
40 |
+
if "response_format" in kwargs:
|
41 |
+
completion_kwargs["response_format"] = kwargs["response_format"]
|
42 |
+
|
43 |
+
response = self.client.chat.completions.create(**completion_kwargs)
|
44 |
+
return response.choices[0].message.content
|
45 |
+
|
46 |
+
class LLMService:
|
47 |
+
def __init__(self, provider: LLMProvider):
|
48 |
+
self.provider = provider
|
49 |
+
|
50 |
+
def chat(
|
51 |
+
self,
|
52 |
+
prompt: Optional[str] = None,
|
53 |
+
messages: Optional[List[Dict[str, str]]] = None,
|
54 |
+
**kwargs
|
55 |
+
) -> str:
|
56 |
+
try:
|
57 |
+
return self.provider.generate(prompt=prompt, messages=messages, **kwargs)
|
58 |
+
except Exception as e:
|
59 |
+
print(f"LLM API error: {str(e)}")
|
60 |
+
raise
|