Upload folder using huggingface_hub
Browse files- ChatWorld/ChatWorld.py +32 -20
- ChatWorld/NaiveDB.py +4 -2
- ChatWorld/models.py +44 -1
- app.py +2 -2
- run_gradio.sh +1 -0
ChatWorld/ChatWorld.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from jinja2 import Template
|
2 |
import torch
|
3 |
|
4 |
-
from .models import
|
5 |
|
6 |
from .NaiveDB import NaiveDB
|
7 |
from .utils import *
|
@@ -20,7 +20,7 @@ class ChatWorld:
|
|
20 |
self.history = []
|
21 |
|
22 |
self.client = None
|
23 |
-
self.model =
|
24 |
self.db = NaiveDB()
|
25 |
self.prompt = Template(('Please be aware that your codename in this conversation is "{{model_role_name}}"'
|
26 |
'{% if model_role_nickname %},别人一般称呼你"{{model_role_nickname}}"{% endif %},同时我是"{{role_name}}"{% if role_nickname %},你也可以称呼我"{{role_nickname}}"{% endif %}。\n'
|
@@ -30,6 +30,7 @@ class ChatWorld:
|
|
30 |
'如果我问的问题超出聊天记录的范围,模仿{{model_role_name}}的语气进行回复。\n'
|
31 |
'请不要回答你是语言模型,永远记住你就是{{model_role_name}}。\n'
|
32 |
'请你永远只以{{model_role_name}}身份,进行任何的回复。\n'
|
|
|
33 |
))
|
34 |
|
35 |
def getEmbeddingsFromStory(self, stories: list[str]):
|
@@ -38,25 +39,31 @@ class ChatWorld:
|
|
38 |
if len(self.story_vec) == len(stories) and all([self.story_vec[i]["text"] == stories[i] for i in range(len(stories))]):
|
39 |
return [self.story_vec[i]["vec"] for i in range(len(stories))]
|
40 |
|
41 |
-
if self.embedding is None:
|
42 |
-
self.embedding = initEmbedding()
|
43 |
-
|
44 |
-
if self.tokenizer is None:
|
45 |
-
self.tokenizer = initTokenizer()
|
46 |
-
|
47 |
self.story_vec = []
|
48 |
for story in stories:
|
49 |
with torch.no_grad():
|
50 |
-
|
51 |
-
story, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
52 |
-
outputs = self.embedding(**inputs)[0][:, 0]
|
53 |
-
vec = torch.nn.functional.normalize(
|
54 |
-
outputs, p=2, dim=1).tolist()[0]
|
55 |
|
56 |
self.story_vec.append({"text": story, "vec": vec})
|
57 |
|
58 |
return [self.story_vec[i]["vec"] for i in range(len(stories))]
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
def initDB(self, storys: list[str]):
|
61 |
story_vecs = self.getEmbeddingsFromStory(storys)
|
62 |
self.db.build_db(storys, story_vecs)
|
@@ -65,21 +72,26 @@ class ChatWorld:
|
|
65 |
self.model_role_name = role_name
|
66 |
self.model_role_nickname = role_nick_name
|
67 |
|
68 |
-
def getSystemPrompt(self, role_name, role_nick_name):
|
69 |
assert self.model_role_name, "Please set model role name first"
|
70 |
|
71 |
-
|
|
|
72 |
|
73 |
-
|
74 |
-
message = [self.getSystemPrompt(
|
75 |
-
user_role_name, user_role_nick_name)] + self.history
|
76 |
|
|
|
|
|
|
|
|
|
77 |
if use_local_model:
|
78 |
response = self.model.get_response(message)
|
79 |
else:
|
80 |
response = self.client.chat(
|
81 |
user_role_name, text, user_role_nick_name)
|
82 |
|
83 |
-
self.history.append(
|
84 |
-
|
|
|
|
|
85 |
return response
|
|
|
1 |
from jinja2 import Template
|
2 |
import torch
|
3 |
|
4 |
+
from .models import GLM
|
5 |
|
6 |
from .NaiveDB import NaiveDB
|
7 |
from .utils import *
|
|
|
20 |
self.history = []
|
21 |
|
22 |
self.client = None
|
23 |
+
self.model = GLM()
|
24 |
self.db = NaiveDB()
|
25 |
self.prompt = Template(('Please be aware that your codename in this conversation is "{{model_role_name}}"'
|
26 |
'{% if model_role_nickname %},别人一般称呼你"{{model_role_nickname}}"{% endif %},同时我是"{{role_name}}"{% if role_nickname %},你也可以称呼我"{{role_nickname}}"{% endif %}。\n'
|
|
|
30 |
'如果我问的问题超出聊天记录的范围,模仿{{model_role_name}}的语气进行回复。\n'
|
31 |
'请不要回答你是语言模型,永远记住你就是{{model_role_name}}。\n'
|
32 |
'请你永远只以{{model_role_name}}身份,进行任何的回复。\n'
|
33 |
+
'{% if RAG %}{% for i in RAG %}##\n{{i}}\n##\n\n{% endfor %}{% endif %}'
|
34 |
))
|
35 |
|
36 |
def getEmbeddingsFromStory(self, stories: list[str]):
|
|
|
39 |
if len(self.story_vec) == len(stories) and all([self.story_vec[i]["text"] == stories[i] for i in range(len(stories))]):
|
40 |
return [self.story_vec[i]["vec"] for i in range(len(stories))]
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
self.story_vec = []
|
43 |
for story in stories:
|
44 |
with torch.no_grad():
|
45 |
+
vec = self.getEmbedding(story)
|
|
|
|
|
|
|
|
|
46 |
|
47 |
self.story_vec.append({"text": story, "vec": vec})
|
48 |
|
49 |
return [self.story_vec[i]["vec"] for i in range(len(stories))]
|
50 |
|
51 |
+
def getEmbedding(self, text: str):
|
52 |
+
if self.embedding is None:
|
53 |
+
self.embedding = initEmbedding()
|
54 |
+
|
55 |
+
if self.tokenizer is None:
|
56 |
+
self.tokenizer = initTokenizer()
|
57 |
+
|
58 |
+
with torch.no_grad():
|
59 |
+
inputs = self.tokenizer(
|
60 |
+
text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.embedding.device)
|
61 |
+
outputs = self.embedding(**inputs)[0][:, 0]
|
62 |
+
vec = torch.nn.functional.normalize(
|
63 |
+
outputs, p=2, dim=1).tolist()[0]
|
64 |
+
|
65 |
+
return vec
|
66 |
+
|
67 |
def initDB(self, storys: list[str]):
|
68 |
story_vecs = self.getEmbeddingsFromStory(storys)
|
69 |
self.db.build_db(storys, story_vecs)
|
|
|
72 |
self.model_role_name = role_name
|
73 |
self.model_role_nickname = role_nick_name
|
74 |
|
75 |
+
def getSystemPrompt(self, text, role_name, role_nick_name):
|
76 |
assert self.model_role_name, "Please set model role name first"
|
77 |
|
78 |
+
query = self.getEmbedding(text)
|
79 |
+
rag = self.db.search(query, 5)
|
80 |
|
81 |
+
return {"role": "system", "content": self.prompt.render(model_role_name=self.model_role_name, model_role_nickname=self.model_role_nickname, role_name=role_name, role_nickname=role_nick_name, RAG=rag)}
|
|
|
|
|
82 |
|
83 |
+
def chat(self, text: str, user_role_name: str, user_role_nick_name: str = None, use_local_model=False):
|
84 |
+
message = [self.getSystemPrompt(text,
|
85 |
+
user_role_name, user_role_nick_name)] + self.history
|
86 |
+
print(message)
|
87 |
if use_local_model:
|
88 |
response = self.model.get_response(message)
|
89 |
else:
|
90 |
response = self.client.chat(
|
91 |
user_role_name, text, user_role_nick_name)
|
92 |
|
93 |
+
self.history.append(
|
94 |
+
{"role": "user", "content": f"{user_role_name}:「{text}」"})
|
95 |
+
self.history.append(
|
96 |
+
{"role": "assistant", "content": f"{self.model_role_name}:「{response}」"})
|
97 |
return response
|
ChatWorld/NaiveDB.py
CHANGED
@@ -81,5 +81,7 @@ class NaiveDB:
|
|
81 |
similarities.sort(key=lambda x: x[0], reverse=True)
|
82 |
self.last_search_ids = [x[1] for x in similarities[:n_results]]
|
83 |
|
84 |
-
|
85 |
-
|
|
|
|
|
|
81 |
similarities.sort(key=lambda x: x[0], reverse=True)
|
82 |
self.last_search_ids = [x[1] for x in similarities[:n_results]]
|
83 |
|
84 |
+
|
85 |
+
|
86 |
+
top_stories = [self.stories[_id] for _id in self.last_search_ids]
|
87 |
+
return top_stories
|
ChatWorld/models.py
CHANGED
@@ -1,4 +1,7 @@
|
|
|
|
|
|
1 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
2 |
|
3 |
|
4 |
class qwen_model:
|
@@ -11,7 +14,9 @@ class qwen_model:
|
|
11 |
def get_response(self, message):
|
12 |
message = self.tokenizer.apply_chat_template(
|
13 |
message, tokenize=False, add_generation_prompt=True)
|
14 |
-
|
|
|
|
|
15 |
generated_ids = self.model.generate(
|
16 |
model_inputs.input_ids,
|
17 |
max_new_tokens=512
|
@@ -22,4 +27,42 @@ class qwen_model:
|
|
22 |
|
23 |
response = self.tokenizer.batch_decode(
|
24 |
generated_ids, skip_special_tokens=True)[0]
|
|
|
25 |
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from string import Template
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
+
from zhipuai import ZhipuAI
|
5 |
|
6 |
|
7 |
class qwen_model:
|
|
|
14 |
def get_response(self, message):
|
15 |
message = self.tokenizer.apply_chat_template(
|
16 |
message, tokenize=False, add_generation_prompt=True)
|
17 |
+
print(message)
|
18 |
+
model_inputs = self.tokenizer(
|
19 |
+
[message], return_tensors="pt").to(self.model.device)
|
20 |
generated_ids = self.model.generate(
|
21 |
model_inputs.input_ids,
|
22 |
max_new_tokens=512
|
|
|
27 |
|
28 |
response = self.tokenizer.batch_decode(
|
29 |
generated_ids, skip_special_tokens=True)[0]
|
30 |
+
|
31 |
return response
|
32 |
+
|
33 |
+
|
34 |
+
class GLM():
|
35 |
+
def __init__(self, model_name="silk-road/Haruhi-Zero-GLM3-6B-0_4"):
|
36 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
37 |
+
model_name, trust_remote_code=True)
|
38 |
+
client = AutoModelForCausalLM.from_pretrained(
|
39 |
+
model_name, trust_remote_code=True, device_map="auto")
|
40 |
+
|
41 |
+
client = client.eval()
|
42 |
+
|
43 |
+
def message2query(messages) -> str:
|
44 |
+
# [{'role': 'user', 'content': '老师: 同学请自我介绍一下'}]
|
45 |
+
# <|system|>
|
46 |
+
# You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
|
47 |
+
# <|user|>
|
48 |
+
# Hello
|
49 |
+
# <|assistant|>
|
50 |
+
# Hello, I'm ChatGLM3. What can I assist you today?
|
51 |
+
template = Template("<|$role|>\n$content\n")
|
52 |
+
|
53 |
+
return "".join([template.substitute(message) for message in messages])
|
54 |
+
|
55 |
+
def get_response(self, message):
|
56 |
+
response, history = self.client.chat(self.tokenizer, message)
|
57 |
+
return response
|
58 |
+
|
59 |
+
|
60 |
+
class GLM_api:
|
61 |
+
def __init__(self, model_name="glm-4"):
|
62 |
+
self.client = ZhipuAI(api_key=os.environ["ZHIPU_API_KEY"])
|
63 |
+
self.model = model_name
|
64 |
+
|
65 |
+
def getResponse(self, message):
|
66 |
+
response = self.client.chat.completions.create(
|
67 |
+
model=self.model, prompt=message)
|
68 |
+
return response.choices[0].message
|
app.py
CHANGED
@@ -38,8 +38,8 @@ def getContent(input_file):
|
|
38 |
|
39 |
def submit_message(message, history, model_role_name, role_name, model_role_nickname, role_nickname):
|
40 |
chatWorld.setRoleName(model_role_name, model_role_nickname)
|
41 |
-
response = chatWorld.chat(
|
42 |
-
|
43 |
return response
|
44 |
|
45 |
|
|
|
38 |
|
39 |
def submit_message(message, history, model_role_name, role_name, model_role_nickname, role_nickname):
|
40 |
chatWorld.setRoleName(model_role_name, model_role_nickname)
|
41 |
+
response = chatWorld.chat(message,
|
42 |
+
role_name, role_nickname, use_local_model=True)
|
43 |
return response
|
44 |
|
45 |
|
run_gradio.sh
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
export CUDA_VISIBLE_DEVICES=0
|
2 |
export HF_HOME="/workspace/jyh/.cache/huggingface"
|
|
|
3 |
|
4 |
# Start the gradio server
|
5 |
/workspace/jyh/miniconda3/envs/ChatWorld/bin/python /workspace/jyh/Zero-Haruhi/app.py
|
|
|
1 |
export CUDA_VISIBLE_DEVICES=0
|
2 |
export HF_HOME="/workspace/jyh/.cache/huggingface"
|
3 |
+
export HF_ENDPOINT="https://hf-mirror.com"
|
4 |
|
5 |
# Start the gradio server
|
6 |
/workspace/jyh/miniconda3/envs/ChatWorld/bin/python /workspace/jyh/Zero-Haruhi/app.py
|