Mar2Ding commited on
Commit
0ad1803
·
verified ·
1 Parent(s): f43ee2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -241,9 +241,11 @@ def get_mask_sam_process(
241
 
242
  # set predictor
243
  if torch.cuda.is_available():
244
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
 
245
  else:
246
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device='cpu')
 
247
 
248
  print("PREDICTOR READY")
249
 
@@ -326,12 +328,15 @@ def get_mask_sam_process(
326
  def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
327
  #### PROPAGATION ####
328
  sam2_checkpoint, model_cfg = load_model(checkpoint)
 
329
  if torch.cuda.is_available():
 
330
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
331
  else:
 
332
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device='cpu')
333
-
334
 
 
335
  inference_state = stored_inference_state
336
  frame_names = stored_frame_names
337
  video_dir = video_frames_dir
 
241
 
242
  # set predictor
243
  if torch.cuda.is_available():
244
+ inference_state["device"] = 'cuda'
245
+ # predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
246
  else:
247
+ inference_state["device"] = 'cpu'
248
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device='cpu')
249
 
250
  print("PREDICTOR READY")
251
 
 
328
  def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
329
  #### PROPAGATION ####
330
  sam2_checkpoint, model_cfg = load_model(checkpoint)
331
+ # set predictor
332
  if torch.cuda.is_available():
333
+ inference_state["device"] = 'cuda'
334
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
335
  else:
336
+ inference_state["device"] = 'cpu'
337
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device='cpu')
 
338
 
339
+
340
  inference_state = stored_inference_state
341
  frame_names = stored_frame_names
342
  video_dir = video_frames_dir