ChatWorld / src /Models /models.py
JiangYH's picture
Upload folder using huggingface_hub
6f179e7 verified
raw
history blame
2.1 kB
import os
from string import Template
from typing import Dict, List, Union
from transformers import AutoTokenizer, AutoModelForCausalLM
from zhipuai import ZhipuAI
class GLM:
def __init__(self, model_name="silk-road/Haruhi-Zero-GLM3-6B-0_4"):
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=True
)
client = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True, device_map="auto"
)
self.client = client.eval()
def message2query(self, messages) -> str:
# [{'role': 'user', 'content': 'θ€εΈˆ: 同学请θ‡ͺζˆ‘δ»‹η»δΈ€δΈ‹'}]
# <|system|>
# You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
# <|user|>
# Hello
# <|assistant|>
# Hello, I'm ChatGLM3. What can I assist you today?
template = Template("<|$role|>\n$content\n")
return "".join([template.substitute(message) for message in messages])
def get_response(
self,
message: Union[str, list[dict[str, str]]],
history: List[Dict[str, str]] = None,
):
if isinstance(message, str):
response, history = self.client.chat(self.tokenizer, message)
elif isinstance(message, list):
response, history = self.client.chat(
self.tokenizer, message[-1]["content"],history=message[:-1]
)
# print(self.message2query(message))
print(response)
return response
class GLM_api:
def __init__(self, model_name="glm-4"):
API_KEY = os.environ.get("ZHIPU_API_KEY")
self.client = ZhipuAI(api_key=API_KEY)
self.model = model_name
def chat(self, message):
try:
response = self.client.chat.completions.create(
model=self.model, messages=message
)
except Exception as e:
print(e)
return "ζ¨‘εž‹θΏžζŽ₯ε€±θ΄₯"
return response.choices[0].message.content