It seems that this project can only support a batch_size of 1 during inference?
#1
by
howard-hou
- opened
I try to make the input batch_size =2 by
inputs = tokenizer([prompt, prompt], return_tensors="pt")
output = model.generate(inputs["input_ids"], max_new_tokens=256)
and it raise a runtime error:
175 # https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L693
--> 176 key = self.key(key).to(torch.float32).view(T, H, S).transpose(0, 1).transpose(-2, -1)
177 value = self.value(value).to(torch.float32).view(T, H, S).transpose(0, 1)
178 receptance = self.receptance(receptance).to(torch.float32).view(T, H, S).transpose(0, 1)
RuntimeError: shape '[47, 32, 64]' is invalid for input of size 192512
Same problem encountered.