chongzhou commited on
Commit
3da2a0c
·
1 Parent(s): 0a7fba1

make gr.State individual

Browse files
Files changed (1) hide show
  1. app.py +111 -67
app.py CHANGED
@@ -90,16 +90,22 @@ def get_video_fps(video_path):
90
  return fps
91
 
92
 
93
- def reset():
 
 
 
 
 
 
 
94
  predictor.to("cpu")
95
- session_state["input_points"] = []
96
- session_state["input_labels"] = []
97
 
98
- session_id = id(session_state)
99
  if global_inference_states[session_id] is not None:
100
  predictor.reset_state(global_inference_states[session_id])
101
- session_state["first_frame"] = None
102
- session_state["all_frames"] = None
103
  global_inference_states[session_id] = None
104
  return (
105
  None,
@@ -107,26 +113,38 @@ def reset():
107
  None,
108
  None,
109
  gr.update(value=None, visible=False),
110
- session_state,
 
 
 
111
  )
112
 
113
 
114
- def clear_points(session_state):
 
115
  predictor.to("cpu")
116
- session_state["input_points"] = []
117
- session_state["input_labels"] = []
118
- session_id = id(session_state)
119
  if global_inference_states[session_id]["tracking_has_started"]:
120
  predictor.reset_state(global_inference_states[session_id])
121
  return (
122
- session_state["first_frame"],
123
  None,
124
  gr.update(value=None, visible=False),
125
- session_state,
 
126
  )
127
 
128
 
129
- def preprocess_video_in(video_path, session_state):
 
 
 
 
 
 
 
 
130
  predictor.to("cpu")
131
  if video_path is None:
132
  return (
@@ -134,7 +152,10 @@ def preprocess_video_in(video_path, session_state):
134
  None, # points_map
135
  None, # output_image
136
  gr.update(value=None, visible=False), # output_video
137
- session_state,
 
 
 
138
  )
139
 
140
  # Read the first frame
@@ -146,7 +167,10 @@ def preprocess_video_in(video_path, session_state):
146
  None, # points_map
147
  None, # output_image
148
  gr.update(value=None, visible=False), # output_video
149
- session_state,
 
 
 
150
  )
151
 
152
  frame_number = 0
@@ -169,46 +193,51 @@ def preprocess_video_in(video_path, session_state):
169
  frame_number += 1
170
 
171
  cap.release()
172
- session_state["first_frame"] = copy.deepcopy(first_frame)
173
- session_state["all_frames"] = all_frames
174
 
175
- session_id = id(session_state)
176
  global_inference_states[session_id] = predictor.init_state(video_path=video_path)
177
 
178
- session_state["input_points"] = []
179
- session_state["input_labels"] = []
180
 
181
  return [
182
  gr.update(open=False), # video_in_drawer
183
  first_frame, # points_map
184
  None, # output_image
185
  gr.update(value=None, visible=False), # output_video
186
- session_state,
 
 
 
187
  ]
188
 
189
 
190
  @spaces.GPU
191
  def segment_with_points(
192
  point_type,
193
- session_state,
 
194
  evt: gr.SelectData,
 
195
  ):
 
196
  if torch.cuda.get_device_properties(0).major >= 8:
197
  torch.backends.cuda.matmul.allow_tf32 = True
198
  torch.backends.cudnn.allow_tf32 = True
199
  with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
200
  predictor.to("cuda")
201
- session_state["input_points"].append(evt.index)
202
- print(f"TRACKING INPUT POINT: {session_state['input_points']}")
203
 
204
  if point_type == "include":
205
- session_state["input_labels"].append(1)
206
  elif point_type == "exclude":
207
- session_state["input_labels"].append(0)
208
- print(f"TRACKING INPUT LABEL: {session_state['input_labels']}")
209
 
210
  # Open the image and get its dimensions
211
- transparent_background = Image.fromarray(session_state["first_frame"]).convert(
212
  "RGBA"
213
  )
214
  w, h = transparent_background.size
