JeffLiang commited on
Commit
e9b7645
1 Parent(s): adca407
Files changed (1) hide show
  1. open_vocab_seg/utils/predictor.py +2 -2
open_vocab_seg/utils/predictor.py CHANGED
@@ -191,7 +191,7 @@ class SAMVisualizationDemo(object):
191
 
192
  with torch.no_grad(), torch.cuda.amp.autocast():
193
  image_features = self.clip_model.encode_image(imgs.cuda().half())
194
- text_features = self.clip_model.encode_text(text.cuda().half())
195
  image_features /= image_features.norm(dim=-1, keepdim=True)
196
  text_features /= text_features.norm(dim=-1, keepdim=True)
197
 
@@ -210,7 +210,7 @@ class SAMVisualizationDemo(object):
210
  select_mask.extend(locs[0].tolist())
211
  for idx in select_mask:
212
  select_cls[idx] = class_preds[idx]
213
- semseg = torch.einsum("qc,qhw->chw", select_cls, pred_masks.tensor.float())
214
 
215
  r = semseg
216
  blank_area = (r[0] == 0)
 
191
 
192
  with torch.no_grad(), torch.cuda.amp.autocast():
193
  image_features = self.clip_model.encode_image(imgs.cuda().half())
194
+ text_features = self.clip_model.encode_text(text.cuda())
195
  image_features /= image_features.norm(dim=-1, keepdim=True)
196
  text_features /= text_features.norm(dim=-1, keepdim=True)
197
 
 
210
  select_mask.extend(locs[0].tolist())
211
  for idx in select_mask:
212
  select_cls[idx] = class_preds[idx]
213
+ semseg = torch.einsum("qc,qhw->chw", select_cls.float(), pred_masks.tensor.float().cuda())
214
 
215
  r = semseg
216
  blank_area = (r[0] == 0)