model.generate cannot handle past_key_values correctly
#13
by
Zhuangl
- opened
Here is my script.
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextStreamer,
)
model_name_or_path='./glm-4-9b-chat'
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(device).eval()
streamer = TextStreamer(tokenizer, skip_prompt=True, decode_kwargs=dict(skip_special_tokens=True))
gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1, "streamer":streamer, "return_dict_in_generate": True}
past_key_values = None
inputs = None
system_message = "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。"
history = [{"role": "system", "content": system_message}]
past_key_values = None
while True:
query = input("Human: ")
if len(query.strip()) == 0:
history = [{"role": "system", "content": system_message}]
continue
history.append({
'role': "user",
"content": query
})
inputs = tokenizer.apply_chat_template(history,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
)
inputs = inputs.to(device)
print(inputs['input_ids'].shape)
with torch.no_grad():
print("Assistant:")
outputs = model.generate(**inputs, **gen_kwargs, past_key_values=past_key_values)
past_key_values = outputs['past_key_values']
outputs = outputs['sequences'][:, inputs['input_ids'].shape[1]:]
outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
history.append({
'role': "assistant",
"content": outputs
})
The first round chat works fine, but the second round (which has past_key_values passed into model.generate) runs into error:
Traceback (most recent call last):
File "/home/ubuntu/llm/generation.txt", line 52, in <module>
outputs = model.generate(**inputs, **gen_kwargs, past_key_values=past_key_values)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 1914, in generate
result = self._sample(
File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 2651, in _sample
outputs = self(
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 997, in forward
transformer_outputs = self.transformer(
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 882, in forward
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 783, in get_masks
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
RuntimeError: The size of tensor a (130) must match the size of tensor b (69) at non-singleton dimension 2
I tested this script on LLama model and it works well. So maybe you guys could share some suggestions based on ChatGLM modeling implementation.
past_key_values = outputs['past_key_values'] 这个已经不能用了,transofmers 4.42升级了写法,不是这么写的,可以看一下我们最新模型实现部分 把这个部分替换掉了
@zRzRzRzRzRzRzR
感谢回复。
我用的transformer==4.42.3,llama3上跑是没问题的。另外您说的最新模型实现 可以提供下link么,我现在用的还是这个repo里的modeling_chatglm.py
版本是这个,权重没有变换,用的最新的这个repo就行(需要pull 其他配置文件,然后你再试一下,
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) 这行代码应该是没有问题的