@@ -220,8 +249,8 @@ def segment_with_points(
220
  # Create a transparent layer to draw on
221
  transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
222
 
223
- for index, track in enumerate(session_state["input_points"]):
224
- if session_state["input_labels"][index] == 1:
225
  cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
226
  else:
227
  cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
@@ -233,10 +262,9 @@ def segment_with_points(
233
  )
234
 
235
  # Let's add a positive click at (x, y) = (210, 350) to get started
236
- points = np.array(session_state["input_points"], dtype=np.float32)
237
  # for labels, `1` means positive click and `0` means negative click
238
- labels = np.array(session_state["input_labels"], np.int32)
239
- session_id = id(session_state)
240
  _, _, out_mask_logits = predictor.add_new_points(
241
  inference_state=global_inference_states[session_id],
242
  frame_idx=0,
@@ -249,7 +277,7 @@ def segment_with_points(
249
  first_frame_output = Image.alpha_composite(transparent_background, mask_image)
250
 
251
  torch.cuda.empty_cache()
252
- return selected_point_map, first_frame_output, session_state
253
 
254
 
255
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
@@ -270,23 +298,21 @@ def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
270
  @spaces.GPU
271
  def propagate_to_all(
272
  video_in,
273
- session_state,
 
274
  ):
 
275
  predictor.to("cuda")
276
  if torch.cuda.get_device_properties(0).major >= 8:
277
  torch.backends.cuda.matmul.allow_tf32 = True
278
  torch.backends.cudnn.allow_tf32 = True
279
  with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
280
- session_id = id(session_state)
281
  if (
282
- len(session_state["input_points"]) == 0
283
  or video_in is None
284
  or global_inference_states[session_id] is None
285
  ):
286
- return (
287
- None,
288
- session_state,
289
- )
290
 
291
  # run propagation throughout the video and collect the results in a dict
292
  video_segments = (
@@ -307,7 +333,7 @@ def propagate_to_all(
307
  output_frames = []
308
  for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
309
  transparent_background = Image.fromarray(
310
- session_state["all_frames"][out_frame_idx]
311
  ).convert("RGBA")
312
  out_mask = video_segments[out_frame_idx][OBJ_ID]
313
  mask_image = show_mask(out_mask)
@@ -331,10 +357,7 @@ def propagate_to_all(
331
  # Write the result to a file
332
  clip.write_videofile(final_vid_output_path, codec="libx264")
333
 
334
- return (
335
- gr.update(value=final_vid_output_path),
336
- session_state,
337
- )
338
 
339
 
340
  def update_ui():
@@ -342,14 +365,10 @@ def update_ui():
342
 
343
 
344
  with gr.Blocks() as demo:
345
- session_state = gr.State(
346
- {
347
- "first_frame": None,
348
- "all_frames": None,
349
- "input_points": [],
350
- "input_labels": [],
351
- }
352
- )
353
 
354
  with gr.Column():
355
  # Title
@@ -399,14 +418,20 @@ with gr.Blocks() as demo:
399
  fn=preprocess_video_in,
400
  inputs=[
401
  video_in,
402
- session_state,
 
 
 
403
  ],
404
  outputs=[
405
  video_in_drawer, # Accordion to hide uploaded video player
406
  points_map, # Image component where we add new tracking points
407
  output_image,
408
  output_video,
409
- session_state,
 
 
 
410
  ],
411
  queue=False,
412
  )
@@ -415,14 +440,20 @@ with gr.Blocks() as demo:
415
  fn=preprocess_video_in,
416
  inputs=[
417
  video_in,
418
- session_state,
 
 
 
419
  ],
420
  outputs=[
421
  video_in_drawer, # Accordion to hide uploaded video player
422
  points_map, # Image component where we add new tracking points
423
  output_image,
424
  output_video,
425
- session_state,
 
 
 
426
  ],
427
  queue=False,
428
  )
@@ -432,12 +463,14 @@ with gr.Blocks() as demo:
432
  fn=segment_with_points,
433
  inputs=[
434
  point_type, # "include" or "exclude"
435
- session_state,
 
436
  ],
437
  outputs=[
438
  points_map, # updated image with points
439
  output_image,
440
- session_state,
 
441
  ],
442
  queue=False,
443
  )
@@ -445,26 +478,38 @@ with gr.Blocks() as demo:
445
  # Clear every points clicked and added to the map
446
  clear_points_btn.click(
447
  fn=clear_points,
448
- inputs=session_state,
 
 
 
449
  outputs=[
450
  points_map,
451
  output_image,
452
  output_video,
453
- session_state,
 
454
  ],
455
  queue=False,
456
  )
457
 
458
  reset_btn.click(
459
  fn=reset,
460
- inputs=session_state,
 
 
 
 
 
461
  outputs=[
462
  video_in,
463
  video_in_drawer,
464
  points_map,
465
  output_image,
466
  output_video,
467
- session_state,
 
 
 
468
  ],
469
  queue=False,
470
  )
@@ -478,11 +523,10 @@ with gr.Blocks() as demo:
478
  fn=propagate_to_all,
479
  inputs=[
480
  video_in,
481
- session_state,
482
  ],
483
  outputs=[
484
  output_video,
485
- session_state,
486
  ],
487
  concurrency_limit=10,
488
  queue=False,
 
90
  return fps
91
 
92
 
93
+ def reset(
94
+ session_first_frame,
95
+ session_all_frames,
96
+ session_input_points,
97
+ session_input_labels,
98
+ request: gr.Request,
99
+ ):
100
+ session_id = request.session_id
101
  predictor.to("cpu")
102
+ session_input_points = []
103
+ session_input_labels = []
104
 
 
105
  if global_inference_states[session_id] is not None:
106
  predictor.reset_state(global_inference_states[session_id])
107
+ session_first_frame = None
108
+ session_all_frames = None
109
  global_inference_states[session_id] = None
110
  return (
111
  None,
 
113
  None,
114
  None,
115
  gr.update(value=None, visible=False),
116
+ session_first_frame,
117
+ session_all_frames,
118
+ session_input_points,
119
+ session_input_labels,
120
  )
121
 
122
 
123
+ def clear_points(session_input_points, session_input_labels, request: gr.Request,):
124
+ session_id = request.session_id
125
  predictor.to("cpu")
126
+ session_input_points = []
127
+ session_input_labels = []
 
128
  if global_inference_states[session_id]["tracking_has_started"]:
129
  predictor.reset_state(global_inference_states[session_id])
130
  return (
131
+ session_first_frame,
132
  None,
133
  gr.update(value=None, visible=False),
134
+ session_input_points,
135
+ session_input_labels,
136
  )
137
 
138
 
139
+ def preprocess_video_in(
140
+ video_path,
141
+ session_first_frame,
142
+ session_all_frames,
143
+ session_input_points,
144
+ session_input_labels,
145
+ request: gr.Request,
146
+ ):
147
+ session_id = request.session_id
148
  predictor.to("cpu")
149
  if video_path is None:
150
  return (
 
152
  None, # points_map
153
  None, # output_image
154
  gr.update(value=None, visible=False), # output_video
155
+ session_first_frame,
156
+ session_all_frames,
157
+ session_input_points,
158
+ session_input_labels,
159
  )
160
 
161
  # Read the first frame
 
167
  None, # points_map
168
  None, # output_image
169
  gr.update(value=None, visible=False), # output_video
170
+ session_first_frame,
171
+ session_all_frames,
172
+ session_input_points,
173
+ session_input_labels,
174
  )
175
 
176
  frame_number = 0
 
193
  frame_number += 1
194
 
195
  cap.release()
196
+ session_first_frame = copy.deepcopy(first_frame)
197
+ session_all_frames = all_frames
198
 
 
199
  global_inference_states[session_id] = predictor.init_state(video_path=video_path)
200
 
201
+ session_input_points = []
202
+ session_input_labels = []
203
 
204
  return [
205
  gr.update(open=False), # video_in_drawer
206
  first_frame, # points_map
207
  None, # output_image
208
  gr.update(value=None, visible=False), # output_video
209
+ session_first_frame,
210
+ session_all_frames,
211
+ session_input_points,
212
+ session_input_labels,
213
  ]
214
 
215
 
216
  @spaces.GPU
217
  def segment_with_points(
218
  point_type,
219
+ session_input_points,
220
+ session_input_labels,
221
  evt: gr.SelectData,
222
+ request: gr.Request,
223
  ):
224
+ session_id = request.session_id
225
  if torch.cuda.get_device_properties(0).major >= 8:
226
  torch.backends.cuda.matmul.allow_tf32 = True
227
  torch.backends.cudnn.allow_tf32 = True
228
  with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
229
  predictor.to("cuda")
230
+ session_input_points.append(evt.index)
231
+ print(f"TRACKING INPUT POINT: {session_input_points}")
232
 
233
  if point_type == "include":
234
+ session_input_labels.append(1)
235
  elif point_type == "exclude":
236
+ session_input_labels.append(0)
237
+ print(f"TRACKING INPUT LABEL: {session_input_labels}")
238
 
239
  # Open the image and get its dimensions
240
+ transparent_background Image.fromarray(session_first_frame).convert(
241
  "RGBA"
242
  )
243
  w, h = transparent_background.size
 
249
  # Create a transparent layer to draw on
250
  transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
251
 
252
+ for index, track in enumerate(session_input_points):
253
+ if session_input_labels[index] == 1:
254
  cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
255
  else:
256
  cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
 
262
  )
263
 
264
  # Let's add a positive click at (x, y) = (210, 350) to get started
265
+ points = np.array(session_input_points, dtype=np.float32)
266
  # for labels, `1` means positive click and `0` means negative click
267
+ labels = np.array(session_input_labels, dtype=np.int32)
 
268
  _, _, out_mask_logits = predictor.add_new_points(
269
  inference_state=global_inference_states[session_id],
270
  frame_idx=0,
 
277
  first_frame_output = Image.alpha_composite(transparent_background, mask_image)
278
 
279
  torch.cuda.empty_cache()
280
+ return selected_point_map, first_frame_output, session_input_points, session_input_labels
281
 
282
 
283
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
 
298
  @spaces.GPU
299
  def propagate_to_all(
300
  video_in,
301
+ session_all_frames,
302
+ request: gr.Request,
303
  ):
304
+ session_id = request.session_id
305
  predictor.to("cuda")
306
  if torch.cuda.get_device_properties(0).major >= 8:
307
  torch.backends.cuda.matmul.allow_tf32 = True
308
  torch.backends.cudnn.allow_tf32 = True
309
  with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
 
310
  if (
311
+ len (session_input_points) == 0
312
  or video_in is None
313
  or global_inference_states[session_id] is None
314
  ):
315
+ return None
 
 
 
316
 
317
  # run propagation throughout the video and collect the results in a dict
318
  video_segments = (
 
333
  output_frames = []
334
  for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
335
  transparent_background = Image.fromarray(
336
+ session_all_frames[out_frame_idx]
337
  ).convert("RGBA")
338
  out_mask = video_segments[out_frame_idx][OBJ_ID]
339
  mask_image = show_mask(out_mask)
 
357
  # Write the result to a file
358
  clip.write_videofile(final_vid_output_path, codec="libx264")
359
 
360
+ return gr.update(value=final_vid_output_path)
 
 
 
361
 
362
 
363
  def update_ui():
 
365
 
366
 
367
  with gr.Blocks() as demo:
368
+ first_frame = gr.State(None)
369
+ all_frames = gr.State(None)
370
+ input_points = gr.State([])
371
+ input_labels = gr.State([])
 
 
 
 
372
 
373
  with gr.Column():
374
  # Title
 
418
  fn=preprocess_video_in,
419
  inputs=[
420
  video_in,
421
+ first_frame,
422
+ all_frames,
423
+ input_points,
424
+ input_labels,
425
  ],
426
  outputs=[
427
  video_in_drawer, # Accordion to hide uploaded video player
428
  points_map, # Image component where we add new tracking points
429
  output_image,
430
  output_video,
431
+ first_frame,
432
+ all_frames,
433
+ input_points,
434
+ input_labels,
435
  ],
436
  queue=False,
437
  )
 
440
  fn=preprocess_video_in,
441
  inputs=[
442
  video_in,
443
+ first_frame,
444
+ all_frames,
445
+ input_points,
446
+ input_labels,
447
  ],
448
  outputs=[
449
  video_in_drawer, # Accordion to hide uploaded video player
450
  points_map, # Image component where we add new tracking points
451
  output_image,
452
  output_video,
453
+ first_frame,
454
+ all_frames,
455
+ input_points,
456
+ input_labels,
457
  ],
458
  queue=False,
459
  )
 
463
  fn=segment_with_points,
464
  inputs=[
465
  point_type, # "include" or "exclude"
466
+ input_points,
467
+ input_labels,
468
  ],
469
  outputs=[
470
  points_map, # updated image with points
471
  output_image,
472
+ input_points,
473
+ input_labels,
474
  ],
475
  queue=False,
476
  )
 
478
  # Clear every points clicked and added to the map
479
  clear_points_btn.click(
480
  fn=clear_points,
481
+ inputs=[
482
+ input_points,
483
+ input_labels,
484
+ ],
485
  outputs=[
486
  points_map,
487
  output_image,
488
  output_video,
489
+ input_points,
490
+ input_labels,
491
  ],
492
  queue=False,
493
  )
494
 
495
  reset_btn.click(
496
  fn=reset,
497
+ inputs=[
498
+ first_frame,
499
+ all_frames,
500
+ input_points,
501
+ input_labels,
502
+ ],
503
  outputs=[
504
  video_in,
505
  video_in_drawer,
506
  points_map,
507
  output_image,
508
  output_video,
509
+ first_frame,
510
+ all_frames,
511
+ input_points,
512
+ input_labels,
513
  ],
514
  queue=False,
515
  )
 
523
  fn=propagate_to_all,
524
  inputs=[
525
  video_in,
526
+ all_frames,
527
  ],
528
  outputs=[
529
  output_video,
 
530
  ],
531
  concurrency_limit=10,
532
  queue=False,