Update modeling_jasper_vl.py
Browse files- modeling_jasper_vl.py +1 -1
modeling_jasper_vl.py
CHANGED
@@ -1168,7 +1168,7 @@ class JasperVL(PreTrainedModel):
|
|
1168 |
vit_embeds = self.vision_model(pixel_values=pixel_values, return_dict=True)["last_hidden_state"]
|
1169 |
vit_embeds = self.adaptive_avg_pool2d(vit_embeds)
|
1170 |
selected = (input_ids.reshape(B * N) == self.config.img_token_id)
|
1171 |
-
vit_embeds = vit_embeds.to(dtype=inputs_embeds)
|
1172 |
inputs_embeds[selected] = vit_embeds.reshape(-1, C)
|
1173 |
inputs_embeds = inputs_embeds.reshape(B, N, C)
|
1174 |
last_hidden_state = self.model(
|
|
|
1168 |
vit_embeds = self.vision_model(pixel_values=pixel_values, return_dict=True)["last_hidden_state"]
|
1169 |
vit_embeds = self.adaptive_avg_pool2d(vit_embeds)
|
1170 |
selected = (input_ids.reshape(B * N) == self.config.img_token_id)
|
1171 |
+
vit_embeds = vit_embeds.to(dtype=inputs_embeds.dtype)
|
1172 |
inputs_embeds[selected] = vit_embeds.reshape(-1, C)
|
1173 |
inputs_embeds = inputs_embeds.reshape(B, N, C)
|
1174 |
last_hidden_state = self.model(
|