int4量化Qwen/Qwen-14B-Chat运行出错

#2
by Trenx - opened

问题:int4量化后key.dtype为float16,但是query.dtype仍然为float32,在进行query和key点乘时报出类型错误
建议:在258行_attn函数中第一行加入判断query和key是否为同一类型的判断,并统一两者类型

+1,load in 4bit 报错如下:

  File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/qwen-chat-pytorch-14b/modeling_qwen.py", line 1109, in forward
    transformer_outputs = self.transformer(
  File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/qwen-chat-pytorch-14b/modeling_qwen.py", line 930, in forward
    outputs = block(
  File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/qwen-chat-pytorch-14b/modeling_qwen.py", line 631, in forward
    attn_outputs = self.attn(
  File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/wz/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/qwen-chat-pytorch-14b/modeling_qwen.py", line 556, in forward
    attn_output, attn_weight = self._attn(
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/qwen-chat-pytorch-14b/modeling_qwen.py", line 314, in _attn
    attn_weights = torch.matmul(query, key.transpose(-1, -2))
RuntimeError: [address=127.0.0.1:36263, pid=21908] expected scalar type Half but found Float

How about using Qwen-14B-Int4? This one performs better than BNB. Check the section quantization in our github readme for more information.

jklj077 changed discussion status to closed
Your need to confirm your account before you can post a new comment.

Sign up or log in to comment