duzx16 commited on
Commit
38fbc7c
·
1 Parent(s): 1676f07

Fix stream_chat

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +14 -9
modeling_chatglm.py CHANGED
@@ -14,6 +14,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm
14
  from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
15
  from torch.nn.utils import skip_init
16
  from typing import Optional, Tuple, Union, List, Callable, Dict, Any
 
17
 
18
  from transformers.modeling_outputs import (
19
  BaseModelOutputWithPast,
@@ -998,21 +999,24 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
998
 
999
  def process_response(self, output, history):
1000
  content = ""
 
1001
  for response in output.split("<|assistant|>"):
1002
  metadata, content = response.split("\n", maxsplit=1)
1003
- history.append({"role": "assistant", "metadata": metadata, "content": content})
1004
  if not metadata.strip():
1005
  content = content.strip()
 
1006
  content = content.replace("[[训练时间]]", "2023年")
1007
  else:
 
1008
  content = "\n".join(content.split("\n")[1:-1])
1009
  def tool_call(**kwargs):
1010
  return kwargs
1011
- content = eval(content)
 
1012
  return content, history
1013
 
1014
  @torch.inference_mode()
1015
- def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = None,
1016
  max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1017
  **kwargs):
1018
  if history is None:
@@ -1027,16 +1031,17 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1027
  eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
1028
  tokenizer.get_command("<|observation|>")]
1029
  outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1030
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
1031
  response = tokenizer.decode(outputs)
1032
  history.append({"role": role, "content": query})
1033
  response, history = self.process_response(response, history)
1034
  return response, history
1035
 
1036
  @torch.inference_mode()
1037
- def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = None,
1038
  past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
1039
  logits_processor=None, return_past_key_values=False, **kwargs):
 
1040
  if history is None:
1041
  history = []
1042
  if logits_processor is None:
@@ -1050,7 +1055,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1050
  inputs = tokenizer.build_chat_input(query, history=history, role=role)
1051
  else:
1052
  inputs = tokenizer.build_chat_input(query, role=role)
1053
- input = inputs.to(self.device)
1054
  if past_key_values is not None:
1055
  past_length = past_key_values[0][0].shape[0]
1056
  if self.transformer.pre_seq_len is not None:
@@ -1059,16 +1064,16 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1059
  attention_mask = inputs.attention_mask
1060
  attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
1061
  inputs['attention_mask'] = attention_mask
 
1062
  for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
1063
  eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
1064
  **gen_kwargs):
1065
  if return_past_key_values:
1066
  outputs, past_key_values = outputs
1067
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
1068
  response = tokenizer.decode(outputs)
1069
  if response and response[-1] != "�":
1070
- response = self.process_response(response)
1071
- new_history = history + [(query, response)]
1072
  if return_past_key_values:
1073
  yield response, new_history, past_key_values
1074
  else:
 
14
  from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
15
  from torch.nn.utils import skip_init
16
  from typing import Optional, Tuple, Union, List, Callable, Dict, Any
17
+ from copy import deepcopy
18
 
19
  from transformers.modeling_outputs import (
20
  BaseModelOutputWithPast,
 
999
 
1000
  def process_response(self, output, history):
1001
  content = ""
1002
+ history = deepcopy(history)
1003
  for response in output.split("<|assistant|>"):
1004
  metadata, content = response.split("\n", maxsplit=1)
 
1005
  if not metadata.strip():
1006
  content = content.strip()
1007
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1008
  content = content.replace("[[训练时间]]", "2023年")
1009
  else:
1010
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1011
  content = "\n".join(content.split("\n")[1:-1])
1012
  def tool_call(**kwargs):
1013
  return kwargs
1014
+ parameters = eval(content)
1015
+ content = {"name": metadata.strip(), "parameters": parameters}
1016
  return content, history
1017
 
1018
  @torch.inference_mode()
1019
+ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
1020
  max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1021
  **kwargs):
1022
  if history is None:
 
1031
  eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
1032
  tokenizer.get_command("<|observation|>")]
1033
  outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1034
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1035
  response = tokenizer.decode(outputs)
1036
  history.append({"role": role, "content": query})
1037
  response, history = self.process_response(response, history)
1038
  return response, history
1039
 
1040
  @torch.inference_mode()
1041
+ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
1042
  past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
1043
  logits_processor=None, return_past_key_values=False, **kwargs):
1044
+ print(history)
1045
  if history is None:
1046
  history = []
1047
  if logits_processor is None:
 
1055
  inputs = tokenizer.build_chat_input(query, history=history, role=role)
1056
  else:
1057
  inputs = tokenizer.build_chat_input(query, role=role)
1058
+ inputs = inputs.to(self.device)
1059
  if past_key_values is not None:
1060
  past_length = past_key_values[0][0].shape[0]
1061
  if self.transformer.pre_seq_len is not None:
 
1064
  attention_mask = inputs.attention_mask
1065
  attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
1066
  inputs['attention_mask'] = attention_mask
1067
+ history.append({"role": role, "content": query})
1068
  for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
1069
  eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
1070
  **gen_kwargs):
1071
  if return_past_key_values:
1072
  outputs, past_key_values = outputs
1073
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1074
  response = tokenizer.decode(outputs)
1075
  if response and response[-1] != "�":
1076
+ response, new_history = self.process_response(response, history)
 
1077
  if return_past_key_values:
1078
  yield response, new_history, past_key_values
1079
  else: