howard-hou commited on
Commit
45cab51
·
1 Parent(s): cbf04ef

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +1 -2
modeling.py CHANGED
@@ -59,8 +59,7 @@ class EmbeddingMixer(nn.Module):
59
  self.image_start_index = len(original_embedding)
60
 
61
  def set_image_embeddings(self, image_embeddings):
62
- if len(image_embeddings.shape) == 3:
63
- image_embeddings = image_embeddings.squeeze(0) # remove batch dim
64
  end_index = self.image_start_index + image_embeddings.shape[0]
65
  self.embedding[self.image_start_index:end_index] = image_embeddings
66
 
 
59
  self.image_start_index = len(original_embedding)
60
 
61
  def set_image_embeddings(self, image_embeddings):
62
+ assert len(image_embeddings.shape) == 2, "image_embeddings should be 2D"
 
63
  end_index = self.image_start_index + image_embeddings.shape[0]
64
  self.embedding[self.image_start_index:end_index] = image_embeddings
65