duzx16
commited on
Commit
·
38fbc7c
1
Parent(s):
1676f07
Fix stream_chat
Browse files- modeling_chatglm.py +14 -9
modeling_chatglm.py
CHANGED
@@ -14,6 +14,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm
|
|
14 |
from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
|
15 |
from torch.nn.utils import skip_init
|
16 |
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
|
|
17 |
|
18 |
from transformers.modeling_outputs import (
|
19 |
BaseModelOutputWithPast,
|
@@ -998,21 +999,24 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
998 |
|
999 |
def process_response(self, output, history):
|
1000 |
content = ""
|
|
|
1001 |
for response in output.split("<|assistant|>"):
|
1002 |
metadata, content = response.split("\n", maxsplit=1)
|
1003 |
-
history.append({"role": "assistant", "metadata": metadata, "content": content})
|
1004 |
if not metadata.strip():
|
1005 |
content = content.strip()
|
|
|
1006 |
content = content.replace("[[训练时间]]", "2023年")
|
1007 |
else:
|
|
|
1008 |
content = "\n".join(content.split("\n")[1:-1])
|
1009 |
def tool_call(**kwargs):
|
1010 |
return kwargs
|
1011 |
-
|
|
|
1012 |
return content, history
|
1013 |
|
1014 |
@torch.inference_mode()
|
1015 |
-
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str =
|
1016 |
max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
1017 |
**kwargs):
|
1018 |
if history is None:
|
@@ -1027,16 +1031,17 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1027 |
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
1028 |
tokenizer.get_command("<|observation|>")]
|
1029 |
outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
|
1030 |
-
outputs = outputs.tolist()[0][len(inputs["input_ids"][0])
|
1031 |
response = tokenizer.decode(outputs)
|
1032 |
history.append({"role": role, "content": query})
|
1033 |
response, history = self.process_response(response, history)
|
1034 |
return response, history
|
1035 |
|
1036 |
@torch.inference_mode()
|
1037 |
-
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str =
|
1038 |
past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
|
1039 |
logits_processor=None, return_past_key_values=False, **kwargs):
|
|
|
1040 |
if history is None:
|
1041 |
history = []
|
1042 |
if logits_processor is None:
|
@@ -1050,7 +1055,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1050 |
inputs = tokenizer.build_chat_input(query, history=history, role=role)
|
1051 |
else:
|
1052 |
inputs = tokenizer.build_chat_input(query, role=role)
|
1053 |
-
|
1054 |
if past_key_values is not None:
|
1055 |
past_length = past_key_values[0][0].shape[0]
|
1056 |
if self.transformer.pre_seq_len is not None:
|
@@ -1059,16 +1064,16 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1059 |
attention_mask = inputs.attention_mask
|
1060 |
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
1061 |
inputs['attention_mask'] = attention_mask
|
|
|
1062 |
for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
|
1063 |
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
|
1064 |
**gen_kwargs):
|
1065 |
if return_past_key_values:
|
1066 |
outputs, past_key_values = outputs
|
1067 |
-
outputs = outputs.tolist()[0][len(inputs["input_ids"][0])
|
1068 |
response = tokenizer.decode(outputs)
|
1069 |
if response and response[-1] != "�":
|
1070 |
-
response = self.process_response(response)
|
1071 |
-
new_history = history + [(query, response)]
|
1072 |
if return_past_key_values:
|
1073 |
yield response, new_history, past_key_values
|
1074 |
else:
|
|
|
14 |
from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
|
15 |
from torch.nn.utils import skip_init
|
16 |
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
17 |
+
from copy import deepcopy
|
18 |
|
19 |
from transformers.modeling_outputs import (
|
20 |
BaseModelOutputWithPast,
|
|
|
999 |
|
1000 |
def process_response(self, output, history):
|
1001 |
content = ""
|
1002 |
+
history = deepcopy(history)
|
1003 |
for response in output.split("<|assistant|>"):
|
1004 |
metadata, content = response.split("\n", maxsplit=1)
|
|
|
1005 |
if not metadata.strip():
|
1006 |
content = content.strip()
|
1007 |
+
history.append({"role": "assistant", "metadata": metadata, "content": content})
|
1008 |
content = content.replace("[[训练时间]]", "2023年")
|
1009 |
else:
|
1010 |
+
history.append({"role": "assistant", "metadata": metadata, "content": content})
|
1011 |
content = "\n".join(content.split("\n")[1:-1])
|
1012 |
def tool_call(**kwargs):
|
1013 |
return kwargs
|
1014 |
+
parameters = eval(content)
|
1015 |
+
content = {"name": metadata.strip(), "parameters": parameters}
|
1016 |
return content, history
|
1017 |
|
1018 |
@torch.inference_mode()
|
1019 |
+
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
|
1020 |
max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
1021 |
**kwargs):
|
1022 |
if history is None:
|
|
|
1031 |
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
1032 |
tokenizer.get_command("<|observation|>")]
|
1033 |
outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
|
1034 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
1035 |
response = tokenizer.decode(outputs)
|
1036 |
history.append({"role": role, "content": query})
|
1037 |
response, history = self.process_response(response, history)
|
1038 |
return response, history
|
1039 |
|
1040 |
@torch.inference_mode()
|
1041 |
+
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
|
1042 |
past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
|
1043 |
logits_processor=None, return_past_key_values=False, **kwargs):
|
1044 |
+
print(history)
|
1045 |
if history is None:
|
1046 |
history = []
|
1047 |
if logits_processor is None:
|
|
|
1055 |
inputs = tokenizer.build_chat_input(query, history=history, role=role)
|
1056 |
else:
|
1057 |
inputs = tokenizer.build_chat_input(query, role=role)
|
1058 |
+
inputs = inputs.to(self.device)
|
1059 |
if past_key_values is not None:
|
1060 |
past_length = past_key_values[0][0].shape[0]
|
1061 |
if self.transformer.pre_seq_len is not None:
|
|
|
1064 |
attention_mask = inputs.attention_mask
|
1065 |
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
1066 |
inputs['attention_mask'] = attention_mask
|
1067 |
+
history.append({"role": role, "content": query})
|
1068 |
for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
|
1069 |
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
|
1070 |
**gen_kwargs):
|
1071 |
if return_past_key_values:
|
1072 |
outputs, past_key_values = outputs
|
1073 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
1074 |
response = tokenizer.decode(outputs)
|
1075 |
if response and response[-1] != "�":
|
1076 |
+
response, new_history = self.process_response(response, history)
|
|
|
1077 |
if return_past_key_values:
|
1078 |
yield response, new_history, past_key_values
|
1079 |
else:
|