JiangYH commited on
Commit
6146562
·
verified ·
1 Parent(s): 403e597

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. ChatWorld/ChatWorld.py +32 -20
  2. ChatWorld/NaiveDB.py +4 -2
  3. ChatWorld/models.py +44 -1
  4. app.py +2 -2
  5. run_gradio.sh +1 -0
ChatWorld/ChatWorld.py CHANGED
@@ -1,7 +1,7 @@
1
  from jinja2 import Template
2
  import torch
3
 
4
- from .models import qwen_model
5
 
6
  from .NaiveDB import NaiveDB
7
  from .utils import *
@@ -20,7 +20,7 @@ class ChatWorld:
20
  self.history = []
21
 
22
  self.client = None
23
- self.model = qwen_model(pretrained_model_name_or_path)
24
  self.db = NaiveDB()
25
  self.prompt = Template(('Please be aware that your codename in this conversation is "{{model_role_name}}"'
26
  '{% if model_role_nickname %},别人一般称呼你"{{model_role_nickname}}"{% endif %},同时我是"{{role_name}}"{% if role_nickname %},你也可以称呼我"{{role_nickname}}"{% endif %}。\n'
@@ -30,6 +30,7 @@ class ChatWorld:
30
  '如果我问的问题超出聊天记录的范围,模仿{{model_role_name}}的语气进行回复。\n'
31
  '请不要回答你是语言模型,永远记住你就是{{model_role_name}}。\n'
32
  '请你永远只以{{model_role_name}}身份,进行任何的回复。\n'
 
33
  ))
34
 
35
  def getEmbeddingsFromStory(self, stories: list[str]):
@@ -38,25 +39,31 @@ class ChatWorld:
38
  if len(self.story_vec) == len(stories) and all([self.story_vec[i]["text"] == stories[i] for i in range(len(stories))]):
39
  return [self.story_vec[i]["vec"] for i in range(len(stories))]
40
 
41
- if self.embedding is None:
42
- self.embedding = initEmbedding()
43
-
44
- if self.tokenizer is None:
45
- self.tokenizer = initTokenizer()
46
-
47
  self.story_vec = []
48
  for story in stories:
49
  with torch.no_grad():
50
- inputs = self.tokenizer(
51
- story, return_tensors="pt", padding=True, truncation=True, max_length=512)
52
- outputs = self.embedding(**inputs)[0][:, 0]
53
- vec = torch.nn.functional.normalize(
54
- outputs, p=2, dim=1).tolist()[0]
55
 
56
  self.story_vec.append({"text": story, "vec": vec})
57
 
58
  return [self.story_vec[i]["vec"] for i in range(len(stories))]
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def initDB(self, storys: list[str]):
61
  story_vecs = self.getEmbeddingsFromStory(storys)
62
  self.db.build_db(storys, story_vecs)
@@ -65,21 +72,26 @@ class ChatWorld:
65
  self.model_role_name = role_name
66
  self.model_role_nickname = role_nick_name
67
 
68
- def getSystemPrompt(self, role_name, role_nick_name):
69
  assert self.model_role_name, "Please set model role name first"
70
 
71
- return {"role": "system", "content": self.prompt.render(model_role_name=self.model_role_name, model_role_nickname=self.model_role_nickname, role_name=role_name, role_nickname=role_nick_name)}
 
72
 
73
- def chat(self, user_role_name: str, text: str, user_role_nick_name: str = None, use_local_model=False):
74
- message = [self.getSystemPrompt(
75
- user_role_name, user_role_nick_name)] + self.history
76
 
 
 
 
 
77
  if use_local_model:
78
  response = self.model.get_response(message)
79
  else:
80
  response = self.client.chat(
81
  user_role_name, text, user_role_nick_name)
82
 
83
- self.history.append({"role": "user", "content": text})
84
- self.history.append({"role": "model", "content": response})
 
 
85
  return response
 
1
  from jinja2 import Template
2
  import torch
3
 
4
+ from .models import GLM
5
 
6
  from .NaiveDB import NaiveDB
7
  from .utils import *
 
20
  self.history = []
21
 
22
  self.client = None
23
+ self.model = GLM()
24
  self.db = NaiveDB()
