Spaces:
Running
on
Zero
Running
on
Zero
Update utils/model.py
Browse files- utils/model.py +3 -1
utils/model.py
CHANGED
@@ -424,7 +424,9 @@ class OwlViTForClassification(nn.Module):
|
|
424 |
txt_embeds = self.owlvit.text_model.embeddings.token_embedding(text_inputs_parts['input_ids'])
|
425 |
print(f"position_embedding: {self.owlvit.text_model.embeddings.position_embedding(position_ids).shape}")
|
426 |
print(f"text_embeds: {txt_embeds.shape}")
|
427 |
-
|
|
|
|
|
428 |
text_inputs_parts["position_ids"] = position_ids
|
429 |
position_ids = position_ids.repeat(1, 1, txt_embeds.size(-1) // position_ids.size(-1))
|
430 |
print(f"pos + emb: {(txt_embeds + position_ids).shape}")
|
|
|
424 |
txt_embeds = self.owlvit.text_model.embeddings.token_embedding(text_inputs_parts['input_ids'])
|
425 |
print(f"position_embedding: {self.owlvit.text_model.embeddings.position_embedding(position_ids).shape}")
|
426 |
print(f"text_embeds: {txt_embeds.shape}")
|
427 |
+
|
428 |
+
device_ = txt_embeds.device
|
429 |
+
position_ids = position_ids.to(device_)
|
430 |
text_inputs_parts["position_ids"] = position_ids
|
431 |
position_ids = position_ids.repeat(1, 1, txt_embeds.size(-1) // position_ids.size(-1))
|
432 |
print(f"pos + emb: {(txt_embeds + position_ids).shape}")
|