fix?: necessarily use CUDA
Browse files
app.py
CHANGED
@@ -15,11 +15,11 @@ from src.plot_utils import export_mask
|
|
15 |
|
16 |
@spaces.GPU()
|
17 |
def predict(model_choice, annotations: Dict[str, Any]):
|
18 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
sam2_model = load_model(
|
20 |
variant=model_choice,
|
21 |
ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
|
22 |
-
device=
|
23 |
)
|
24 |
predictor = SAM2ImagePredictor(sam2_model) # type:ignore
|
25 |
predictor.set_image(annotations["image"])
|
|
|
15 |
|
16 |
@spaces.GPU()
|
17 |
def predict(model_choice, annotations: Dict[str, Any]):
|
18 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
sam2_model = load_model(
|
20 |
variant=model_choice,
|
21 |
ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
|
22 |
+
device="cuda",
|
23 |
)
|
24 |
predictor = SAM2ImagePredictor(sam2_model) # type:ignore
|
25 |
predictor.set_image(annotations["image"])
|