25
  self.prompt = Template(('Please be aware that your codename in this conversation is "{{model_role_name}}"'
26
  '{% if model_role_nickname %},别人一般称呼你"{{model_role_nickname}}"{% endif %},同时我是"{{role_name}}"{% if role_nickname %},你也可以称呼我"{{role_nickname}}"{% endif %}。\n'
 
30
  '如果我问的问题超出聊天记录的范围,模仿{{model_role_name}}的语气进行回复。\n'
31
  '请不要回答你是语言模型,永远记住你就是{{model_role_name}}。\n'
32
  '请你永远只以{{model_role_name}}身份,进行任何的回复。\n'
33
+ '{% if RAG %}{% for i in RAG %}##\n{{i}}\n##\n\n{% endfor %}{% endif %}'
34
  ))
35
 
36
  def getEmbeddingsFromStory(self, stories: list[str]):
 
39
  if len(self.story_vec) == len(stories) and all([self.story_vec[i]["text"] == stories[i] for i in range(len(stories))]):
40
  return [self.story_vec[i]["vec"] for i in range(len(stories))]
41
 
 
 
 
 
 
 
42
  self.story_vec = []
43
  for story in stories:
44
  with torch.no_grad():
45
+ vec = self.getEmbedding(story)
 
 
 
 
46
 
47
  self.story_vec.append({"text": story, "vec": vec})
48
 
49
  return [self.story_vec[i]["vec"] for i in range(len(stories))]
50
 
51
+ def getEmbedding(self, text: str):
52
+ if self.embedding is None:
53
+ self.embedding = initEmbedding()
54
+
55
+ if self.tokenizer is None:
56
+ self.tokenizer = initTokenizer()
57
+
58
+ with torch.no_grad():
59
+ inputs = self.tokenizer(
60
+ text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.embedding.device)
61
+ outputs = self.embedding(**inputs)[0][:, 0]
62
+ vec = torch.nn.functional.normalize(
63
+ outputs, p=2, dim=1).tolist()[0]
64
+
65
+ return vec
66
+
67
  def initDB(self, storys: list[str]):
68
  story_vecs = self.getEmbeddingsFromStory(storys)
69
  self.db.build_db(storys, story_vecs)
 
72
  self.model_role_name = role_name
73
  self.model_role_nickname = role_nick_name
74
 
75
+ def getSystemPrompt(self, text, role_name, role_nick_name):
76
  assert self.model_role_name, "Please set model role name first"
77
 
78
+ query = self.getEmbedding(text)
79
+ rag = self.db.search(query, 5)
80
 
81
+ return {"role": "system", "content": self.prompt.render(model_role_name=self.model_role_name, model_role_nickname=self.model_role_nickname, role_name=role_name, role_nickname=role_nick_name, RAG=rag)}
 
 
82
 
83
+ def chat(self, text: str, user_role_name: str, user_role_nick_name: str = None, use_local_model=False):
84
+ message = [self.getSystemPrompt(text,
85
+ user_role_name, user_role_nick_name)] + self.history
86
+ print(message)
87
  if use_local_model:
88
  response = self.model.get_response(message)
89
  else:
90
  response = self.client.chat(
91
  user_role_name, text, user_role_nick_name)
92
 
93
+ self.history.append(
94
+ {"role": "user", "content": f"{user_role_name}:「{text}」"})
95
+ self.history.append(
96
+ {"role": "assistant", "content": f"{self.model_role_name}:「{response}」"})
97
  return response
ChatWorld/NaiveDB.py CHANGED
@@ -81,5 +81,7 @@ class NaiveDB:
81
  similarities.sort(key=lambda x: x[0], reverse=True)
82
  self.last_search_ids = [x[1] for x in similarities[:n_results]]
83
 
84
- top_indices = [x[1] for x in similarities[:n_results]]
85
- return top_indices
 
 
 
81
  similarities.sort(key=lambda x: x[0], reverse=True)
82
  self.last_search_ids = [x[1] for x in similarities[:n_results]]
83
 
84
+
85
+
86
+ top_stories = [self.stories[_id] for _id in self.last_search_ids]
87
+ return top_stories
ChatWorld/models.py CHANGED
@@ -1,4 +1,7 @@
 
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
2
 
3
 
4
  class qwen_model:
