|
import os |
|
from string import Template |
|
from typing import Dict, List, Union |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from zhipuai import ZhipuAI |
|
|
|
|
|
class GLM: |
|
def __init__(self, model_name="silk-road/Haruhi-Zero-GLM3-6B-0_4"): |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, trust_remote_code=True |
|
) |
|
client = AutoModelForCausalLM.from_pretrained( |
|
model_name, trust_remote_code=True, device_map="auto" |
|
) |
|
|
|
self.client = client.eval() |
|
|
|
def message2query(self, messages) -> str: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template = Template("<|$role|>\n$content\n") |
|
|
|
return "".join([template.substitute(message) for message in messages]) |
|
|
|
def get_response( |
|
self, |
|
message: Union[str, list[dict[str, str]]], |
|
history: List[Dict[str, str]] = None, |
|
): |
|
if isinstance(message, str): |
|
response, history = self.client.chat(self.tokenizer, message) |
|
elif isinstance(message, list): |
|
response, history = self.client.chat( |
|
self.tokenizer, message[-1]["content"],history=message[:-1] |
|
) |
|
|
|
print(response) |
|
return response |
|
|
|
|
|
class GLM_api: |
|
def __init__(self, model_name="glm-4"): |
|
API_KEY = os.environ.get("ZHIPU_API_KEY") |
|
|
|
self.client = ZhipuAI(api_key=API_KEY) |
|
self.model = model_name |
|
|
|
def chat(self, message): |
|
try: |
|
response = self.client.chat.completions.create( |
|
model=self.model, messages=message |
|
) |
|
except Exception as e: |
|
print(e) |
|
return "樑εθΏζ₯ε€±θ΄₯" |
|
|
|
return response.choices[0].message.content |
|
|