alexlimh commited on
Commit
b10a706
·
verified ·
1 Parent(s): 2466c58

Update modeling_jasper_vl.py

Browse files
Files changed (1) hide show
  1. modeling_jasper_vl.py +1 -1
modeling_jasper_vl.py CHANGED
@@ -1168,7 +1168,7 @@ class JasperVL(PreTrainedModel):
1168
  vit_embeds = self.vision_model(pixel_values=pixel_values, return_dict=True)["last_hidden_state"]
1169
  vit_embeds = self.adaptive_avg_pool2d(vit_embeds)
1170
  selected = (input_ids.reshape(B * N) == self.config.img_token_id)
1171
- vit_embeds = vit_embeds.to(dtype=inputs_embeds)
1172
  inputs_embeds[selected] = vit_embeds.reshape(-1, C)
1173
  inputs_embeds = inputs_embeds.reshape(B, N, C)
1174
  last_hidden_state = self.model(
 
1168
  vit_embeds = self.vision_model(pixel_values=pixel_values, return_dict=True)["last_hidden_state"]
1169
  vit_embeds = self.adaptive_avg_pool2d(vit_embeds)
1170
  selected = (input_ids.reshape(B * N) == self.config.img_token_id)
1171
+ vit_embeds = vit_embeds.to(dtype=inputs_embeds.dtype)
1172
  inputs_embeds[selected] = vit_embeds.reshape(-1, C)
1173
  inputs_embeds = inputs_embeds.reshape(B, N, C)
1174
  last_hidden_state = self.model(