youngtsai commited on
Commit
ac10d4c
·
1 Parent(s): a8d1987
Files changed (1) hide show
  1. llms.py +60 -0
llms.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from vertexai.generative_models import GenerativeModel
3
+ from openai import OpenAI
4
+ import json
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ class LLMProvider(ABC):
8
+ @abstractmethod
9
+ def generate(self, prompt: Optional[str] = None, messages: Optional[List[Dict]] = None, **kwargs) -> str:
10
+ pass
11
+
12
+ class GeminiProvider(LLMProvider):
13
+ def generate(self, prompt: Optional[str] = None, messages: Optional[List[Dict]] = None, **kwargs) -> str:
14
+ model_name = kwargs.get('model', 'gemini-pro')
15
+
16
+ model = GenerativeModel(model_name=model_name)
17
+
18
+ content = prompt
19
+ if messages and not prompt:
20
+ content = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
21
+
22
+ response = model.generate_content(content)
23
+ return response.text
24
+
25
+ class OpenAIProvider(LLMProvider):
26
+ def __init__(self, client):
27
+ self.client = client
28
+
29
+ def generate(self, prompt: Optional[str] = None, messages: Optional[List[Dict]] = None, **kwargs) -> str:
30
+ if prompt and not messages:
31
+ messages = [{"role": "user", "content": prompt}]
32
+
33
+ completion_kwargs = {
34
+ "model": kwargs.get("model", "gpt-3.5-turbo"),
35
+ "messages": messages,
36
+ "max_tokens": kwargs.get("max_tokens", 1000),
37
+ "temperature": kwargs.get("temperature", 1.0),
38
+ }
39
+
40
+ if "response_format" in kwargs:
41
+ completion_kwargs["response_format"] = kwargs["response_format"]
42
+
43
+ response = self.client.chat.completions.create(**completion_kwargs)
44
+ return response.choices[0].message.content
45
+
46
+ class LLMService:
47
+ def __init__(self, provider: LLMProvider):
48
+ self.provider = provider
49
+
50
+ def chat(
51
+ self,
52
+ prompt: Optional[str] = None,
53
+ messages: Optional[List[Dict[str, str]]] = None,
54
+ **kwargs
55
+ ) -> str:
56
+ try:
57
+ return self.provider.generate(prompt=prompt, messages=messages, **kwargs)
58
+ except Exception as e:
59
+ print(f"LLM API error: {str(e)}")
60
+ raise