yunyangx commited on
Commit
15b7466
·
verified ·
1 Parent(s): e868605

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -257,7 +257,10 @@ def load_model(checkpoint):
257
  efficienttam_checkpoint = "./checkpoints/demo/efficienttam_s.pt"
258
  model_cfg = "efficienttam_s.yaml"
259
  return [efficienttam_checkpoint, model_cfg]
260
-
 
 
 
261
  def get_mask_sam_process(
262
  stored_inference_state,
263
  input_first_frame_image,
@@ -349,6 +352,9 @@ def get_mask_sam_process(
349
 
350
  return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
351
 
 
 
 
352
  def propagate_to_all(tracking_points, video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame):
353
  if tracking_points is None or video_in is None or checkpoint is None or stored_inference_state is None:
354
  return gr.update(value=None), gr.update(value=None), gr.update(value=None), available_frames_to_check, gr.update(visible=False)
 
257
  efficienttam_checkpoint = "./checkpoints/demo/efficienttam_s.pt"
258
  model_cfg = "efficienttam_s.yaml"
259
  return [efficienttam_checkpoint, model_cfg]
260
+
261
+ @spaces.GPU
262
+ @torch.inference_mode()
263
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
264
  def get_mask_sam_process(
265
  stored_inference_state,
266
  input_first_frame_image,
 
352
 
353
  return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
354
 
355
+ @spaces.GPU
356
+ @torch.inference_mode()
357
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
358
  def propagate_to_all(tracking_points, video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame):
359
  if tracking_points is None or video_in is None or checkpoint is None or stored_inference_state is None:
360
  return gr.update(value=None), gr.update(value=None), gr.update(value=None), available_frames_to_check, gr.update(visible=False)