BenkHel commited on
Commit
6f484e9
·
verified ·
1 Parent(s): b2ea8ea

Update cumo/model/multimodal_encoder/clip_encoder.py

Browse files
cumo/model/multimodal_encoder/clip_encoder.py CHANGED
@@ -48,9 +48,7 @@ class CLIPVisionTower(nn.Module):
48
  self.is_loaded = True
49
 
50
  def feature_select(self, image_features):
51
- # Take first element if output is a tuple
52
- if isinstance(image_features, tuple):
53
- image_features = image_features[0]
54
  if self.select_feature == 'patch':
55
  image_features = image_features[:, 1:]
56
  elif self.select_feature == 'cls_patch':
@@ -59,7 +57,6 @@ class CLIPVisionTower(nn.Module):
59
  raise ValueError(f'Unexpected select feature: {self.select_feature}')
60
  return image_features
61
 
62
-
63
  def split_chessboard(self, x, num_split):
64
  """
65
  x: b * c * h * w
@@ -87,13 +84,9 @@ class CLIPVisionTower(nn.Module):
87
  if type(images) is list:
88
  image_features = []
89
  for image in images:
90
- dev = image.device if hasattr(image, "device") else torch.device("cuda" if torch.cuda.is_available() else "cpu")
91
- dt = image.dtype if hasattr(image, "dtype") else torch.float16
92
- print("Image shape before vision_model:", image.shape)
93
- image_forward_out = self.vision_model(image.to(device=dev, dtype=dt))
94
  image_feature = self.feature_select(image_forward_out).to(image.dtype)
95
  image_features.append(image_feature)
96
-
97
  else:
98
  input_size = images.shape[3]
99
  img_sizes = [int(input_size * scale) for scale in self.scales]
 
48
  self.is_loaded = True
49
 
50
  def feature_select(self, image_features):
51
+ #image_features = image_forward_outs.hidden_states[self.select_layer]
 
 
52
  if self.select_feature == 'patch':
53
  image_features = image_features[:, 1:]
54
  elif self.select_feature == 'cls_patch':
 
57
  raise ValueError(f'Unexpected select feature: {self.select_feature}')
58
  return image_features
59
 
 
60
  def split_chessboard(self, x, num_split):
61
  """
62
  x: b * c * h * w
 
84
  if type(images) is list:
85
  image_features = []
86
  for image in images:
87
+ image_forward_out = self.vision_model(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
 
 
 
88
  image_feature = self.feature_select(image_forward_out).to(image.dtype)
89
  image_features.append(image_feature)
 
90
  else:
91
  input_size = images.shape[3]
92
  img_sizes = [int(input_size * scale) for scale in self.scales]