BenkHel commited on
Commit
6014840
·
verified ·
1 Parent(s): 788a968

Update cumo/model/multimodal_encoder/clip_encoder.py

Browse files
cumo/model/multimodal_encoder/clip_encoder.py CHANGED
@@ -84,9 +84,12 @@ class CLIPVisionTower(nn.Module):
84
  if type(images) is list:
85
  image_features = []
86
  for image in images:
87
- image_forward_out = self.vision_model(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
 
 
88
  image_feature = self.feature_select(image_forward_out).to(image.dtype)
89
  image_features.append(image_feature)
 
90
  else:
91
  input_size = images.shape[3]
92
  img_sizes = [int(input_size * scale) for scale in self.scales]
 
84
  if type(images) is list:
85
  image_features = []
86
  for image in images:
87
+ dev = image.device if hasattr(image, "device") else torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
+ dt = image.dtype if hasattr(image, "dtype") else torch.float16
89
+ image_forward_out = self.vision_model(image.to(device=dev, dtype=dt).unsqueeze(0), output_hidden_states=True)
90
  image_feature = self.feature_select(image_forward_out).to(image.dtype)
91
  image_features.append(image_feature)
92
+
93
  else:
94
  input_size = images.shape[3]
95
  img_sizes = [int(input_size * scale) for scale in self.scales]