BenkHel commited on
Commit
b2ea8ea
·
verified ·
1 Parent(s): 98866e7

Update cumo/model/multimodal_encoder/clip_encoder.py

Browse files
cumo/model/multimodal_encoder/clip_encoder.py CHANGED
@@ -48,7 +48,9 @@ class CLIPVisionTower(nn.Module):
48
  self.is_loaded = True
49
 
50
  def feature_select(self, image_features):
51
- #image_features = image_forward_outs.hidden_states[self.select_layer]
 
 
52
  if self.select_feature == 'patch':
53
  image_features = image_features[:, 1:]
54
  elif self.select_feature == 'cls_patch':
@@ -57,6 +59,7 @@ class CLIPVisionTower(nn.Module):
57
  raise ValueError(f'Unexpected select feature: {self.select_feature}')
58
  return image_features
59
 
 
60
  def split_chessboard(self, x, num_split):
61
  """
62
  x: b * c * h * w
 
48
  self.is_loaded = True
49
 
50
  def feature_select(self, image_features):
51
+ # Take first element if output is a tuple
52
+ if isinstance(image_features, tuple):
53
+ image_features = image_features[0]
54
  if self.select_feature == 'patch':
55
  image_features = image_features[:, 1:]
56
  elif self.select_feature == 'cls_patch':
 
59
  raise ValueError(f'Unexpected select feature: {self.select_feature}')
60
  return image_features
61
 
62
+
63
  def split_chessboard(self, x, num_split):
64
  """
65
  x: b * c * h * w