yuxin commited on
Commit
e46628a
·
1 Parent(s): a3e2e90
Files changed (1) hide show
  1. model_segvol_single.py +1 -1
model_segvol_single.py CHANGED
@@ -101,7 +101,7 @@ class SegVolModel(PreTrainedModel):
101
  ## inference
102
  with torch.no_grad():
103
  logits_single_cropped = sliding_window_inference(
104
- image_single_cropped.cuda(), prompt_reflection,
105
  self.config.spatial_size, 1, self.model, 0.5,
106
  text=text_prompt,
107
  use_box=bbox_prompt is not None,
 
101
  ## inference
102
  with torch.no_grad():
103
  logits_single_cropped = sliding_window_inference(
104
+ image_single_cropped.to(self.custom_device), prompt_reflection,
105
  self.config.spatial_size, 1, self.model, 0.5,
106
  text=text_prompt,
107
  use_box=bbox_prompt is not None,