make gr.State individual
Browse files
app.py
CHANGED
@@ -90,16 +90,22 @@ def get_video_fps(video_path):
|
|
90 |
return fps
|
91 |
|
92 |
|
93 |
-
def reset(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
predictor.to("cpu")
|
95 |
-
|
96 |
-
|
97 |
|
98 |
-
session_id = id(session_state)
|
99 |
if global_inference_states[session_id] is not None:
|
100 |
predictor.reset_state(global_inference_states[session_id])
|
101 |
-
|
102 |
-
|
103 |
global_inference_states[session_id] = None
|
104 |
return (
|
105 |
None,
|
@@ -107,26 +113,38 @@ def reset():
|
|
107 |
None,
|
108 |
None,
|
109 |
gr.update(value=None, visible=False),
|
110 |
-
|
|
|
|
|
|
|
111 |
)
|
112 |
|
113 |
|
114 |
-
def clear_points(
|
|
|
115 |
predictor.to("cpu")
|
116 |
-
|
117 |
-
|
118 |
-
session_id = id(session_state)
|
119 |
if global_inference_states[session_id]["tracking_has_started"]:
|
120 |
predictor.reset_state(global_inference_states[session_id])
|
121 |
return (
|
122 |
-
|
123 |
None,
|
124 |
gr.update(value=None, visible=False),
|
125 |
-
|
|
|
126 |
)
|
127 |
|
128 |
|
129 |
-
def preprocess_video_in(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
predictor.to("cpu")
|
131 |
if video_path is None:
|
132 |
return (
|
@@ -134,7 +152,10 @@ def preprocess_video_in(video_path, session_state):
|
|
134 |
None, # points_map
|
135 |
None, # output_image
|
136 |
gr.update(value=None, visible=False), # output_video
|
137 |
-
|
|
|
|
|
|
|
138 |
)
|
139 |
|
140 |
# Read the first frame
|
@@ -146,7 +167,10 @@ def preprocess_video_in(video_path, session_state):
|
|
146 |
None, # points_map
|
147 |
None, # output_image
|
148 |
gr.update(value=None, visible=False), # output_video
|
149 |
-
|
|
|
|
|
|
|
150 |
)
|
151 |
|
152 |
frame_number = 0
|
@@ -169,46 +193,51 @@ def preprocess_video_in(video_path, session_state):
|
|
169 |
frame_number += 1
|
170 |
|
171 |
cap.release()
|
172 |
-
|
173 |
-
|
174 |
|
175 |
-
session_id = id(session_state)
|
176 |
global_inference_states[session_id] = predictor.init_state(video_path=video_path)
|
177 |
|
178 |
-
|
179 |
-
|
180 |
|
181 |
return [
|
182 |
gr.update(open=False), # video_in_drawer
|
183 |
first_frame, # points_map
|
184 |
None, # output_image
|
185 |
gr.update(value=None, visible=False), # output_video
|
186 |
-
|
|
|
|
|
|
|
187 |
]
|
188 |
|
189 |
|
190 |
@spaces.GPU
|
191 |
def segment_with_points(
|
192 |
point_type,
|
193 |
-
|
|
|
194 |
evt: gr.SelectData,
|
|
|
195 |
):
|
|
|
196 |
if torch.cuda.get_device_properties(0).major >= 8:
|
197 |
torch.backends.cuda.matmul.allow_tf32 = True
|
198 |
torch.backends.cudnn.allow_tf32 = True
|
199 |
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
200 |
predictor.to("cuda")
|
201 |
-
|
202 |
-
print(f"TRACKING INPUT POINT: {
|
203 |
|
204 |
if point_type == "include":
|
205 |
-
|
206 |
elif point_type == "exclude":
|
207 |
-
|
208 |
-
print(f"TRACKING INPUT LABEL: {
|
209 |
|
210 |
# Open the image and get its dimensions
|
211 |
-
transparent_background
|
212 |
"RGBA"
|
213 |
)
|
214 |
w, h = transparent_background.size
|
@@ -220,8 +249,8 @@ def segment_with_points(
|
|
220 |
# Create a transparent layer to draw on
|
221 |
transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
|
222 |
|
223 |
-
for index, track in enumerate(
|
224 |
-
if
|
225 |
cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
|
226 |
else:
|
227 |
cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
|
@@ -233,10 +262,9 @@ def segment_with_points(
|
|
233 |
)
|
234 |
|
235 |
# Let's add a positive click at (x, y) = (210, 350) to get started
|
236 |
-
points = np.array(
|
237 |
# for labels, `1` means positive click and `0` means negative click
|
238 |
-
labels = np.array(
|
239 |
-
session_id = id(session_state)
|
240 |
_, _, out_mask_logits = predictor.add_new_points(
|
241 |
inference_state=global_inference_states[session_id],
|
242 |
frame_idx=0,
|
@@ -249,7 +277,7 @@ def segment_with_points(
|
|
249 |
first_frame_output = Image.alpha_composite(transparent_background, mask_image)
|
250 |
|
251 |
torch.cuda.empty_cache()
|
252 |
-
return selected_point_map, first_frame_output,
|
253 |
|
254 |
|
255 |
def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
|
@@ -270,23 +298,21 @@ def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
|
|
270 |
@spaces.GPU
|
271 |
def propagate_to_all(
|
272 |
video_in,
|
273 |
-
|
|
|
274 |
):
|
|
|
275 |
predictor.to("cuda")
|
276 |
if torch.cuda.get_device_properties(0).major >= 8:
|
277 |
torch.backends.cuda.matmul.allow_tf32 = True
|
278 |
torch.backends.cudnn.allow_tf32 = True
|
279 |
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
280 |
-
session_id = id(session_state)
|
281 |
if (
|
282 |
-
len(
|
283 |
or video_in is None
|
284 |
or global_inference_states[session_id] is None
|
285 |
):
|
286 |
-
return
|
287 |
-
None,
|
288 |
-
session_state,
|
289 |
-
)
|
290 |
|
291 |
# run propagation throughout the video and collect the results in a dict
|
292 |
video_segments = (
|
@@ -307,7 +333,7 @@ def propagate_to_all(
|
|
307 |
output_frames = []
|
308 |
for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
|
309 |
transparent_background = Image.fromarray(
|
310 |
-
|
311 |
).convert("RGBA")
|
312 |
out_mask = video_segments[out_frame_idx][OBJ_ID]
|
313 |
mask_image = show_mask(out_mask)
|
@@ -331,10 +357,7 @@ def propagate_to_all(
|
|
331 |
# Write the result to a file
|
332 |
clip.write_videofile(final_vid_output_path, codec="libx264")
|
333 |
|
334 |
-
return (
|
335 |
-
gr.update(value=final_vid_output_path),
|
336 |
-
session_state,
|
337 |
-
)
|
338 |
|
339 |
|
340 |
def update_ui():
|
@@ -342,14 +365,10 @@ def update_ui():
|
|
342 |
|
343 |
|
344 |
with gr.Blocks() as demo:
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
"input_points": [],
|
350 |
-
"input_labels": [],
|
351 |
-
}
|
352 |
-
)
|
353 |
|
354 |
with gr.Column():
|
355 |
# Title
|
@@ -399,14 +418,20 @@ with gr.Blocks() as demo:
|
|
399 |
fn=preprocess_video_in,
|
400 |
inputs=[
|
401 |
video_in,
|
402 |
-
|
|
|
|
|
|
|
403 |
],
|
404 |
outputs=[
|
405 |
video_in_drawer, # Accordion to hide uploaded video player
|
406 |
points_map, # Image component where we add new tracking points
|
407 |
output_image,
|
408 |
output_video,
|
409 |
-
|
|
|
|
|
|
|
410 |
],
|
411 |
queue=False,
|
412 |
)
|
@@ -415,14 +440,20 @@ with gr.Blocks() as demo:
|
|
415 |
fn=preprocess_video_in,
|
416 |
inputs=[
|
417 |
video_in,
|
418 |
-
|
|
|
|
|
|
|
419 |
],
|
420 |
outputs=[
|
421 |
video_in_drawer, # Accordion to hide uploaded video player
|
422 |
points_map, # Image component where we add new tracking points
|
423 |
output_image,
|
424 |
output_video,
|
425 |
-
|
|
|
|
|
|
|
426 |
],
|
427 |
queue=False,
|
428 |
)
|
@@ -432,12 +463,14 @@ with gr.Blocks() as demo:
|
|
432 |
fn=segment_with_points,
|
433 |
inputs=[
|
434 |
point_type, # "include" or "exclude"
|
435 |
-
|
|
|
436 |
],
|
437 |
outputs=[
|
438 |
points_map, # updated image with points
|
439 |
output_image,
|
440 |
-
|
|
|
441 |
],
|
442 |
queue=False,
|
443 |
)
|
@@ -445,26 +478,38 @@ with gr.Blocks() as demo:
|
|
445 |
# Clear every points clicked and added to the map
|
446 |
clear_points_btn.click(
|
447 |
fn=clear_points,
|
448 |
-
inputs=
|
|
|
|
|
|
|
449 |
outputs=[
|
450 |
points_map,
|
451 |
output_image,
|
452 |
output_video,
|
453 |
-
|
|
|
454 |
],
|
455 |
queue=False,
|
456 |
)
|
457 |
|
458 |
reset_btn.click(
|
459 |
fn=reset,
|
460 |
-
inputs=
|
|
|
|
|
|
|
|
|
|
|
461 |
outputs=[
|
462 |
video_in,
|
463 |
video_in_drawer,
|
464 |
points_map,
|
465 |
output_image,
|
466 |
output_video,
|
467 |
-
|
|
|
|
|
|
|
468 |
],
|
469 |
queue=False,
|
470 |
)
|
@@ -478,11 +523,10 @@ with gr.Blocks() as demo:
|
|
478 |
fn=propagate_to_all,
|
479 |
inputs=[
|
480 |
video_in,
|
481 |
-
|
482 |
],
|
483 |
outputs=[
|
484 |
output_video,
|
485 |
-
session_state,
|
486 |
],
|
487 |
concurrency_limit=10,
|
488 |
queue=False,
|
|
|
90 |
return fps
|
91 |
|
92 |
|
93 |
+
def reset(
|
94 |
+
session_first_frame,
|
95 |
+
session_all_frames,
|
96 |
+
session_input_points,
|
97 |
+
session_input_labels,
|
98 |
+
request: gr.Request,
|
99 |
+
):
|
100 |
+
session_id = request.session_id
|
101 |
predictor.to("cpu")
|
102 |
+
session_input_points = []
|
103 |
+
session_input_labels = []
|
104 |
|
|
|
105 |
if global_inference_states[session_id] is not None:
|
106 |
predictor.reset_state(global_inference_states[session_id])
|
107 |
+
session_first_frame = None
|
108 |
+
session_all_frames = None
|
109 |
global_inference_states[session_id] = None
|
110 |
return (
|
111 |
None,
|
|
|
113 |
None,
|
114 |
None,
|
115 |
gr.update(value=None, visible=False),
|
116 |
+
session_first_frame,
|
117 |
+
session_all_frames,
|
118 |
+
session_input_points,
|
119 |
+
session_input_labels,
|
120 |
)
|
121 |
|
122 |
|
123 |
+
def clear_points(session_input_points, session_input_labels, request: gr.Request,):
|
124 |
+
session_id = request.session_id
|
125 |
predictor.to("cpu")
|
126 |
+
session_input_points = []
|
127 |
+
session_input_labels = []
|
|
|
128 |
if global_inference_states[session_id]["tracking_has_started"]:
|
129 |
predictor.reset_state(global_inference_states[session_id])
|
130 |
return (
|
131 |
+
session_first_frame,
|
132 |
None,
|
133 |
gr.update(value=None, visible=False),
|
134 |
+
session_input_points,
|
135 |
+
session_input_labels,
|
136 |
)
|
137 |
|
138 |
|
139 |
+
def preprocess_video_in(
|
140 |
+
video_path,
|
141 |
+
session_first_frame,
|
142 |
+
session_all_frames,
|
143 |
+
session_input_points,
|
144 |
+
session_input_labels,
|
145 |
+
request: gr.Request,
|
146 |
+
):
|
147 |
+
session_id = request.session_id
|
148 |
predictor.to("cpu")
|
149 |
if video_path is None:
|
150 |
return (
|
|
|
152 |
None, # points_map
|
153 |
None, # output_image
|
154 |
gr.update(value=None, visible=False), # output_video
|
155 |
+
session_first_frame,
|
156 |
+
session_all_frames,
|
157 |
+
session_input_points,
|
158 |
+
session_input_labels,
|
159 |
)
|
160 |
|
161 |
# Read the first frame
|
|
|
167 |
None, # points_map
|
168 |
None, # output_image
|
169 |
gr.update(value=None, visible=False), # output_video
|
170 |
+
session_first_frame,
|
171 |
+
session_all_frames,
|
172 |
+
session_input_points,
|
173 |
+
session_input_labels,
|
174 |
)
|
175 |
|
176 |
frame_number = 0
|
|
|
193 |
frame_number += 1
|
194 |
|
195 |
cap.release()
|
196 |
+
session_first_frame = copy.deepcopy(first_frame)
|
197 |
+
session_all_frames = all_frames
|
198 |
|
|
|
199 |
global_inference_states[session_id] = predictor.init_state(video_path=video_path)
|
200 |
|
201 |
+
session_input_points = []
|
202 |
+
session_input_labels = []
|
203 |
|
204 |
return [
|
205 |
gr.update(open=False), # video_in_drawer
|
206 |
first_frame, # points_map
|
207 |
None, # output_image
|
208 |
gr.update(value=None, visible=False), # output_video
|
209 |
+
session_first_frame,
|
210 |
+
session_all_frames,
|
211 |
+
session_input_points,
|
212 |
+
session_input_labels,
|
213 |
]
|
214 |
|
215 |
|
216 |
@spaces.GPU
|
217 |
def segment_with_points(
|
218 |
point_type,
|
219 |
+
session_input_points,
|
220 |
+
session_input_labels,
|
221 |
evt: gr.SelectData,
|
222 |
+
request: gr.Request,
|
223 |
):
|
224 |
+
session_id = request.session_id
|
225 |
if torch.cuda.get_device_properties(0).major >= 8:
|
226 |
torch.backends.cuda.matmul.allow_tf32 = True
|
227 |
torch.backends.cudnn.allow_tf32 = True
|
228 |
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
229 |
predictor.to("cuda")
|
230 |
+
session_input_points.append(evt.index)
|
231 |
+
print(f"TRACKING INPUT POINT: {session_input_points}")
|
232 |
|
233 |
if point_type == "include":
|
234 |
+
session_input_labels.append(1)
|
235 |
elif point_type == "exclude":
|
236 |
+
session_input_labels.append(0)
|
237 |
+
print(f"TRACKING INPUT LABEL: {session_input_labels}")
|
238 |
|
239 |
# Open the image and get its dimensions
|
240 |
+
transparent_background Image.fromarray(session_first_frame).convert(
|
241 |
"RGBA"
|
242 |
)
|
243 |
w, h = transparent_background.size
|
|
|
249 |
# Create a transparent layer to draw on
|
250 |
transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
|
251 |
|
252 |
+
for index, track in enumerate(session_input_points):
|
253 |
+
if session_input_labels[index] == 1:
|
254 |
cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
|
255 |
else:
|
256 |
cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
|
|
|
262 |
)
|
263 |
|
264 |
# Let's add a positive click at (x, y) = (210, 350) to get started
|
265 |
+
points = np.array(session_input_points, dtype=np.float32)
|
266 |
# for labels, `1` means positive click and `0` means negative click
|
267 |
+
labels = np.array(session_input_labels, dtype=np.int32)
|
|
|
268 |
_, _, out_mask_logits = predictor.add_new_points(
|
269 |
inference_state=global_inference_states[session_id],
|
270 |
frame_idx=0,
|
|
|
277 |
first_frame_output = Image.alpha_composite(transparent_background, mask_image)
|
278 |
|
279 |
torch.cuda.empty_cache()
|
280 |
+
return selected_point_map, first_frame_output, session_input_points, session_input_labels
|
281 |
|
282 |
|
283 |
def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
|
|
|
298 |
@spaces.GPU
|
299 |
def propagate_to_all(
|
300 |
video_in,
|
301 |
+
session_all_frames,
|
302 |
+
request: gr.Request,
|
303 |
):
|
304 |
+
session_id = request.session_id
|
305 |
predictor.to("cuda")
|
306 |
if torch.cuda.get_device_properties(0).major >= 8:
|
307 |
torch.backends.cuda.matmul.allow_tf32 = True
|
308 |
torch.backends.cudnn.allow_tf32 = True
|
309 |
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
|
|
310 |
if (
|
311 |
+
len (session_input_points) == 0
|
312 |
or video_in is None
|
313 |
or global_inference_states[session_id] is None
|
314 |
):
|
315 |
+
return None
|
|
|
|
|
|
|
316 |
|
317 |
# run propagation throughout the video and collect the results in a dict
|
318 |
video_segments = (
|
|
|
333 |
output_frames = []
|
334 |
for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
|
335 |
transparent_background = Image.fromarray(
|
336 |
+
session_all_frames[out_frame_idx]
|
337 |
).convert("RGBA")
|
338 |
out_mask = video_segments[out_frame_idx][OBJ_ID]
|
339 |
mask_image = show_mask(out_mask)
|
|
|
357 |
# Write the result to a file
|
358 |
clip.write_videofile(final_vid_output_path, codec="libx264")
|
359 |
|
360 |
+
return gr.update(value=final_vid_output_path)
|
|
|
|
|
|
|
361 |
|
362 |
|
363 |
def update_ui():
|
|
|
365 |
|
366 |
|
367 |
with gr.Blocks() as demo:
|
368 |
+
first_frame = gr.State(None)
|
369 |
+
all_frames = gr.State(None)
|
370 |
+
input_points = gr.State([])
|
371 |
+
input_labels = gr.State([])
|
|
|
|
|
|
|
|
|
372 |
|
373 |
with gr.Column():
|
374 |
# Title
|
|
|
418 |
fn=preprocess_video_in,
|
419 |
inputs=[
|
420 |
video_in,
|
421 |
+
first_frame,
|
422 |
+
all_frames,
|
423 |
+
input_points,
|
424 |
+
input_labels,
|
425 |
],
|
426 |
outputs=[
|
427 |
video_in_drawer, # Accordion to hide uploaded video player
|
428 |
points_map, # Image component where we add new tracking points
|
429 |
output_image,
|
430 |
output_video,
|
431 |
+
first_frame,
|
432 |
+
all_frames,
|
433 |
+
input_points,
|
434 |
+
input_labels,
|
435 |
],
|
436 |
queue=False,
|
437 |
)
|
|
|
440 |
fn=preprocess_video_in,
|
441 |
inputs=[
|
442 |
video_in,
|
443 |
+
first_frame,
|
444 |
+
all_frames,
|
445 |
+
input_points,
|
446 |
+
input_labels,
|
447 |
],
|
448 |
outputs=[
|
449 |
video_in_drawer, # Accordion to hide uploaded video player
|
450 |
points_map, # Image component where we add new tracking points
|
451 |
output_image,
|
452 |
output_video,
|
453 |
+
first_frame,
|
454 |
+
all_frames,
|
455 |
+
input_points,
|
456 |
+
input_labels,
|
457 |
],
|
458 |
queue=False,
|
459 |
)
|
|
|
463 |
fn=segment_with_points,
|
464 |
inputs=[
|
465 |
point_type, # "include" or "exclude"
|
466 |
+
input_points,
|
467 |
+
input_labels,
|
468 |
],
|
469 |
outputs=[
|
470 |
points_map, # updated image with points
|
471 |
output_image,
|
472 |
+
input_points,
|
473 |
+
input_labels,
|
474 |
],
|
475 |
queue=False,
|
476 |
)
|
|
|
478 |
# Clear every points clicked and added to the map
|
479 |
clear_points_btn.click(
|
480 |
fn=clear_points,
|
481 |
+
inputs=[
|
482 |
+
input_points,
|
483 |
+
input_labels,
|
484 |
+
],
|
485 |
outputs=[
|
486 |
points_map,
|
487 |
output_image,
|
488 |
output_video,
|
489 |
+
input_points,
|
490 |
+
input_labels,
|
491 |
],
|
492 |
queue=False,
|
493 |
)
|
494 |
|
495 |
reset_btn.click(
|
496 |
fn=reset,
|
497 |
+
inputs=[
|
498 |
+
first_frame,
|
499 |
+
all_frames,
|
500 |
+
input_points,
|
501 |
+
input_labels,
|
502 |
+
],
|
503 |
outputs=[
|
504 |
video_in,
|
505 |
video_in_drawer,
|
506 |
points_map,
|
507 |
output_image,
|
508 |
output_video,
|
509 |
+
first_frame,
|
510 |
+
all_frames,
|
511 |
+
input_points,
|
512 |
+
input_labels,
|
513 |
],
|
514 |
queue=False,
|
515 |
)
|
|
|
523 |
fn=propagate_to_all,
|
524 |
inputs=[
|
525 |
video_in,
|
526 |
+
all_frames,
|
527 |
],
|
528 |
outputs=[
|
529 |
output_video,
|
|
|
530 |
],
|
531 |
concurrency_limit=10,
|
532 |
queue=False,
|