SauravMaheshkar commited on
Commit
3dcca3c
·
unverified ·
1 Parent(s): 3ac5059

fix?: necessarily use CUDA

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -15,11 +15,11 @@ from src.plot_utils import export_mask
15
 
16
  @spaces.GPU()
17
  def predict(model_choice, annotations: Dict[str, Any]):
18
- device = "cuda" if torch.cuda.is_available() else "cpu"
19
  sam2_model = load_model(
20
  variant=model_choice,
21
  ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
22
- device=device,
23
  )
24
  predictor = SAM2ImagePredictor(sam2_model) # type:ignore
25
  predictor.set_image(annotations["image"])
 
15
 
16
  @spaces.GPU()
17
  def predict(model_choice, annotations: Dict[str, Any]):
18
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
19
  sam2_model = load_model(
20
  variant=model_choice,
21
  ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
22
+ device="cuda",
23
  )
24
  predictor = SAM2ImagePredictor(sam2_model) # type:ignore
25
  predictor.set_image(annotations["image"])