chongzhou commited on
Commit
113b7b2
·
1 Parent(s): 9bc4638

move model to cpu when not using ZeroGPU

Browse files
Files changed (1) hide show
  1. app.py +113 -126
app.py CHANGED
@@ -70,26 +70,9 @@ examples = [
70
 
71
  OBJ_ID = 0
72
 
73
-
74
- @spaces.GPU
75
- def get_predictor(session_state):
76
- if "predictor" not in session_state:
77
- sam2_checkpoint = "checkpoints/edgetam.pt"
78
- model_cfg = "edgetam.yaml"
79
- predictor = build_sam2_video_predictor(
80
- model_cfg, sam2_checkpoint, device="cuda"
81
- )
82
- print("predictor loaded")
83
-
84
- # use bfloat16 for the entire demo
85
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
86
- if torch.cuda.get_device_properties(0).major >= 8:
87
- # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
88
- torch.backends.cuda.matmul.allow_tf32 = True
89
- torch.backends.cudnn.allow_tf32 = True
90
-
91
- session_state["predictor"] = predictor
92
- return session_state["predictor"]
93
 
94
 
95
  def get_video_fps(video_path):
@@ -106,10 +89,8 @@ def get_video_fps(video_path):
106
  return fps
107
 
108
 
109
- @spaces.GPU
110
  def reset(session_state):
111
- predictor = get_predictor(session_state)
112
- predictor.to("cuda")
113
  session_state["input_points"] = []
114
  session_state["input_labels"] = []
115
  if session_state["inference_state"] is not None:
@@ -127,10 +108,8 @@ def reset(session_state):
127
  )
128
 
129
 
130
- @spaces.GPU
131
  def clear_points(session_state):
132
- predictor = get_predictor(session_state)
133
- predictor.to("cuda")
134
  session_state["input_points"] = []
135
  session_state["input_labels"] = []
136
  if session_state["inference_state"]["tracking_has_started"]:
@@ -143,10 +122,8 @@ def clear_points(session_state):
143
  )
144
 
145
 
146
- @spaces.GPU
147
  def preprocess_video_in(video_path, session_state):
148
- predictor = get_predictor(session_state)
149
- predictor.to("cuda")
150
  if video_path is None:
