Bug of modeling_gemma.py in transformers 4.38.0

#46
by zlk - opened

Traceback (most recent call last):
File "/mnt/bn/motor-nlp-team/mlx/users/zhangkaiqi.zlkqz/repo/5355/Personal_repo/test.py", line 15, in
outputs = model.generate(**input_ids, max_new_tokens=1024)
File "/usr/local/lib/python3.9/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/tiger/.local/lib/python3.9/site-packages/transformers/generation/utils.py", line 1544, in generate
return self.greedy_search(
File "/home/tiger/.local/lib/python3.9/site-packages/transformers/generation/utils.py", line 2404, in greedy_search
outputs = self(
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tiger/.local/lib/python3.9/site-packages/transformers/models/gemma/modeling_gemma.py", line 1068, in forward
outputs = self.model(
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tiger/.local/lib/python3.9/site-packages/transformers/models/gemma/modeling_gemma.py", line 906, in forward
layer_outputs = decoder_layer(
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tiger/.local/lib/python3.9/site-packages/transformers/models/gemma/modeling_gemma.py", line 626, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tiger/.local/lib/python3.9/site-packages/transformers/models/gemma/modeling_gemma.py", line 280, in forward
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
RuntimeError: shape '[1, 13, 3072]' is invalid for input of size 53248

Solution:

In modeling_gemma.py in transformers 4.38.0 line 280:
You should change source code:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
to:
attn_output = attn_output.reshape(bsz, q_len, self.num_heads*self.head_dim)

Because the next line is:
attn_output = self.o_proj(attn_output)
attn_output will multiply with o_proj(shape=self.num_heads*self.head_dim, self.hidden_size)
The multiplication operation will fail

hi @zlk
Thanks for the issue! Can you try to upgrade transformers to 4.38.1, we made a patch release to include the fix

OK, Thanks, But your article in https://huggingface.co/blog/gemma recommends 4.38.0. So you can also update the article

Thanks @zlk , good point! I made https://github.com/huggingface/blog/pull/1851 which should adapt the blogpost according

Google org

The post has now been updated, thank you @zlk and @ybelkada !

pcuenq changed discussion status to closed

Sign up or log in to comment