@@ -11,7 +14,9 @@ class qwen_model:
11
  def get_response(self, message):
12
  message = self.tokenizer.apply_chat_template(
13
  message, tokenize=False, add_generation_prompt=True)
14
- model_inputs = self.tokenizer([message], return_tensors="pt")
 
 
15
  generated_ids = self.model.generate(
16
  model_inputs.input_ids,
17
  max_new_tokens=512
@@ -22,4 +27,42 @@ class qwen_model:
22
 
23
  response = self.tokenizer.batch_decode(
24
  generated_ids, skip_special_tokens=True)[0]
 
25
  return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from string import Template
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from zhipuai import ZhipuAI
5
 
6
 
7
  class qwen_model:
 
14
  def get_response(self, message):
15
  message = self.tokenizer.apply_chat_template(
16
  message, tokenize=False, add_generation_prompt=True)
17
+ print(message)
18
+ model_inputs = self.tokenizer(
19
+ [message], return_tensors="pt").to(self.model.device)
20
  generated_ids = self.model.generate(
21
  model_inputs.input_ids,
22
  max_new_tokens=512
 
27
 
28
  response = self.tokenizer.batch_decode(
29
  generated_ids, skip_special_tokens=True)[0]
30
+
31
  return response
32
+
33
+
34
+ class GLM():
35
+ def __init__(self, model_name="silk-road/Haruhi-Zero-GLM3-6B-0_4"):
36
+ tokenizer = AutoTokenizer.from_pretrained(
37
+ model_name, trust_remote_code=True)
38
+ client = AutoModelForCausalLM.from_pretrained(
39
+ model_name, trust_remote_code=True, device_map="auto")
40
+
41
+ client = client.eval()
42
+
43
+ def message2query(messages) -> str:
44
+ # [{'role': 'user', 'content': '老师: 同学请自我介绍一下'}]
45
+ # <|system|>
46
+ # You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
47
+ # <|user|>
48
+ # Hello
49
+ # <|assistant|>
50
+ # Hello, I'm ChatGLM3. What can I assist you today?
51
+ template = Template("<|$role|>\n$content\n")
52
+
53
+ return "".join([template.substitute(message) for message in messages])
54
+
55
+ def get_response(self, message):
56
+ response, history = self.client.chat(self.tokenizer, message)
57
+ return response
58
+
59
+
60
+ class GLM_api:
61
+ def __init__(self, model_name="glm-4"):
62
+ self.client = ZhipuAI(api_key=os.environ["ZHIPU_API_KEY"])
63
+ self.model = model_name
64
+
65
+ def getResponse(self, message):
66
+ response = self.client.chat.completions.create(
67
+ model=self.model, prompt=message)
68
+ return response.choices[0].message
app.py CHANGED
@@ -38,8 +38,8 @@ def getContent(input_file):
38
 
39
  def submit_message(message, history, model_role_name, role_name, model_role_nickname, role_nickname):
40
  chatWorld.setRoleName(model_role_name, model_role_nickname)
41
- response = chatWorld.chat(
42
- role_name, message, role_nickname, use_local_model=True)
43
  return response
44
 
45
 
 
38
 
39
  def submit_message(message, history, model_role_name, role_name, model_role_nickname, role_nickname):
40
  chatWorld.setRoleName(model_role_name, model_role_nickname)
41
+ response = chatWorld.chat(message,
42
+ role_name, role_nickname, use_local_model=True)
43
  return response
44
 
45
 
run_gradio.sh CHANGED
@@ -1,5 +1,6 @@
1
  export CUDA_VISIBLE_DEVICES=0
2
  export HF_HOME="/workspace/jyh/.cache/huggingface"
 
3
 
4
  # Start the gradio server
5
  /workspace/jyh/miniconda3/envs/ChatWorld/bin/python /workspace/jyh/Zero-Haruhi/app.py
 
1
  export CUDA_VISIBLE_DEVICES=0
2
  export HF_HOME="/workspace/jyh/.cache/huggingface"
3
+ export HF_ENDPOINT="https://hf-mirror.com"
4
 
5
  # Start the gradio server
6
  /workspace/jyh/miniconda3/envs/ChatWorld/bin/python /workspace/jyh/Zero-Haruhi/app.py