151
  return (
152
  gr.update(open=True), # video_in_drawer
@@ -210,59 +187,62 @@ def segment_with_points(
210
  session_state,
211
  evt: gr.SelectData,
212
  ):
213
- predictor = get_predictor(session_state)
214
- predictor.to("cuda")
215
- session_state["input_points"].append(evt.index)
216
- print(f"TRACKING INPUT POINT: {session_state['input_points']}")
217
-
218
- if point_type == "include":
219
- session_state["input_labels"].append(1)
220
- elif point_type == "exclude":
221
- session_state["input_labels"].append(0)
222
- print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")
223
-
224
- # Open the image and get its dimensions
225
- transparent_background = Image.fromarray(session_state["first_frame"]).convert(
226
- "RGBA"
227
- )
228
- w, h = transparent_background.size
 
 
 
229
 
230
- # Define the circle radius as a fraction of the smaller dimension
231
- fraction = 0.01 # You can adjust this value as needed
232
- radius = int(fraction * min(w, h))
233
 
234
- # Create a transparent layer to draw on
235
- transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
236
 
237
- for index, track in enumerate(session_state["input_points"]):
238
- if session_state["input_labels"][index] == 1:
239
- cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
240
- else:
241
- cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
242
 
243
- # Convert the transparent layer back to an image
244
- transparent_layer = Image.fromarray(transparent_layer, "RGBA")
245
- selected_point_map = Image.alpha_composite(
246
- transparent_background, transparent_layer
247
- )
248
 
249
- # Let's add a positive click at (x, y) = (210, 350) to get started
250
- points = np.array(session_state["input_points"], dtype=np.float32)
251
- # for labels, `1` means positive click and `0` means negative click
252
- labels = np.array(session_state["input_labels"], np.int32)
253
- _, _, out_mask_logits = predictor.add_new_points(
254
- inference_state=session_state["inference_state"],
255
- frame_idx=0,
256
- obj_id=OBJ_ID,
257
- points=points,
258
- labels=labels,
259
- )
260
 
261
- mask_image = show_mask((out_mask_logits[0] > 0.0).cpu().numpy())
262
- first_frame_output = Image.alpha_composite(transparent_background, mask_image)
263
 
264
- torch.cuda.empty_cache()
265
- return selected_point_map, first_frame_output, session_state
266
 
267
 
268
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
@@ -285,61 +265,68 @@ def propagate_to_all(
285
  video_in,
286
  session_state,
287
  ):
288
- predictor = get_predictor(session_state)
289
  predictor.to("cuda")
290
- if (
291
- len(session_state["input_points"]) == 0
292
- or video_in is None
293
- or session_state["inference_state"] is None
294
- ):
295
- return (
296
- None,
297
- session_state,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  )
299
 
300
- # run propagation throughout the video and collect the results in a dict
301
- video_segments = {} # video_segments contains the per-frame segmentation results
302
- print("starting propagate_in_video")
303
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
304
- session_state["inference_state"]
305
- ):
306
- video_segments[out_frame_idx] = {
307
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
308
- for i, out_obj_id in enumerate(out_obj_ids)
309
- }
310
-
311
- # obtain the segmentation results every few frames
312
- vis_frame_stride = 1
313
-
314
- output_frames = []
315
- for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
316
- transparent_background = Image.fromarray(
317
- session_state["all_frames"][out_frame_idx]
318
- ).convert("RGBA")
319
- out_mask = video_segments[out_frame_idx][OBJ_ID]
320
- mask_image = show_mask(out_mask)
321
- output_frame = Image.alpha_composite(transparent_background, mask_image)
322
- output_frame = np.array(output_frame)
323
- output_frames.append(output_frame)
324
-
325
- torch.cuda.empty_cache()
326
-
327
- # Create a video clip from the image sequence
328
- original_fps = get_video_fps(video_in)
329
- fps = original_fps # Frames per second
330
- clip = ImageSequenceClip(output_frames, fps=fps)
331
- # Write the result to a file
332
- unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
333
- final_vid_output_path = f"output_video_{unique_id}.mp4"
334
- final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)
335
-
336
- # Write the result to a file
337
- clip.write_videofile(final_vid_output_path, codec="libx264")
338
 
339
- return (
340
- gr.update(value=final_vid_output_path),
341
- session_state,
342
- )
343
 
344
 
345
  def update_ui():
 
70
 
71
  OBJ_ID = 0
72
 
73
+ sam2_checkpoint = "checkpoints/edgetam.pt"
74
+ model_cfg = "edgetam.yaml"
75
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
 
78
  def get_video_fps(video_path):
 
89
  return fps
90
 
91
 
 
92
  def reset(session_state):
93
+ predictor.to("cpu")
 
94
  session_state["input_points"] = []
95
  session_state["input_labels"] = []
96
  if session_state["inference_state"] is not None:
 
108
  )
109
 
110
 
 
111
  def clear_points(session_state):
112
+ predictor.to("cpu")
 
113
  session_state["input_points"] = []
114
  session_state["input_labels"] = []
115
  if session_state["inference_state"]["tracking_has_started"]:
 
122
  )
123
 
124
 
 
125
  def preprocess_video_in(video_path, session_state):
126
+ predictor.to("cpu")
 
127
  if video_path is None:
128
  return (
129
  gr.update(open=True), # video_in_drawer
 
187
  session_state,
188
  evt: gr.SelectData,
189
  ):
190
+ if torch.cuda.get_device_properties(0).major >= 8:
191
+ torch.backends.cuda.matmul.allow_tf32 = True
192
+ torch.backends.cudnn.allow_tf32 = True
193
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
194
+ predictor.to("cuda")
195
+ session_state["input_points"].append(evt.index)
196
+ print(f"TRACKING INPUT POINT: {session_state['input_points']}")
197
+
198
+ if point_type == "include":
199
+ session_state["input_labels"].append(1)
200
+ elif point_type == "exclude":
201
+ session_state["input_labels"].append(0)
202
+ print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")
203
+
204
+ # Open the image and get its dimensions
205
+ transparent_background = Image.fromarray(session_state["first_frame"]).convert(
206
+ "RGBA"
207
+ )
208
+ w, h = transparent_background.size
209
 
