Bug of modeling_gemma.py in transformers 4.38.0
Traceback (most recent call last):
File "/mnt/bn/motor-nlp-team/mlx/users/zhangkaiqi.zlkqz/repo/5355/Personal_repo/test.py", line 15, in
outputs = model.generate(**input_ids, max_new_tokens=1024)
File "/usr/local/lib/python3.9/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/tiger/.local/lib/python3.9/site-packages/transformers/generation/utils.py", line 1544, in generate
return self.greedy_search(
File "/home/tiger/.local/lib/python3.9/site-packages/transformers/generation/utils.py", line 2404, in greedy_search
outputs = self(
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tiger/.local/lib/python3.9/site-packages/transformers/models/gemma/modeling_gemma.py", line 1068, in forward
outputs = self.model(
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tiger/.local/lib/python3.9/site-packages/transformers/models/gemma/modeling_gemma.py", line 906, in forward
layer_outputs = decoder_layer(
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tiger/.local/lib/python3.9/site-packages/transformers/models/gemma/modeling_gemma.py", line 626, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tiger/.local/lib/python3.9/site-packages/transformers/models/gemma/modeling_gemma.py", line 280, in forward
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
RuntimeError: shape '[1, 13, 3072]' is invalid for input of size 53248
In modeling_gemma.py in transformers 4.38.0 line 280:
You should change source code:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
to:
attn_output = attn_output.reshape(bsz, q_len, self.num_heads*self.head_dim)
Because the next line is:
attn_output = self.o_proj(attn_output)
attn_output will multiply with o_proj(shape=self.num_heads*self.head_dim, self.hidden_size)
The multiplication operation will fail
?