Mar2Ding commited on
Commit
612947c
·
verified ·
1 Parent(s): 7456938

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -36
app.py CHANGED
@@ -55,8 +55,8 @@ def clear_points(image):
55
  # we clean all
56
  return [
57
  image, # first_frame_path
58
- [], # tracking_points
59
- [], # trackings_input_label
60
  image, # points_map
61
  #gr.State() # stored_inference_state
62
  ]
@@ -119,8 +119,8 @@ def preprocess_video_in(video_path):
119
 
120
  return [
121
  first_frame, # first_frame_path
122
- gr.State([]), # tracking_points
123
- gr.State([]),
124
  first_frame, # input_first_frame_image
125
  first_frame, # points_map
126
  extracted_frames_output_dir, # video_frames_dir
@@ -130,7 +130,6 @@ def preprocess_video_in(video_path):
130
  gr.update(open=False) # video_in_drawer
131
  ]
132
 
133
- @spaces.GPU(duration=120)
134
  def get_point(point_type, tracking_points, trackings_input_label, input_first_frame_image, evt: gr.SelectData):
135
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
136
 
@@ -166,13 +165,13 @@ def get_point(point_type, tracking_points, trackings_input_label, input_first_fr
166
 
167
  return tracking_points, trackings_input_label, selected_point_map
168
 
169
- # # use bfloat16 for the entire notebook
170
- # torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
171
 
172
- # if torch.cuda.get_device_properties(0).major >= 8:
173
- # # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
174
- # torch.backends.cuda.matmul.allow_tf32 = True
175
- # torch.backends.cudnn.allow_tf32 = True
176
 
177
  def show_mask(mask, ax, obj_id=None, random_color=False):
178
  if random_color:
@@ -218,8 +217,7 @@ def load_model(checkpoint):
218
  # return [sam2_checkpoint, model_cfg]
219
 
220
 
221
-
222
- @spaces.GPU(duration=120)
223
  def get_mask_sam_process(
224
  stored_inference_state,
225
  input_first_frame_image,
@@ -315,7 +313,7 @@ def get_mask_sam_process(
315
  # return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
316
  return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=False)
317
 
318
- @spaces.GPU(duration=120)
319
  def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
320
  #### PROPAGATION ####
321
  sam2_checkpoint, model_cfg = load_model(checkpoint)
@@ -415,36 +413,25 @@ def switch_working_frame(working_frame, scanned_frames, video_frames_dir):
415
  frame_number = int(match.group(1))
416
  ann_frame_idx = frame_number
417
  new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
418
- return [], [], new_working_frame, new_working_frame
419
-
420
 
421
- @spaces.GPU(duration=120)
422
  def reset_propagation(first_frame_path, predictor, stored_inference_state):
423
 
424
  predictor.reset_state(stored_inference_state)
425
  # print(f"RESET State: {stored_inference_state} ")
426
- return first_frame_path, [], [], gr.update(value=None, visible=False), stored_inference_state, None, ["frame_0.jpg"], first_frame_path, "frame_0.jpg", gr.update(visible=False)
427
 
428
 
429
  with gr.Blocks(css=css) as demo:
430
- # first_frame_path = gr.State()
431
- # tracking_points = gr.State([])
432
- # trackings_input_label = gr.State([])
433
- # video_frames_dir = gr.State()
434
- # scanned_frames = gr.State()
435
- # loaded_predictor = gr.State()
436
- # stored_inference_state = gr.State()
437
- # stored_frame_names = gr.State()
438
- # available_frames_to_check = gr.State([])
439
- first_frame_path = None
440
- tracking_points = []
441
- trackings_input_label = []
442
- video_frames_dir = None
443
- scanned_frames = None
444
- loaded_predictor = None
445
- stored_inference_state = None
446
- stored_frame_names = None
447
- available_frames_to_check = []
448
  with gr.Column():
449
  gr.Markdown(
450
  """
 
55
  # we clean all
56
  return [
57
  image, # first_frame_path
58
+ gr.State([]), # tracking_points
59
+ gr.State([]), # trackings_input_label
60
  image, # points_map
61
  #gr.State() # stored_inference_state
62
  ]
 
119
 
120
  return [
121
  first_frame, # first_frame_path
122
+ [], # tracking_points
123
+ [], # trackings_input_label
124
  first_frame, # input_first_frame_image
125
  first_frame, # points_map
126
  extracted_frames_output_dir, # video_frames_dir
 
130
  gr.update(open=False) # video_in_drawer
131
  ]
132
 
 
133
  def get_point(point_type, tracking_points, trackings_input_label, input_first_frame_image, evt: gr.SelectData):
134
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
135
 
 
165
 
166
  return tracking_points, trackings_input_label, selected_point_map
167
 
168
+ # use bfloat16 for the entire notebook
169
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
170
 
171
+ if torch.cuda.get_device_properties(0).major >= 8:
172
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
173
+ torch.backends.cuda.matmul.allow_tf32 = True
174
+ torch.backends.cudnn.allow_tf32 = True
175
 
176
  def show_mask(mask, ax, obj_id=None, random_color=False):
177
  if random_color:
 
217
  # return [sam2_checkpoint, model_cfg]
218
 
219
 
220
+
 
221
  def get_mask_sam_process(
222
  stored_inference_state,
223
  input_first_frame_image,
 
313
  # return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
314
  return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=False)
315
 
316
+ @spaces.GPU(duration=180)
317
  def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
318
  #### PROPAGATION ####
319
  sam2_checkpoint, model_cfg = load_model(checkpoint)
 
413
  frame_number = int(match.group(1))
414
  ann_frame_idx = frame_number
415
  new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
416
+ return gr.State([]), gr.State([]), new_working_frame, new_working_frame
 
417
 
 
418
  def reset_propagation(first_frame_path, predictor, stored_inference_state):
419
 
420
  predictor.reset_state(stored_inference_state)
421
  # print(f"RESET State: {stored_inference_state} ")
422
+ return first_frame_path, gr.State([]), gr.State([]), gr.update(value=None, visible=False), stored_inference_state, None, ["frame_0.jpg"], first_frame_path, "frame_0.jpg", gr.update(visible=False)
423
 
424
 
425
  with gr.Blocks(css=css) as demo:
426
+ first_frame_path = gr.State()
427
+ tracking_points = gr.State([])
428
+ trackings_input_label = gr.State([])
429
+ video_frames_dir = gr.State()
430
+ scanned_frames = gr.State()
431
+ loaded_predictor = gr.State()
432
+ stored_inference_state = gr.State()
433
+ stored_frame_names = gr.State()
434
+ available_frames_to_check = gr.State([])
 
 
 
 
 
 
 
 
 
435
  with gr.Column():
436
  gr.Markdown(
437
  """