210
+ # Define the circle radius as a fraction of the smaller dimension
211
+ fraction = 0.01 # You can adjust this value as needed
212
+ radius = int(fraction * min(w, h))
213
 
214
+ # Create a transparent layer to draw on
215
+ transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
216
 
217
+ for index, track in enumerate(session_state["input_points"]):
218
+ if session_state["input_labels"][index] == 1:
219
+ cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
220
+ else:
221
+ cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
222
 
223
+ # Convert the transparent layer back to an image
224
+ transparent_layer = Image.fromarray(transparent_layer, "RGBA")
225
+ selected_point_map = Image.alpha_composite(
226
+ transparent_background, transparent_layer
227
+ )
228
 
229
+ # Let's add a positive click at (x, y) = (210, 350) to get started
230
+ points = np.array(session_state["input_points"], dtype=np.float32)
231
+ # for labels, `1` means positive click and `0` means negative click
232
+ labels = np.array(session_state["input_labels"], np.int32)
233
+ _, _, out_mask_logits = predictor.add_new_points(
234
+ inference_state=session_state["inference_state"],
235
+ frame_idx=0,
236
+ obj_id=OBJ_ID,
237
+ points=points,
238
+ labels=labels,
239
+ )
240
 
241
+ mask_image = show_mask((out_mask_logits[0] > 0.0).cpu().numpy())
242
+ first_frame_output = Image.alpha_composite(transparent_background, mask_image)
243
 
244
+ torch.cuda.empty_cache()
245
+ return selected_point_map, first_frame_output, session_state
246
 
247
 
248
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
 
265
  video_in,
266
  session_state,
267
  ):
 
268
  predictor.to("cuda")
269
+ if torch.cuda.get_device_properties(0).major >= 8:
270
+ torch.backends.cuda.matmul.allow_tf32 = True
271
+ torch.backends.cudnn.allow_tf32 = True
272
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
273
+ if (
274
+ len(session_state["input_points"]) == 0
275
+ or video_in is None
276
+ or session_state["inference_state"] is None
277
+ ):
278
+ return (
279
+ None,
280
+ session_state,
281
+ )
282
+
283
+ # run propagation throughout the video and collect the results in a dict
284
+ video_segments = (
285
+ {}
286
+ ) # video_segments contains the per-frame segmentation results
287
+ print("starting propagate_in_video")
288
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
289
+ session_state["inference_state"]
290
+ ):
291
+ video_segments[out_frame_idx] = {
292
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
293
+ for i, out_obj_id in enumerate(out_obj_ids)
294
+ }
295
+
296
+ # obtain the segmentation results every few frames
297
+ vis_frame_stride = 1
298
+
299
+ output_frames = []
300
+ for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
301
+ transparent_background = Image.fromarray(
302
+ session_state["all_frames"][out_frame_idx]
303
+ ).convert("RGBA")
304
+ out_mask = video_segments[out_frame_idx][OBJ_ID]
305
+ mask_image = show_mask(out_mask)
306
+ output_frame = Image.alpha_composite(transparent_background, mask_image)
307
+ output_frame = np.array(output_frame)
308
+ output_frames.append(output_frame)
309
+
310
+ torch.cuda.empty_cache()
311
+
312
+ # Create a video clip from the image sequence
313
+ original_fps = get_video_fps(video_in)
314
+ fps = original_fps # Frames per second
315
+ clip = ImageSequenceClip(output_frames, fps=fps)
316
+ # Write the result to a file
317
+ unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
318
+ final_vid_output_path = f"output_video_{unique_id}.mp4"
319
+ final_vid_output_path = os.path.join(
320
+ tempfile.gettempdir(), final_vid_output_path
321
  )
322
 
323
+ # Write the result to a file
324
+ clip.write_videofile(final_vid_output_path, codec="libx264")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
+ return (
327
+ gr.update(value=final_vid_output_path),
328
+ session_state,
329
+ )
330
 
331
 
332
  def update_ui():