speed up video uploading
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# This Gradio demo code is from https://github.com/cvlab-kaist/locotrack/blob/main/demo/demo.py
|
2 |
# We updated it to work with CoTracker3 models. We thank authors of LocoTrack
|
3 |
# for such an amazing Gradio demo.
|
4 |
|
@@ -22,18 +22,35 @@ from visualizer import Visualizer
|
|
22 |
|
23 |
# Generate random colormaps for visualizing different points.
|
24 |
def get_colors(num_colors: int) -> List[Tuple[int, int, int]]:
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def get_points_on_a_grid(
|
39 |
size: int,
|
@@ -84,91 +101,99 @@ def get_points_on_a_grid(
|
|
84 |
)
|
85 |
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
|
86 |
|
|
|
87 |
def paint_point_track(
|
88 |
frames: np.ndarray,
|
89 |
point_tracks: np.ndarray,
|
90 |
visibles: np.ndarray,
|
91 |
colormap: Optional[List[Tuple[int, int, int]]] = None,
|
92 |
) -> np.ndarray:
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
)
|
131 |
-
for i in range(num_points):
|
132 |
-
# The icon is centered at the center of a pixel, but the input coordinates
|
133 |
-
# are raster coordinates. Therefore, to render a point at (1,1) (which
|
134 |
-
# lies on the corner between four pixels), we need 1/4 of the icon placed
|
135 |
-
# centered on the 0'th row, 0'th column, etc. We need to subtract
|
136 |
-
# 0.5 to make the fractional position come out right.
|
137 |
-
x, y = point_tracks[i, t, :] + 0.5
|
138 |
-
x = min(max(x, 0.0), width)
|
139 |
-
y = min(max(y, 0.0), height)
|
140 |
-
|
141 |
-
if visibles[i, t]:
|
142 |
-
x1, y1 = np.floor(x).astype(np.int32), np.floor(y).astype(np.int32)
|
143 |
-
x2, y2 = x1 + 1, y1 + 1
|
144 |
-
|
145 |
-
# bilinear interpolation
|
146 |
-
patch = (
|
147 |
-
icon1 * (x2 - x) * (y2 - y)
|
148 |
-
+ icon2 * (x2 - x) * (y - y1)
|
149 |
-
+ icon3 * (x - x1) * (y2 - y)
|
150 |
-
+ icon4 * (x - x1) * (y - y1)
|
151 |
)
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
|
165 |
-
PREVIEW_WIDTH = 768
|
166 |
-
VIDEO_INPUT_RESO = (384, 512)
|
167 |
-
POINT_SIZE = 4
|
168 |
-
FRAME_LIMIT =
|
169 |
|
170 |
|
171 |
-
def get_point(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
print(f"You selected {(evt.index[0], evt.index[1], frame_num)}")
|
173 |
|
174 |
current_frame = video_queried_preview[int(frame_num)]
|
@@ -192,22 +217,29 @@ def get_point(frame_num, video_queried_preview, query_points, query_points_color
|
|
192 |
# Update the query count
|
193 |
query_count += 1
|
194 |
return (
|
195 |
-
current_frame_draw,
|
196 |
-
video_queried_preview,
|
197 |
-
query_points,
|
198 |
-
query_points_color,
|
199 |
-
query_count
|
200 |
)
|
201 |
|
202 |
|
203 |
-
def undo_point(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
if len(query_points[int(frame_num)]) == 0:
|
205 |
return (
|
206 |
video_queried_preview[int(frame_num)],
|
207 |
video_queried_preview,
|
208 |
query_points,
|
209 |
query_points_color,
|
210 |
-
query_count
|
211 |
)
|
212 |
|
213 |
# Get the last point
|
@@ -216,9 +248,13 @@ def undo_point(frame_num, video_preview, video_queried_preview, query_points, qu
|
|
216 |
|
217 |
# Redraw the frame
|
218 |
current_frame_draw = video_preview[int(frame_num)].copy()
|
219 |
-
for point, color in zip(
|
|
|
|
|
220 |
x, y, _ = point
|
221 |
-
current_frame_draw = cv2.circle(
|
|
|
|
|
222 |
|
223 |
# Update the query count
|
224 |
query_count -= 1
|
@@ -226,15 +262,22 @@ def undo_point(frame_num, video_preview, video_queried_preview, query_points, qu
|
|
226 |
# Update the frame
|
227 |
video_queried_preview[int(frame_num)] = current_frame_draw
|
228 |
return (
|
229 |
-
current_frame_draw,
|
230 |
-
video_queried_preview,
|
231 |
-
query_points,
|
232 |
-
query_points_color,
|
233 |
-
query_count
|
234 |
)
|
235 |
|
236 |
|
237 |
-
def clear_frame_fn(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
query_count -= len(query_points[int(frame_num)])
|
239 |
|
240 |
query_points[int(frame_num)] = []
|
@@ -243,22 +286,21 @@ def clear_frame_fn(frame_num, video_preview, video_queried_preview, query_points
|
|
243 |
video_queried_preview[int(frame_num)] = video_preview[int(frame_num)].copy()
|
244 |
|
245 |
return (
|
246 |
-
video_preview[int(frame_num)],
|
247 |
-
video_queried_preview,
|
248 |
-
query_points,
|
249 |
-
query_points_color,
|
250 |
-
query_count
|
251 |
)
|
252 |
|
253 |
|
254 |
-
|
255 |
def clear_all_fn(frame_num, video_preview):
|
256 |
return (
|
257 |
video_preview[int(frame_num)],
|
258 |
video_preview.copy(),
|
259 |
[[] for _ in range(len(video_preview))],
|
260 |
[[] for _ in range(len(video_preview))],
|
261 |
-
0
|
262 |
)
|
263 |
|
264 |
|
@@ -267,77 +309,98 @@ def choose_frame(frame_num, video_preview_array):
|
|
267 |
|
268 |
|
269 |
def preprocess_video_input(video_path):
|
270 |
-
|
271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
num_frames = video_arr.shape[0]
|
273 |
if num_frames > FRAME_LIMIT:
|
274 |
-
gr.Warning(
|
|
|
|
|
|
|
275 |
video_arr = video_arr[:FRAME_LIMIT]
|
276 |
num_frames = FRAME_LIMIT
|
277 |
|
278 |
-
|
279 |
-
height, width = video_arr.shape[1:3]
|
280 |
-
new_height, new_width = int(PREVIEW_WIDTH * height / width), PREVIEW_WIDTH
|
281 |
-
|
282 |
-
preview_video = mediapy.resize_video(video_arr, (new_height, new_width))
|
283 |
-
input_video = mediapy.resize_video(video_arr, VIDEO_INPUT_RESO)
|
284 |
-
|
285 |
-
preview_video = np.array(preview_video)
|
286 |
-
input_video = np.array(input_video)
|
287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
interactive = True
|
289 |
|
290 |
return (
|
291 |
-
video_arr,
|
292 |
-
preview_video,
|
293 |
-
preview_video.copy(),
|
294 |
-
input_video,
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
[[] for _ in range(num_frames)],
|
302 |
-
|
303 |
-
|
304 |
-
0, # Set query count to 0
|
305 |
-
gr.update(interactive=interactive), # Make the buttons interactive
|
306 |
gr.update(interactive=interactive),
|
307 |
gr.update(interactive=interactive),
|
308 |
gr.update(interactive=True),
|
309 |
)
|
310 |
|
|
|
311 |
@spaces.GPU
|
312 |
def track(
|
313 |
video_preview,
|
314 |
-
video_input,
|
315 |
-
video_fps,
|
316 |
-
query_points,
|
317 |
-
query_points_color,
|
318 |
-
query_count,
|
319 |
):
|
320 |
-
tracking_mode =
|
321 |
-
if query_count == 0:
|
322 |
-
tracking_mode=
|
323 |
-
|
324 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
325 |
dtype = torch.float if device == "cuda" else torch.float
|
326 |
|
327 |
# Convert query points to tensor, normalize to input resolution
|
328 |
-
if tracking_mode!=
|
329 |
query_points_tensor = []
|
330 |
for frame_points in query_points:
|
331 |
query_points_tensor.extend(frame_points)
|
332 |
-
|
333 |
query_points_tensor = torch.tensor(query_points_tensor).float()
|
334 |
-
query_points_tensor *= torch.tensor(
|
335 |
-
VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0], 1
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
query_points_tensor = query_points_tensor[:, :, [0, 2, 1]]
|
341 |
|
342 |
video_input = torch.tensor(video_input).unsqueeze(0)
|
343 |
|
@@ -345,10 +408,10 @@ def track(
|
|
345 |
model = model.to(device)
|
346 |
|
347 |
video_input = video_input.permute(0, 1, 4, 2, 3)
|
348 |
-
if tracking_mode==
|
349 |
xy = get_points_on_a_grid(40, video_input.shape[3:], device=device)
|
350 |
queries = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
351 |
-
add_support_grid=False
|
352 |
cmap = matplotlib.colormaps.get_cmap("gist_rainbow")
|
353 |
query_points_color = [[]]
|
354 |
query_count = queries.shape[1]
|
@@ -364,18 +427,33 @@ def track(
|
|
364 |
# queries__ = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
365 |
num_tracks = queries.shape[1]
|
366 |
# queries = torch.cat([queries,queries__],dim=1)
|
367 |
-
add_support_grid=True
|
368 |
-
|
369 |
-
model(
|
370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
for ind in range(0, video_input.shape[1] - model.step, model.step):
|
372 |
pred_tracks, pred_visibility = model(
|
373 |
video_chunk=video_input[:, ind : ind + model.step * 2].to(device, dtype),
|
374 |
-
grid_size=0,
|
375 |
-
queries=queries,
|
376 |
-
add_support_grid=add_support_grid
|
377 |
) # B T N 2, B T N 1
|
378 |
-
tracks = (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
pred_occ = torch.ones_like(pred_visibility[0]).permute(1, 0).cpu().numpy()
|
380 |
|
381 |
# make color array
|
@@ -386,15 +464,33 @@ def track(
|
|
386 |
|
387 |
# pred_tracks = torch.cat([pred_tracks[:,:1],(pred_tracks[:,:-2] + pred_tracks[:,1:-1] + pred_tracks[:,2:])/ 3, pred_tracks[:,-1:]],dim=1)
|
388 |
# torch.cat([pred_tracks[:,:1],pred_tracks[:,1:]],dim=1)
|
389 |
-
pred_tracks = (
|
390 |
-
|
391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
# segm_mask = torch.zeros(queries.shape[1])
|
393 |
# segm_mask[:num_tracks] = 1
|
394 |
# print('segm_mask',segm_mask.shape, segm_mask)
|
395 |
-
# segm_mask=segm_mask,
|
396 |
-
painted_video =
|
397 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
398 |
# painted_video = paint_point_track(video_preview,tracks,pred_occ,colors)
|
399 |
|
400 |
# save video
|
@@ -420,8 +516,11 @@ with gr.Blocks() as demo:
|
|
420 |
is_tracked_query = gr.State([])
|
421 |
query_count = gr.State(0)
|
422 |
|
423 |
-
gr.Markdown(
|
424 |
-
|
|
|
|
|
|
|
425 |
<p>Welcome to <a href='https://cotracker3.github.io/' target='_blank'>CoTracker</a>! This space demonstrates point (pixel) tracking in videos. \
|
426 |
The model tracks points on a grid or points selected by you. </p> \
|
427 |
<p> To get started, simply upload your <b>.mp4</b> video or click on one of the example videos to load them. The shorter the video, the faster the processing. We recommend submitting short videos of length <b>2-7 seconds</b>.</p> \
|
@@ -429,42 +528,59 @@ with gr.Blocks() as demo:
|
|
429 |
<p style='text-align: left'>For more details, check out our <a href='https://github.com/facebookresearch/co-tracker' target='_blank'>GitHub Repo</a> ⭐. We thank the authors of LocoTrack for their interactive demo.</p> \
|
430 |
</div>"
|
431 |
)
|
432 |
-
|
433 |
|
434 |
-
gr.Markdown(
|
|
|
|
|
435 |
with gr.Row():
|
436 |
-
|
437 |
|
438 |
with gr.Accordion("Your video input", open=True) as video_in_drawer:
|
439 |
video_in = gr.Video(label="Video Input", format="mp4")
|
440 |
submit = gr.Button("Submit", scale=0)
|
441 |
|
442 |
import os
|
|
|
443 |
apple = os.path.join(os.path.dirname(__file__), "videos", "apple.mp4")
|
444 |
bear = os.path.join(os.path.dirname(__file__), "videos", "bear.mp4")
|
445 |
paragliding_launch = os.path.join(
|
446 |
os.path.dirname(__file__), "videos", "paragliding-launch.mp4"
|
447 |
)
|
448 |
-
paragliding = os.path.join(
|
|
|
|
|
449 |
cat = os.path.join(os.path.dirname(__file__), "videos", "cat.mp4")
|
450 |
pillow = os.path.join(os.path.dirname(__file__), "videos", "pillow.mp4")
|
451 |
teddy = os.path.join(os.path.dirname(__file__), "videos", "teddy.mp4")
|
452 |
backpack = os.path.join(os.path.dirname(__file__), "videos", "backpack.mp4")
|
453 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
],
|
459 |
-
)
|
460 |
-
|
461 |
-
|
462 |
-
gr.Markdown("## Second step: Simply click \"Track\" to track a grid of points or select query points on the video before clicking")
|
463 |
with gr.Row():
|
464 |
with gr.Column():
|
465 |
with gr.Row():
|
466 |
query_frames = gr.Slider(
|
467 |
-
minimum=0,
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
with gr.Row():
|
469 |
undo = gr.Button("Undo", interactive=False)
|
470 |
clear_frame = gr.Button("Clear Frame", interactive=False)
|
@@ -472,11 +588,9 @@ with gr.Blocks() as demo:
|
|
472 |
|
473 |
with gr.Row():
|
474 |
current_frame = gr.Image(
|
475 |
-
label="Click to add query points",
|
476 |
-
type="numpy",
|
477 |
-
interactive=False
|
478 |
)
|
479 |
-
|
480 |
with gr.Row():
|
481 |
track_button = gr.Button("Track", interactive=False)
|
482 |
|
@@ -488,12 +602,10 @@ with gr.Blocks() as demo:
|
|
488 |
loop=True,
|
489 |
)
|
490 |
|
491 |
-
|
492 |
-
|
493 |
submit.click(
|
494 |
-
fn
|
495 |
-
inputs
|
496 |
-
outputs
|
497 |
video,
|
498 |
video_preview,
|
499 |
video_queried_preview,
|
@@ -511,97 +623,96 @@ with gr.Blocks() as demo:
|
|
511 |
clear_all,
|
512 |
track_button,
|
513 |
],
|
514 |
-
queue
|
515 |
)
|
516 |
|
517 |
query_frames.change(
|
518 |
-
fn
|
519 |
-
inputs
|
520 |
-
outputs
|
521 |
current_frame,
|
522 |
],
|
523 |
-
queue
|
524 |
)
|
525 |
|
526 |
current_frame.select(
|
527 |
-
fn
|
528 |
-
inputs
|
529 |
query_frames,
|
530 |
video_queried_preview,
|
531 |
query_points,
|
532 |
query_points_color,
|
533 |
query_count,
|
534 |
-
],
|
535 |
-
outputs
|
536 |
current_frame,
|
537 |
video_queried_preview,
|
538 |
query_points,
|
539 |
query_points_color,
|
540 |
-
query_count
|
541 |
-
],
|
542 |
-
queue
|
543 |
)
|
544 |
-
|
545 |
undo.click(
|
546 |
-
fn
|
547 |
-
inputs
|
548 |
query_frames,
|
549 |
video_preview,
|
550 |
video_queried_preview,
|
551 |
query_points,
|
552 |
query_points_color,
|
553 |
-
query_count
|
554 |
],
|
555 |
-
outputs
|
556 |
current_frame,
|
557 |
video_queried_preview,
|
558 |
query_points,
|
559 |
query_points_color,
|
560 |
-
query_count
|
561 |
],
|
562 |
-
queue
|
563 |
)
|
564 |
|
565 |
clear_frame.click(
|
566 |
-
fn
|
567 |
-
inputs
|
568 |
query_frames,
|
569 |
video_preview,
|
570 |
video_queried_preview,
|
571 |
query_points,
|
572 |
query_points_color,
|
573 |
-
query_count
|
574 |
],
|
575 |
-
outputs
|
576 |
current_frame,
|
577 |
video_queried_preview,
|
578 |
query_points,
|
579 |
query_points_color,
|
580 |
-
query_count
|
581 |
],
|
582 |
-
queue
|
583 |
)
|
584 |
|
585 |
clear_all.click(
|
586 |
-
fn
|
587 |
-
inputs
|
588 |
query_frames,
|
589 |
video_preview,
|
590 |
],
|
591 |
-
outputs
|
592 |
current_frame,
|
593 |
video_queried_preview,
|
594 |
query_points,
|
595 |
query_points_color,
|
596 |
-
query_count
|
597 |
],
|
598 |
-
queue
|
599 |
)
|
600 |
|
601 |
-
|
602 |
track_button.click(
|
603 |
-
fn
|
604 |
-
inputs
|
605 |
video_preview,
|
606 |
video_input,
|
607 |
video_fps,
|
@@ -609,11 +720,11 @@ with gr.Blocks() as demo:
|
|
609 |
query_points_color,
|
610 |
query_count,
|
611 |
],
|
612 |
-
outputs
|
613 |
output_video,
|
614 |
],
|
615 |
-
queue
|
616 |
)
|
617 |
|
618 |
-
|
619 |
-
demo.launch(show_api=False, show_error=True, debug=False, share=False)
|
|
|
1 |
+
# This Gradio demo code is from https://github.com/cvlab-kaist/locotrack/blob/main/demo/demo.py
|
2 |
# We updated it to work with CoTracker3 models. We thank authors of LocoTrack
|
3 |
# for such an amazing Gradio demo.
|
4 |
|
|
|
22 |
|
23 |
# Generate random colormaps for visualizing different points.
|
24 |
def get_colors(num_colors: int) -> List[Tuple[int, int, int]]:
|
25 |
+
"""Gets colormap for points."""
|
26 |
+
colors = []
|
27 |
+
for i in np.arange(0.0, 360.0, 360.0 / num_colors):
|
28 |
+
hue = i / 360.0
|
29 |
+
lightness = (50 + np.random.rand() * 10) / 100.0
|
30 |
+
saturation = (90 + np.random.rand() * 10) / 100.0
|
31 |
+
color = colorsys.hls_to_rgb(hue, lightness, saturation)
|
32 |
+
colors.append((int(color[0] * 255), int(color[1] * 255), int(color[2] * 255)))
|
33 |
+
random.shuffle(colors)
|
34 |
+
return colors
|
35 |
+
|
36 |
+
def read_video_cv2(video_path):
|
37 |
+
cap = cv2.VideoCapture(video_path)
|
38 |
+
frames = []
|
39 |
+
|
40 |
+
# Get FPS from video metadata
|
41 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
42 |
+
|
43 |
+
while cap.isOpened():
|
44 |
+
ret, frame = cap.read()
|
45 |
+
if not ret:
|
46 |
+
break
|
47 |
+
# Convert BGR to RGB
|
48 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
49 |
+
frames.append(frame)
|
50 |
+
|
51 |
+
cap.release()
|
52 |
+
video_arr = np.array(frames)
|
53 |
+
return video_arr, fps
|
54 |
|
55 |
def get_points_on_a_grid(
|
56 |
size: int,
|
|
|
101 |
)
|
102 |
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
|
103 |
|
104 |
+
|
105 |
def paint_point_track(
|
106 |
frames: np.ndarray,
|
107 |
point_tracks: np.ndarray,
|
108 |
visibles: np.ndarray,
|
109 |
colormap: Optional[List[Tuple[int, int, int]]] = None,
|
110 |
) -> np.ndarray:
|
111 |
+
"""Converts a sequence of points to color code video.
|
112 |
+
Args:
|
113 |
+
frames: [num_frames, height, width, 3], np.uint8, [0, 255]
|
114 |
+
point_tracks: [num_points, num_frames, 2], np.float32, [0, width / height]
|
115 |
+
visibles: [num_points, num_frames], bool
|
116 |
+
colormap: colormap for points, each point has a different RGB color.
|
117 |
+
Returns:
|
118 |
+
video: [num_frames, height, width, 3], np.uint8, [0, 255]
|
119 |
+
"""
|
120 |
+
num_points, num_frames = point_tracks.shape[0:2]
|
121 |
+
if colormap is None:
|
122 |
+
colormap = get_colors(num_colors=num_points)
|
123 |
+
height, width = frames.shape[1:3]
|
124 |
+
dot_size_as_fraction_of_min_edge = 0.015
|
125 |
+
radius = int(round(min(height, width) * dot_size_as_fraction_of_min_edge))
|
126 |
+
diam = radius * 2 + 1
|
127 |
+
quadratic_y = np.square(np.arange(diam)[:, np.newaxis] - radius - 1)
|
128 |
+
quadratic_x = np.square(np.arange(diam)[np.newaxis, :] - radius - 1)
|
129 |
+
icon = (quadratic_y + quadratic_x) - (radius**2) / 2.0
|
130 |
+
sharpness = 0.15
|
131 |
+
icon = np.clip(icon / (radius * 2 * sharpness), 0, 1)
|
132 |
+
icon = 1 - icon[:, :, np.newaxis]
|
133 |
+
icon1 = np.pad(icon, [(0, 1), (0, 1), (0, 0)])
|
134 |
+
icon2 = np.pad(icon, [(1, 0), (0, 1), (0, 0)])
|
135 |
+
icon3 = np.pad(icon, [(0, 1), (1, 0), (0, 0)])
|
136 |
+
icon4 = np.pad(icon, [(1, 0), (1, 0), (0, 0)])
|
137 |
+
|
138 |
+
video = frames.copy()
|
139 |
+
for t in range(num_frames):
|
140 |
+
# Pad so that points that extend outside the image frame don't crash us
|
141 |
+
image = np.pad(
|
142 |
+
video[t],
|
143 |
+
[
|
144 |
+
(radius + 1, radius + 1),
|
145 |
+
(radius + 1, radius + 1),
|
146 |
+
(0, 0),
|
147 |
+
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
)
|
149 |
+
for i in range(num_points):
|
150 |
+
# The icon is centered at the center of a pixel, but the input coordinates
|
151 |
+
# are raster coordinates. Therefore, to render a point at (1,1) (which
|
152 |
+
# lies on the corner between four pixels), we need 1/4 of the icon placed
|
153 |
+
# centered on the 0'th row, 0'th column, etc. We need to subtract
|
154 |
+
# 0.5 to make the fractional position come out right.
|
155 |
+
x, y = point_tracks[i, t, :] + 0.5
|
156 |
+
x = min(max(x, 0.0), width)
|
157 |
+
y = min(max(y, 0.0), height)
|
158 |
+
|
159 |
+
if visibles[i, t]:
|
160 |
+
x1, y1 = np.floor(x).astype(np.int32), np.floor(y).astype(np.int32)
|
161 |
+
x2, y2 = x1 + 1, y1 + 1
|
162 |
+
|
163 |
+
# bilinear interpolation
|
164 |
+
patch = (
|
165 |
+
icon1 * (x2 - x) * (y2 - y)
|
166 |
+
+ icon2 * (x2 - x) * (y - y1)
|
167 |
+
+ icon3 * (x - x1) * (y2 - y)
|
168 |
+
+ icon4 * (x - x1) * (y - y1)
|
169 |
+
)
|
170 |
+
x_ub = x1 + 2 * radius + 2
|
171 |
+
y_ub = y1 + 2 * radius + 2
|
172 |
+
image[y1:y_ub, x1:x_ub, :] = (1 - patch) * image[
|
173 |
+
y1:y_ub, x1:x_ub, :
|
174 |
+
] + patch * np.array(colormap[i])[np.newaxis, np.newaxis, :]
|
175 |
+
|
176 |
+
# Remove the pad
|
177 |
+
video[t] = image[radius + 1 : -radius - 1, radius + 1 : -radius - 1].astype(
|
178 |
+
np.uint8
|
179 |
+
)
|
180 |
+
return video
|
181 |
|
182 |
|
183 |
+
PREVIEW_WIDTH = 768 # Width of the preview video
|
184 |
+
VIDEO_INPUT_RESO = (384, 512) # Resolution of the input video
|
185 |
+
POINT_SIZE = 4 # Size of the query point in the preview video
|
186 |
+
FRAME_LIMIT = 600 # Limit the number of frames to process
|
187 |
|
188 |
|
189 |
+
def get_point(
|
190 |
+
frame_num,
|
191 |
+
video_queried_preview,
|
192 |
+
query_points,
|
193 |
+
query_points_color,
|
194 |
+
query_count,
|
195 |
+
evt: gr.SelectData,
|
196 |
+
):
|
197 |
print(f"You selected {(evt.index[0], evt.index[1], frame_num)}")
|
198 |
|
199 |
current_frame = video_queried_preview[int(frame_num)]
|
|
|
217 |
# Update the query count
|
218 |
query_count += 1
|
219 |
return (
|
220 |
+
current_frame_draw, # Updated frame for preview
|
221 |
+
video_queried_preview, # Updated preview video
|
222 |
+
query_points, # Updated query points
|
223 |
+
query_points_color, # Updated query points color
|
224 |
+
query_count, # Updated query count
|
225 |
)
|
226 |
|
227 |
|
228 |
+
def undo_point(
|
229 |
+
frame_num,
|
230 |
+
video_preview,
|
231 |
+
video_queried_preview,
|
232 |
+
query_points,
|
233 |
+
query_points_color,
|
234 |
+
query_count,
|
235 |
+
):
|
236 |
if len(query_points[int(frame_num)]) == 0:
|
237 |
return (
|
238 |
video_queried_preview[int(frame_num)],
|
239 |
video_queried_preview,
|
240 |
query_points,
|
241 |
query_points_color,
|
242 |
+
query_count,
|
243 |
)
|
244 |
|
245 |
# Get the last point
|
|
|
248 |
|
249 |
# Redraw the frame
|
250 |
current_frame_draw = video_preview[int(frame_num)].copy()
|
251 |
+
for point, color in zip(
|
252 |
+
query_points[int(frame_num)], query_points_color[int(frame_num)]
|
253 |
+
):
|
254 |
x, y, _ = point
|
255 |
+
current_frame_draw = cv2.circle(
|
256 |
+
current_frame_draw, (x, y), POINT_SIZE, color, -1
|
257 |
+
)
|
258 |
|
259 |
# Update the query count
|
260 |
query_count -= 1
|
|
|
262 |
# Update the frame
|
263 |
video_queried_preview[int(frame_num)] = current_frame_draw
|
264 |
return (
|
265 |
+
current_frame_draw, # Updated frame for preview
|
266 |
+
video_queried_preview, # Updated preview video
|
267 |
+
query_points, # Updated query points
|
268 |
+
query_points_color, # Updated query points color
|
269 |
+
query_count, # Updated query count
|
270 |
)
|
271 |
|
272 |
|
273 |
+
def clear_frame_fn(
|
274 |
+
frame_num,
|
275 |
+
video_preview,
|
276 |
+
video_queried_preview,
|
277 |
+
query_points,
|
278 |
+
query_points_color,
|
279 |
+
query_count,
|
280 |
+
):
|
281 |
query_count -= len(query_points[int(frame_num)])
|
282 |
|
283 |
query_points[int(frame_num)] = []
|
|
|
286 |
video_queried_preview[int(frame_num)] = video_preview[int(frame_num)].copy()
|
287 |
|
288 |
return (
|
289 |
+
video_preview[int(frame_num)], # Set the preview frame to the original frame
|
290 |
+
video_queried_preview,
|
291 |
+
query_points, # Cleared query points
|
292 |
+
query_points_color, # Cleared query points color
|
293 |
+
query_count, # New query count
|
294 |
)
|
295 |
|
296 |
|
|
|
297 |
def clear_all_fn(frame_num, video_preview):
|
298 |
return (
|
299 |
video_preview[int(frame_num)],
|
300 |
video_preview.copy(),
|
301 |
[[] for _ in range(len(video_preview))],
|
302 |
[[] for _ in range(len(video_preview))],
|
303 |
+
0,
|
304 |
)
|
305 |
|
306 |
|
|
|
309 |
|
310 |
|
311 |
def preprocess_video_input(video_path):
|
312 |
+
import time
|
313 |
+
start_time = time.time()
|
314 |
+
|
315 |
+
# Read video and get FPS
|
316 |
+
video_arr, video_fps = read_video_cv2(video_path)
|
317 |
+
end_time = time.time()
|
318 |
+
print(f"Time taken to read video: {end_time - start_time} seconds")
|
319 |
+
|
320 |
+
# Apply frame limit
|
321 |
num_frames = video_arr.shape[0]
|
322 |
if num_frames > FRAME_LIMIT:
|
323 |
+
gr.Warning(
|
324 |
+
f"The video is too long. Only the first {FRAME_LIMIT} frames will be used.",
|
325 |
+
duration=5,
|
326 |
+
)
|
327 |
video_arr = video_arr[:FRAME_LIMIT]
|
328 |
num_frames = FRAME_LIMIT
|
329 |
|
330 |
+
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
|
332 |
+
# Resize preview video while maintaining aspect ratio
|
333 |
+
h, w = video_arr.shape[1:3]
|
334 |
+
aspect_ratio = w / h
|
335 |
+
if w > PREVIEW_WIDTH:
|
336 |
+
new_w = PREVIEW_WIDTH
|
337 |
+
new_h = int(new_w / aspect_ratio)
|
338 |
+
preview_video = np.zeros((len(video_arr), new_h, new_w, 3), dtype=np.uint8)
|
339 |
+
for i in range(len(video_arr)):
|
340 |
+
preview_video[i] = cv2.resize(video_arr[i], (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
341 |
+
else:
|
342 |
+
preview_video = video_arr.copy()
|
343 |
+
|
344 |
+
# Resize input video for the model
|
345 |
+
input_video = np.zeros((len(video_arr), VIDEO_INPUT_RESO[0], VIDEO_INPUT_RESO[1], 3), dtype=np.uint8)
|
346 |
+
for i in range(len(video_arr)):
|
347 |
+
input_video[i] = cv2.resize(video_arr[i], (VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]), interpolation=cv2.INTER_LINEAR)
|
348 |
+
|
349 |
+
end_time = time.time()
|
350 |
+
print(f"Time taken to resize videos: {end_time - start_time} seconds")
|
351 |
+
|
352 |
interactive = True
|
353 |
|
354 |
return (
|
355 |
+
video_arr, # Original video
|
356 |
+
preview_video, # Preview video at resized resolution
|
357 |
+
preview_video.copy(), # Copy for visualization
|
358 |
+
input_video, # Resized input video for model
|
359 |
+
video_fps,
|
360 |
+
gr.update(open=False),
|
361 |
+
preview_video[0],
|
362 |
+
gr.update(minimum=0, maximum=num_frames - 1, value=0, interactive=interactive),
|
363 |
+
[[] for _ in range(num_frames)],
|
364 |
+
[[] for _ in range(num_frames)],
|
365 |
+
[[] for _ in range(num_frames)],
|
366 |
+
0,
|
367 |
+
gr.update(interactive=interactive),
|
|
|
|
|
368 |
gr.update(interactive=interactive),
|
369 |
gr.update(interactive=interactive),
|
370 |
gr.update(interactive=True),
|
371 |
)
|
372 |
|
373 |
+
|
374 |
@spaces.GPU
|
375 |
def track(
|
376 |
video_preview,
|
377 |
+
video_input,
|
378 |
+
video_fps,
|
379 |
+
query_points,
|
380 |
+
query_points_color,
|
381 |
+
query_count,
|
382 |
):
|
383 |
+
tracking_mode = "selected"
|
384 |
+
if query_count == 0:
|
385 |
+
tracking_mode = "grid"
|
386 |
+
|
387 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
388 |
dtype = torch.float if device == "cuda" else torch.float
|
389 |
|
390 |
# Convert query points to tensor, normalize to input resolution
|
391 |
+
if tracking_mode != "grid":
|
392 |
query_points_tensor = []
|
393 |
for frame_points in query_points:
|
394 |
query_points_tensor.extend(frame_points)
|
395 |
+
|
396 |
query_points_tensor = torch.tensor(query_points_tensor).float()
|
397 |
+
query_points_tensor *= torch.tensor(
|
398 |
+
[VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0], 1]
|
399 |
+
) / torch.tensor([[video_preview.shape[2], video_preview.shape[1], 1]])
|
400 |
+
query_points_tensor = (
|
401 |
+
query_points_tensor[None].flip(-1).to(device, dtype)
|
402 |
+
) # xyt -> tyx
|
403 |
+
query_points_tensor = query_points_tensor[:, :, [0, 2, 1]] # tyx -> txy
|
404 |
|
405 |
video_input = torch.tensor(video_input).unsqueeze(0)
|
406 |
|
|
|
408 |
model = model.to(device)
|
409 |
|
410 |
video_input = video_input.permute(0, 1, 4, 2, 3)
|
411 |
+
if tracking_mode == "grid":
|
412 |
xy = get_points_on_a_grid(40, video_input.shape[3:], device=device)
|
413 |
queries = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
414 |
+
add_support_grid = False
|
415 |
cmap = matplotlib.colormaps.get_cmap("gist_rainbow")
|
416 |
query_points_color = [[]]
|
417 |
query_count = queries.shape[1]
|
|
|
427 |
# queries__ = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
428 |
num_tracks = queries.shape[1]
|
429 |
# queries = torch.cat([queries,queries__],dim=1)
|
430 |
+
add_support_grid = True
|
431 |
+
|
432 |
+
model(
|
433 |
+
video_chunk=video_input[:, :1].to(device, dtype),
|
434 |
+
is_first_step=True,
|
435 |
+
grid_size=0,
|
436 |
+
queries=queries,
|
437 |
+
add_support_grid=add_support_grid,
|
438 |
+
)
|
439 |
+
#
|
440 |
for ind in range(0, video_input.shape[1] - model.step, model.step):
|
441 |
pred_tracks, pred_visibility = model(
|
442 |
video_chunk=video_input[:, ind : ind + model.step * 2].to(device, dtype),
|
443 |
+
grid_size=0,
|
444 |
+
queries=queries,
|
445 |
+
add_support_grid=add_support_grid,
|
446 |
) # B T N 2, B T N 1
|
447 |
+
tracks = (
|
448 |
+
(
|
449 |
+
pred_tracks
|
450 |
+
* torch.tensor([video_preview.shape[2], video_preview.shape[1]]).to(device)
|
451 |
+
/ torch.tensor([VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]]).to(device)
|
452 |
+
)[0]
|
453 |
+
.permute(1, 0, 2)
|
454 |
+
.cpu()
|
455 |
+
.numpy()
|
456 |
+
)
|
457 |
pred_occ = torch.ones_like(pred_visibility[0]).permute(1, 0).cpu().numpy()
|
458 |
|
459 |
# make color array
|
|
|
464 |
|
465 |
# pred_tracks = torch.cat([pred_tracks[:,:1],(pred_tracks[:,:-2] + pred_tracks[:,1:-1] + pred_tracks[:,2:])/ 3, pred_tracks[:,-1:]],dim=1)
|
466 |
# torch.cat([pred_tracks[:,:1],pred_tracks[:,1:]],dim=1)
|
467 |
+
pred_tracks = (
|
468 |
+
pred_tracks
|
469 |
+
* torch.tensor([video_preview.shape[2], video_preview.shape[1]]).to(device)
|
470 |
+
/ torch.tensor([VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]]).to(device)
|
471 |
+
)
|
472 |
+
|
473 |
+
vis = Visualizer(
|
474 |
+
save_dir="./saved_videos", pad_value=0, linewidth=2, tracks_leave_trace=0
|
475 |
+
)
|
476 |
# segm_mask = torch.zeros(queries.shape[1])
|
477 |
# segm_mask[:num_tracks] = 1
|
478 |
# print('segm_mask',segm_mask.shape, segm_mask)
|
479 |
+
# segm_mask=segm_mask,
|
480 |
+
painted_video = (
|
481 |
+
vis.visualize(
|
482 |
+
torch.tensor(video_preview)
|
483 |
+
.permute(0, 3, 1, 2)[None]
|
484 |
+
.to(pred_tracks.device),
|
485 |
+
pred_tracks,
|
486 |
+
pred_visibility,
|
487 |
+
save_video=False,
|
488 |
+
)[0]
|
489 |
+
.permute(0, 2, 3, 1)
|
490 |
+
.cpu()
|
491 |
+
.numpy()
|
492 |
+
)
|
493 |
+
|
494 |
# painted_video = paint_point_track(video_preview,tracks,pred_occ,colors)
|
495 |
|
496 |
# save video
|
|
|
516 |
is_tracked_query = gr.State([])
|
517 |
query_count = gr.State(0)
|
518 |
|
519 |
+
gr.Markdown(
|
520 |
+
"# 🎨 CoTracker3: Simpler and Better Point Tracking by Pseudo-Labelling Real Videos"
|
521 |
+
)
|
522 |
+
gr.Markdown(
|
523 |
+
"<div style='text-align: left;'> \
|
524 |
<p>Welcome to <a href='https://cotracker3.github.io/' target='_blank'>CoTracker</a>! This space demonstrates point (pixel) tracking in videos. \
|
525 |
The model tracks points on a grid or points selected by you. </p> \
|
526 |
<p> To get started, simply upload your <b>.mp4</b> video or click on one of the example videos to load them. The shorter the video, the faster the processing. We recommend submitting short videos of length <b>2-7 seconds</b>.</p> \
|
|
|
528 |
<p style='text-align: left'>For more details, check out our <a href='https://github.com/facebookresearch/co-tracker' target='_blank'>GitHub Repo</a> ⭐. We thank the authors of LocoTrack for their interactive demo.</p> \
|
529 |
</div>"
|
530 |
)
|
|
|
531 |
|
532 |
+
gr.Markdown(
|
533 |
+
"## First step: upload your video or select an example video, and click submit."
|
534 |
+
)
|
535 |
with gr.Row():
|
|
|
536 |
|
537 |
with gr.Accordion("Your video input", open=True) as video_in_drawer:
|
538 |
video_in = gr.Video(label="Video Input", format="mp4")
|
539 |
submit = gr.Button("Submit", scale=0)
|
540 |
|
541 |
import os
|
542 |
+
|
543 |
apple = os.path.join(os.path.dirname(__file__), "videos", "apple.mp4")
|
544 |
bear = os.path.join(os.path.dirname(__file__), "videos", "bear.mp4")
|
545 |
paragliding_launch = os.path.join(
|
546 |
os.path.dirname(__file__), "videos", "paragliding-launch.mp4"
|
547 |
)
|
548 |
+
paragliding = os.path.join(
|
549 |
+
os.path.dirname(__file__), "videos", "paragliding.mp4"
|
550 |
+
)
|
551 |
cat = os.path.join(os.path.dirname(__file__), "videos", "cat.mp4")
|
552 |
pillow = os.path.join(os.path.dirname(__file__), "videos", "pillow.mp4")
|
553 |
teddy = os.path.join(os.path.dirname(__file__), "videos", "teddy.mp4")
|
554 |
backpack = os.path.join(os.path.dirname(__file__), "videos", "backpack.mp4")
|
555 |
|
556 |
+
gr.Examples(
|
557 |
+
examples=[
|
558 |
+
bear,
|
559 |
+
apple,
|
560 |
+
paragliding,
|
561 |
+
paragliding_launch,
|
562 |
+
cat,
|
563 |
+
pillow,
|
564 |
+
teddy,
|
565 |
+
backpack,
|
566 |
+
],
|
567 |
+
inputs=[video_in],
|
568 |
+
)
|
569 |
|
570 |
+
gr.Markdown(
|
571 |
+
'## Second step: Simply click "Track" to track a grid of points or select query points on the video before clicking'
|
572 |
+
)
|
|
|
|
|
|
|
|
|
|
|
573 |
with gr.Row():
|
574 |
with gr.Column():
|
575 |
with gr.Row():
|
576 |
query_frames = gr.Slider(
|
577 |
+
minimum=0,
|
578 |
+
maximum=100,
|
579 |
+
value=0,
|
580 |
+
step=1,
|
581 |
+
label="Choose Frame",
|
582 |
+
interactive=False,
|
583 |
+
)
|
584 |
with gr.Row():
|
585 |
undo = gr.Button("Undo", interactive=False)
|
586 |
clear_frame = gr.Button("Clear Frame", interactive=False)
|
|
|
588 |
|
589 |
with gr.Row():
|
590 |
current_frame = gr.Image(
|
591 |
+
label="Click to add query points", type="numpy", interactive=False
|
|
|
|
|
592 |
)
|
593 |
+
|
594 |
with gr.Row():
|
595 |
track_button = gr.Button("Track", interactive=False)
|
596 |
|
|
|
602 |
loop=True,
|
603 |
)
|
604 |
|
|
|
|
|
605 |
submit.click(
|
606 |
+
fn=preprocess_video_input,
|
607 |
+
inputs=[video_in],
|
608 |
+
outputs=[
|
609 |
video,
|
610 |
video_preview,
|
611 |
video_queried_preview,
|
|
|
623 |
clear_all,
|
624 |
track_button,
|
625 |
],
|
626 |
+
queue=False,
|
627 |
)
|
628 |
|
629 |
query_frames.change(
|
630 |
+
fn=choose_frame,
|
631 |
+
inputs=[query_frames, video_queried_preview],
|
632 |
+
outputs=[
|
633 |
current_frame,
|
634 |
],
|
635 |
+
queue=False,
|
636 |
)
|
637 |
|
638 |
current_frame.select(
|
639 |
+
fn=get_point,
|
640 |
+
inputs=[
|
641 |
query_frames,
|
642 |
video_queried_preview,
|
643 |
query_points,
|
644 |
query_points_color,
|
645 |
query_count,
|
646 |
+
],
|
647 |
+
outputs=[
|
648 |
current_frame,
|
649 |
video_queried_preview,
|
650 |
query_points,
|
651 |
query_points_color,
|
652 |
+
query_count,
|
653 |
+
],
|
654 |
+
queue=False,
|
655 |
)
|
656 |
+
|
657 |
undo.click(
|
658 |
+
fn=undo_point,
|
659 |
+
inputs=[
|
660 |
query_frames,
|
661 |
video_preview,
|
662 |
video_queried_preview,
|
663 |
query_points,
|
664 |
query_points_color,
|
665 |
+
query_count,
|
666 |
],
|
667 |
+
outputs=[
|
668 |
current_frame,
|
669 |
video_queried_preview,
|
670 |
query_points,
|
671 |
query_points_color,
|
672 |
+
query_count,
|
673 |
],
|
674 |
+
queue=False,
|
675 |
)
|
676 |
|
677 |
clear_frame.click(
|
678 |
+
fn=clear_frame_fn,
|
679 |
+
inputs=[
|
680 |
query_frames,
|
681 |
video_preview,
|
682 |
video_queried_preview,
|
683 |
query_points,
|
684 |
query_points_color,
|
685 |
+
query_count,
|
686 |
],
|
687 |
+
outputs=[
|
688 |
current_frame,
|
689 |
video_queried_preview,
|
690 |
query_points,
|
691 |
query_points_color,
|
692 |
+
query_count,
|
693 |
],
|
694 |
+
queue=False,
|
695 |
)
|
696 |
|
697 |
clear_all.click(
|
698 |
+
fn=clear_all_fn,
|
699 |
+
inputs=[
|
700 |
query_frames,
|
701 |
video_preview,
|
702 |
],
|
703 |
+
outputs=[
|
704 |
current_frame,
|
705 |
video_queried_preview,
|
706 |
query_points,
|
707 |
query_points_color,
|
708 |
+
query_count,
|
709 |
],
|
710 |
+
queue=False,
|
711 |
)
|
712 |
|
|
|
713 |
track_button.click(
|
714 |
+
fn=track,
|
715 |
+
inputs=[
|
716 |
video_preview,
|
717 |
video_input,
|
718 |
video_fps,
|
|
|
720 |
query_points_color,
|
721 |
query_count,
|
722 |
],
|
723 |
+
outputs=[
|
724 |
output_video,
|
725 |
],
|
726 |
+
queue=True,
|
727 |
)
|
728 |
|
729 |
+
|
730 |
+
demo.launch(show_api=False, show_error=True, debug=False, share=False)
|