Update modeling_kosmos2.py
Browse files- modeling_kosmos2.py +1 -1
modeling_kosmos2.py
CHANGED
@@ -1007,7 +1007,7 @@ class Kosmos2TextTransformer(nn.Module):
|
|
1007 |
inputs_embeds = self.embed_tokens(input_ids)
|
1008 |
|
1009 |
if img_features is not None:
|
1010 |
-
inputs_embeds[img_input_mask.to(dtype=torch.bool)] = img_features
|
1011 |
|
1012 |
inputs_embeds = inputs_embeds * self.embed_scale
|
1013 |
|
|
|
1007 |
inputs_embeds = self.embed_tokens(input_ids)
|
1008 |
|
1009 |
if img_features is not None:
|
1010 |
+
inputs_embeds[img_input_mask.to(dtype=torch.bool)] = img_features.view(-1, img_features.size(-1))
|
1011 |
|
1012 |
inputs_embeds = inputs_embeds * self.embed_scale
|
1013 |
|