directly use bool instead of torch.float16 to avoid crash in ASIC like HPU which does not support float16
72fc4ea
sywangyi
commited on