yunyangx commited on
Commit
0a9a024
·
verified ·
1 Parent(s): 1effced

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -3
app.py CHANGED
@@ -77,6 +77,9 @@ def clear_points(image):
77
  image, # points_map
78
  ]
79
 
 
 
 
80
  def preprocess_video_in(video_path):
81
  if video_path is None:
82
  return None, gr.State([]), gr.State([]), None, None, None, None, None, None, gr.update(open=True)
@@ -149,6 +152,9 @@ def preprocess_video_in(video_path):
149
  ]
150
 
151
 
 
 
 
152
  def get_point(point_type, tracking_points, trackings_input_label, input_first_frame_image, evt: gr.SelectData):
153
  if input_first_frame_image is None:
154
  return gr.State([]), gr.State([]), None
@@ -193,7 +199,8 @@ if torch.cuda.get_device_properties(0).major >= 8:
193
  # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
194
  torch.backends.cuda.matmul.allow_tf32 = True
195
  torch.backends.cudnn.allow_tf32 = True
196
-
 
197
  def show_mask(mask, ax, obj_id=None, random_color=False):
198
  if random_color:
199
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
@@ -206,19 +213,20 @@ def show_mask(mask, ax, obj_id=None, random_color=False):
206
  ax.axis('off')
207
  ax.imshow(mask_image)
208
 
209
-
210
  def show_points(coords, labels, ax, marker_size=200):
211
  pos_points = coords[labels==1]
212
  neg_points = coords[labels==0]
213
  ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
214
  ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
215
 
 
216
  def show_box(box, ax):
217
  x0, y0 = box[0], box[1]
218
  w, h = box[2] - box[0], box[3] - box[1]
219
  ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
220
 
221
-
222
  def load_model(checkpoint):
223
  # Load model accordingly to user's choice
224
  if checkpoint == "efficienttam_s":
@@ -431,12 +439,14 @@ def propagate_to_all(tracking_points, video_in, checkpoint, stored_inference_sta
431
 
432
  return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True)
433
 
 
434
  def update_ui(vis_frame_type):
435
  if vis_frame_type == "coarse":
436
  return gr.update(visible=True), gr.update(visible=False)
437
  elif vis_frame_type == "fine":
438
  return gr.update(visible=False), gr.update(visible=True)
439
 
 
440
  def switch_working_frame(working_frame, scanned_frames, video_frames_dir):
441
  new_working_frame = None
442
  if working_frame == None:
@@ -452,6 +462,7 @@ def switch_working_frame(working_frame, scanned_frames, video_frames_dir):
452
  new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
453
  return gr.State([]), gr.State([]), new_working_frame, new_working_frame
454
 
 
455
  def reset_propagation(first_frame_path, predictor, stored_inference_state):
456
  predictor.reset_state(stored_inference_state)
457
  # print(f"RESET State: {stored_inference_state} ")
 
77
  image, # points_map
78
  ]
79
 
80
+ @spaces.GPU
81
+ @torch.inference_mode()
82
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
83
  def preprocess_video_in(video_path):
84
  if video_path is None:
85
  return None, gr.State([]), gr.State([]), None, None, None, None, None, None, gr.update(open=True)
 
152
  ]
153
 
154
 
155
+ @spaces.GPU
156
+ @torch.inference_mode()
157
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
158
  def get_point(point_type, tracking_points, trackings_input_label, input_first_frame_image, evt: gr.SelectData):
159
  if input_first_frame_image is None:
160
  return gr.State([]), gr.State([]), None
 
199
  # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
200
  torch.backends.cuda.matmul.allow_tf32 = True
201
  torch.backends.cudnn.allow_tf32 = True
202
+
203
+ @spaces.GPU
204
  def show_mask(mask, ax, obj_id=None, random_color=False):
205
  if random_color:
206
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
 
213
  ax.axis('off')
214
  ax.imshow(mask_image)
215
 
216
+ @spaces.GPU
217
  def show_points(coords, labels, ax, marker_size=200):
218
  pos_points = coords[labels==1]
219
  neg_points = coords[labels==0]
220
  ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
221
  ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
222
 
223
+ @spaces.GPU
224
  def show_box(box, ax):
225
  x0, y0 = box[0], box[1]
226
  w, h = box[2] - box[0], box[3] - box[1]
227
  ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
228
 
229
+ @spaces.GPU
230
  def load_model(checkpoint):
231
  # Load model accordingly to user's choice
232
  if checkpoint == "efficienttam_s":
 
439
 
440
  return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True)
441
 
442
+ @spaces.GPU
443
  def update_ui(vis_frame_type):
444
  if vis_frame_type == "coarse":
445
  return gr.update(visible=True), gr.update(visible=False)
446
  elif vis_frame_type == "fine":
447
  return gr.update(visible=False), gr.update(visible=True)
448
 
449
+ @spaces.GPU
450
  def switch_working_frame(working_frame, scanned_frames, video_frames_dir):
451
  new_working_frame = None
452
  if working_frame == None:
 
462
  new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
463
  return gr.State([]), gr.State([]), new_working_frame, new_working_frame
464
 
465
+ @spaces.GPU
466
  def reset_propagation(first_frame_path, predictor, stored_inference_state):
467
  predictor.reset_state(stored_inference_state)
468
  # print(f"RESET State: {stored_inference_state} ")