fix chat streaming
Browse files- modeling_qwen.py +79 -10
modeling_qwen.py
CHANGED
@@ -5,7 +5,7 @@
|
|
5 |
|
6 |
import importlib
|
7 |
import math
|
8 |
-
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List
|
9 |
|
10 |
import torch
|
11 |
import torch.nn.functional as F
|
@@ -53,6 +53,13 @@ _CONFIG_FOR_DOC = "QWenConfig"
|
|
53 |
|
54 |
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
apply_rotary_emb_func = None
|
57 |
rms_norm = None
|
58 |
flash_attn_unpadded_func = None
|
@@ -971,6 +978,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
971 |
stop_words_ids: Optional[List[List[int]]] = None,
|
972 |
**kwargs,
|
973 |
) -> Tuple[str, HistoryType]:
|
|
|
974 |
if history is None:
|
975 |
history = []
|
976 |
if stop_words_ids is None:
|
@@ -990,14 +998,17 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
990 |
))
|
991 |
input_ids = torch.tensor([context_tokens]).to(self.device)
|
992 |
if stream:
|
993 |
-
|
|
|
|
|
|
|
994 |
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
|
995 |
-
self.__class__.
|
996 |
self.__class__.sample_stream = NewGenerationMixin.sample_stream
|
997 |
stream_config = StreamGenerationConfig(**self.generation_config.to_dict(), do_stream=True)
|
998 |
def stream_generator():
|
999 |
outputs = []
|
1000 |
-
for token in self.
|
1001 |
input_ids, return_dict_in_generate=False, generation_config=stream_config, **kwargs):
|
1002 |
outputs.append(token.item())
|
1003 |
if outputs[-1] in (tokenizer.im_end_id, tokenizer.im_start_id):
|
@@ -1027,6 +1038,62 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
1027 |
|
1028 |
return response, history
|
1029 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1030 |
def generate(
|
1031 |
self,
|
1032 |
inputs: Optional[torch.Tensor] = None,
|
@@ -1037,6 +1104,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
1037 |
Callable[[int, torch.Tensor], List[int]]
|
1038 |
] = None,
|
1039 |
synced_gpus: Optional[bool] = None,
|
|
|
1040 |
streamer: Optional["BaseStreamer"] = None,
|
1041 |
**kwargs,
|
1042 |
) -> Union[GenerateOutput, torch.LongTensor]:
|
@@ -1059,12 +1127,13 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
1059 |
|
1060 |
return super().generate(
|
1061 |
inputs,
|
1062 |
-
generation_config,
|
1063 |
-
logits_processor,
|
1064 |
-
stopping_criteria,
|
1065 |
-
prefix_allowed_tokens_fn,
|
1066 |
-
synced_gpus,
|
1067 |
-
|
|
|
1068 |
**kwargs,
|
1069 |
)
|
1070 |
|
|
|
5 |
|
6 |
import importlib
|
7 |
import math
|
8 |
+
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
|
9 |
|
10 |
import torch
|
11 |
import torch.nn.functional as F
|
|
|
53 |
|
54 |
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
|
55 |
|
56 |
+
_ERROR_BAD_CHAT_FORMAT = """\
|
57 |
+
We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml".
|
58 |
+
If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat().
|
59 |
+
我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。
|
60 |
+
如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
|
61 |
+
"""
|
62 |
+
|
63 |
apply_rotary_emb_func = None
|
64 |
rms_norm = None
|
65 |
flash_attn_unpadded_func = None
|
|
|
978 |
stop_words_ids: Optional[List[List[int]]] = None,
|
979 |
**kwargs,
|
980 |
) -> Tuple[str, HistoryType]:
|
981 |
+
assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
|
982 |
if history is None:
|
983 |
history = []
|
984 |
if stop_words_ids is None:
|
|
|
998 |
))
|
999 |
input_ids = torch.tensor([context_tokens]).to(self.device)
|
1000 |
if stream:
|
1001 |
+
logger.warn(
|
1002 |
+
"[WARNING] This usage is deprecated and marked for removal."
|
1003 |
+
"Please use chat_stream() instead of chat(stream=True)."
|
1004 |
+
)
|
1005 |
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
|
1006 |
+
self.__class__.generate_stream = NewGenerationMixin.generate
|
1007 |
self.__class__.sample_stream = NewGenerationMixin.sample_stream
|
1008 |
stream_config = StreamGenerationConfig(**self.generation_config.to_dict(), do_stream=True)
|
1009 |
def stream_generator():
|
1010 |
outputs = []
|
1011 |
+
for token in self.generate_stream(
|
1012 |
input_ids, return_dict_in_generate=False, generation_config=stream_config, **kwargs):
|
1013 |
outputs.append(token.item())
|
1014 |
if outputs[-1] in (tokenizer.im_end_id, tokenizer.im_start_id):
|
|
|
1038 |
|
1039 |
return response, history
|
1040 |
|
1041 |
+
def chat_stream(
|
1042 |
+
self,
|
1043 |
+
tokenizer: PreTrainedTokenizer,
|
1044 |
+
query: str,
|
1045 |
+
history: Optional[HistoryType],
|
1046 |
+
system: str = "You are a helpful assistant.",
|
1047 |
+
stop_words_ids: Optional[List[List[int]]] = None,
|
1048 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
1049 |
+
**kwargs,
|
1050 |
+
) -> Generator[str, Any, None]:
|
1051 |
+
assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
|
1052 |
+
if history is None:
|
1053 |
+
history = []
|
1054 |
+
if stop_words_ids is None:
|
1055 |
+
stop_words_ids = []
|
1056 |
+
|
1057 |
+
raw_text, context_tokens = make_context(
|
1058 |
+
tokenizer,
|
1059 |
+
query,
|
1060 |
+
history=history,
|
1061 |
+
system=system,
|
1062 |
+
max_window_size=6144,
|
1063 |
+
chat_format=self.generation_config.chat_format,
|
1064 |
+
)
|
1065 |
+
|
1066 |
+
stop_words_ids.extend(get_stop_words_ids(
|
1067 |
+
self.generation_config.chat_format, tokenizer
|
1068 |
+
))
|
1069 |
+
if stop_words_ids is not None:
|
1070 |
+
stop_words_logits_processor = StopWordsLogitsProcessor(
|
1071 |
+
stop_words_ids=stop_words_ids,
|
1072 |
+
eos_token_id=self.generation_config.eos_token_id,
|
1073 |
+
)
|
1074 |
+
if logits_processor is None:
|
1075 |
+
logits_processor = LogitsProcessorList([stop_words_logits_processor])
|
1076 |
+
else:
|
1077 |
+
logits_processor.append(stop_words_logits_processor)
|
1078 |
+
input_ids = torch.tensor([context_tokens]).to(self.device)
|
1079 |
+
|
1080 |
+
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
|
1081 |
+
self.__class__.generate_stream = NewGenerationMixin.generate
|
1082 |
+
self.__class__.sample_stream = NewGenerationMixin.sample_stream
|
1083 |
+
stream_config = StreamGenerationConfig(**self.generation_config.to_dict(), do_stream=True)
|
1084 |
+
def stream_generator():
|
1085 |
+
outputs = []
|
1086 |
+
for token in self.generate_stream(
|
1087 |
+
input_ids,
|
1088 |
+
return_dict_in_generate=False,
|
1089 |
+
generation_config=stream_config,
|
1090 |
+
logits_processor=logits_processor,
|
1091 |
+
**kwargs):
|
1092 |
+
outputs.append(token.item())
|
1093 |
+
yield tokenizer.decode(outputs, skip_special_tokens=True, erros='ignore')
|
1094 |
+
|
1095 |
+
return stream_generator()
|
1096 |
+
|
1097 |
def generate(
|
1098 |
self,
|
1099 |
inputs: Optional[torch.Tensor] = None,
|
|
|
1104 |
Callable[[int, torch.Tensor], List[int]]
|
1105 |
] = None,
|
1106 |
synced_gpus: Optional[bool] = None,
|
1107 |
+
assistant_model: Optional["PreTrainedModel"] = None,
|
1108 |
streamer: Optional["BaseStreamer"] = None,
|
1109 |
**kwargs,
|
1110 |
) -> Union[GenerateOutput, torch.LongTensor]:
|
|
|
1127 |
|
1128 |
return super().generate(
|
1129 |
inputs,
|
1130 |
+
generation_config=generation_config,
|
1131 |
+
logits_processor=logits_processor,
|
1132 |
+
stopping_criteria=stopping_criteria,
|
1133 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
1134 |
+
synced_gpus=synced_gpus,
|
1135 |
+
assistant_model=assistant_model,
|
1136 |
+
streamer=streamer,
|
1137 |
**kwargs,
|
1138 |
)
|
1139 |
|