nikkar commited on
Commit
8a1b8f9
·
verified ·
1 Parent(s): 8d0855d

speed up video uploading

Browse files
Files changed (1) hide show
  1. app.py +342 -231
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # This Gradio demo code is from https://github.com/cvlab-kaist/locotrack/blob/main/demo/demo.py
2
  # We updated it to work with CoTracker3 models. We thank authors of LocoTrack
3
  # for such an amazing Gradio demo.
4
 
@@ -22,18 +22,35 @@ from visualizer import Visualizer
22
 
23
  # Generate random colormaps for visualizing different points.
24
  def get_colors(num_colors: int) -> List[Tuple[int, int, int]]:
25
- """Gets colormap for points."""
26
- colors = []
27
- for i in np.arange(0.0, 360.0, 360.0 / num_colors):
28
- hue = i / 360.0
29
- lightness = (50 + np.random.rand() * 10) / 100.0
30
- saturation = (90 + np.random.rand() * 10) / 100.0
31
- color = colorsys.hls_to_rgb(hue, lightness, saturation)
32
- colors.append(
33
- (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
34
- )
35
- random.shuffle(colors)
36
- return colors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def get_points_on_a_grid(
39
  size: int,
@@ -84,91 +101,99 @@ def get_points_on_a_grid(
84
  )
85
  return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
86
 
 
87
  def paint_point_track(
88
  frames: np.ndarray,
89
  point_tracks: np.ndarray,
90
  visibles: np.ndarray,
91
  colormap: Optional[List[Tuple[int, int, int]]] = None,
92
  ) -> np.ndarray:
93
- """Converts a sequence of points to color code video.
94
- Args:
95
- frames: [num_frames, height, width, 3], np.uint8, [0, 255]
96
- point_tracks: [num_points, num_frames, 2], np.float32, [0, width / height]
97
- visibles: [num_points, num_frames], bool
98
- colormap: colormap for points, each point has a different RGB color.
99
- Returns:
100
- video: [num_frames, height, width, 3], np.uint8, [0, 255]
101
- """
102
- num_points, num_frames = point_tracks.shape[0:2]
103
- if colormap is None:
104
- colormap = get_colors(num_colors=num_points)
105
- height, width = frames.shape[1:3]
106
- dot_size_as_fraction_of_min_edge = 0.015
107
- radius = int(round(min(height, width) * dot_size_as_fraction_of_min_edge))
108
- diam = radius * 2 + 1
109
- quadratic_y = np.square(np.arange(diam)[:, np.newaxis] - radius - 1)
110
- quadratic_x = np.square(np.arange(diam)[np.newaxis, :] - radius - 1)
111
- icon = (quadratic_y + quadratic_x) - (radius**2) / 2.0
112
- sharpness = 0.15
113
- icon = np.clip(icon / (radius * 2 * sharpness), 0, 1)
114
- icon = 1 - icon[:, :, np.newaxis]
115
- icon1 = np.pad(icon, [(0, 1), (0, 1), (0, 0)])
116
- icon2 = np.pad(icon, [(1, 0), (0, 1), (0, 0)])
117
- icon3 = np.pad(icon, [(0, 1), (1, 0), (0, 0)])
118
- icon4 = np.pad(icon, [(1, 0), (1, 0), (0, 0)])
119
-
120
- video = frames.copy()
121
- for t in range(num_frames):
122
- # Pad so that points that extend outside the image frame don't crash us
123
- image = np.pad(
124
- video[t],
125
- [
126
- (radius + 1, radius + 1),
127
- (radius + 1, radius + 1),
128
- (0, 0),
129
- ],
130
- )
131
- for i in range(num_points):
132
- # The icon is centered at the center of a pixel, but the input coordinates
133
- # are raster coordinates. Therefore, to render a point at (1,1) (which
134
- # lies on the corner between four pixels), we need 1/4 of the icon placed
135
- # centered on the 0'th row, 0'th column, etc. We need to subtract
136
- # 0.5 to make the fractional position come out right.
137
- x, y = point_tracks[i, t, :] + 0.5
138
- x = min(max(x, 0.0), width)
139
- y = min(max(y, 0.0), height)
140
-
141
- if visibles[i, t]:
142
- x1, y1 = np.floor(x).astype(np.int32), np.floor(y).astype(np.int32)
143
- x2, y2 = x1 + 1, y1 + 1
144
-
145
- # bilinear interpolation
146
- patch = (
147
- icon1 * (x2 - x) * (y2 - y)
148
- + icon2 * (x2 - x) * (y - y1)
149
- + icon3 * (x - x1) * (y2 - y)
150
- + icon4 * (x - x1) * (y - y1)
151
  )
152
- x_ub = x1 + 2 * radius + 2
153
- y_ub = y1 + 2 * radius + 2
154
- image[y1:y_ub, x1:x_ub, :] = (1 - patch) * image[
155
- y1:y_ub, x1:x_ub, :
156
- ] + patch * np.array(colormap[i])[np.newaxis, np.newaxis, :]
157
-
158
- # Remove the pad
159
- video[t] = image[
160
- radius + 1 : -radius - 1, radius + 1 : -radius - 1
161
- ].astype(np.uint8)
162
- return video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
 
165
- PREVIEW_WIDTH = 768 # Width of the preview video
166
- VIDEO_INPUT_RESO = (384, 512) # Resolution of the input video
167
- POINT_SIZE = 4 # Size of the query point in the preview video
168
- FRAME_LIMIT = 300 # Limit the number of frames to process
169
 
170
 
171
- def get_point(frame_num, video_queried_preview, query_points, query_points_color, query_count, evt: gr.SelectData):
 
 
 
 
 
 
 
172
  print(f"You selected {(evt.index[0], evt.index[1], frame_num)}")
173
 
174
  current_frame = video_queried_preview[int(frame_num)]
@@ -192,22 +217,29 @@ def get_point(frame_num, video_queried_preview, query_points, query_points_color
192
  # Update the query count
193
  query_count += 1
194
  return (
195
- current_frame_draw, # Updated frame for preview
196
- video_queried_preview, # Updated preview video
197
- query_points, # Updated query points
198
- query_points_color, # Updated query points color
199
- query_count # Updated query count
200
  )
201
 
202
 
203
- def undo_point(frame_num, video_preview, video_queried_preview, query_points, query_points_color, query_count):
 
 
 
 
 
 
 
204
  if len(query_points[int(frame_num)]) == 0:
205
  return (
206
  video_queried_preview[int(frame_num)],
207
  video_queried_preview,
208
  query_points,
209
  query_points_color,
210
- query_count
211
  )
212
 
213
  # Get the last point
@@ -216,9 +248,13 @@ def undo_point(frame_num, video_preview, video_queried_preview, query_points, qu
216
 
217
  # Redraw the frame
218
  current_frame_draw = video_preview[int(frame_num)].copy()
219
- for point, color in zip(query_points[int(frame_num)], query_points_color[int(frame_num)]):
 
 
220
  x, y, _ = point
221
- current_frame_draw = cv2.circle(current_frame_draw, (x, y), POINT_SIZE, color, -1)
 
 
222
 
223
  # Update the query count
224
  query_count -= 1
@@ -226,15 +262,22 @@ def undo_point(frame_num, video_preview, video_queried_preview, query_points, qu
226
  # Update the frame
227
  video_queried_preview[int(frame_num)] = current_frame_draw
228
  return (
229
- current_frame_draw, # Updated frame for preview
230
- video_queried_preview, # Updated preview video
231
- query_points, # Updated query points
232
- query_points_color, # Updated query points color
233
- query_count # Updated query count
234
  )
235
 
236
 
237
- def clear_frame_fn(frame_num, video_preview, video_queried_preview, query_points, query_points_color, query_count):
 
 
 
 
 
 
 
238
  query_count -= len(query_points[int(frame_num)])
239
 
240
  query_points[int(frame_num)] = []
@@ -243,22 +286,21 @@ def clear_frame_fn(frame_num, video_preview, video_queried_preview, query_points
243
  video_queried_preview[int(frame_num)] = video_preview[int(frame_num)].copy()
244
 
245
  return (
246
- video_preview[int(frame_num)], # Set the preview frame to the original frame
247
- video_queried_preview,
248
- query_points, # Cleared query points
249
- query_points_color, # Cleared query points color
250
- query_count # New query count
251
  )
252
 
253
 
254
-
255
  def clear_all_fn(frame_num, video_preview):
256
  return (
257
  video_preview[int(frame_num)],
258
  video_preview.copy(),
259
  [[] for _ in range(len(video_preview))],
260
  [[] for _ in range(len(video_preview))],
261
- 0
262
  )
263
 
264
 
@@ -267,77 +309,98 @@ def choose_frame(frame_num, video_preview_array):
267
 
268
 
269
  def preprocess_video_input(video_path):
270
- video_arr = mediapy.read_video(video_path)
271
- video_fps = video_arr.metadata.fps
 
 
 
 
 
 
 
272
  num_frames = video_arr.shape[0]
273
  if num_frames > FRAME_LIMIT:
274
- gr.Warning(f"The video is too long. Only the first {FRAME_LIMIT} frames will be used.", duration=5)
 
 
 
275
  video_arr = video_arr[:FRAME_LIMIT]
276
  num_frames = FRAME_LIMIT
277
 
278
- # Resize to preview size for faster processing, width = PREVIEW_WIDTH
279
- height, width = video_arr.shape[1:3]
280
- new_height, new_width = int(PREVIEW_WIDTH * height / width), PREVIEW_WIDTH
281
-
282
- preview_video = mediapy.resize_video(video_arr, (new_height, new_width))
283
- input_video = mediapy.resize_video(video_arr, VIDEO_INPUT_RESO)
284
-
285
- preview_video = np.array(preview_video)
286
- input_video = np.array(input_video)
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  interactive = True
289
 
290
  return (
291
- video_arr, # Original video
292
- preview_video, # Original preview video, resized for faster processing
293
- preview_video.copy(), # Copy of preview video for visualization
294
- input_video, # Resized video input for model
295
- # None, # video_feature, # Extracted feature
296
- video_fps, # Set the video FPS
297
- gr.update(open=False), # Close the video input drawer
298
- # tracking_mode, # Set the tracking mode
299
- preview_video[0], # Set the preview frame to the first frame
300
- gr.update(minimum=0, maximum=num_frames - 1, value=0, interactive=interactive), # Set slider interactive
301
- [[] for _ in range(num_frames)], # Set query_points to empty
302
- [[] for _ in range(num_frames)], # Set query_points_color to empty
303
- [[] for _ in range(num_frames)],
304
- 0, # Set query count to 0
305
- gr.update(interactive=interactive), # Make the buttons interactive
306
  gr.update(interactive=interactive),
307
  gr.update(interactive=interactive),
308
  gr.update(interactive=True),
309
  )
310
 
 
311
  @spaces.GPU
312
  def track(
313
  video_preview,
314
- video_input,
315
- video_fps,
316
- query_points,
317
- query_points_color,
318
- query_count,
319
  ):
320
- tracking_mode = 'selected'
321
- if query_count == 0:
322
- tracking_mode='grid'
323
-
324
  device = "cuda" if torch.cuda.is_available() else "cpu"
325
  dtype = torch.float if device == "cuda" else torch.float
326
 
327
  # Convert query points to tensor, normalize to input resolution
328
- if tracking_mode!='grid':
329
  query_points_tensor = []
330
  for frame_points in query_points:
331
  query_points_tensor.extend(frame_points)
332
-
333
  query_points_tensor = torch.tensor(query_points_tensor).float()
334
- query_points_tensor *= torch.tensor([
335
- VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0], 1
336
- ]) / torch.tensor([
337
- [video_preview.shape[2], video_preview.shape[1], 1]
338
- ])
339
- query_points_tensor = query_points_tensor[None].flip(-1).to(device, dtype) # xyt -> tyx
340
- query_points_tensor = query_points_tensor[:, :, [0, 2, 1]] # tyx -> txy
341
 
342
  video_input = torch.tensor(video_input).unsqueeze(0)
343
 
@@ -345,10 +408,10 @@ def track(
345
  model = model.to(device)
346
 
347
  video_input = video_input.permute(0, 1, 4, 2, 3)
348
- if tracking_mode=='grid':
349
  xy = get_points_on_a_grid(40, video_input.shape[3:], device=device)
350
  queries = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
351
- add_support_grid=False
352
  cmap = matplotlib.colormaps.get_cmap("gist_rainbow")
353
  query_points_color = [[]]
354
  query_count = queries.shape[1]
@@ -364,18 +427,33 @@ def track(
364
  # queries__ = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
365
  num_tracks = queries.shape[1]
366
  # queries = torch.cat([queries,queries__],dim=1)
367
- add_support_grid=True
368
-
369
- model(video_chunk=video_input[:,:1].to(device, dtype), is_first_step=True, grid_size=0, queries=queries, add_support_grid=add_support_grid)
370
- #
 
 
 
 
 
 
371
  for ind in range(0, video_input.shape[1] - model.step, model.step):
372
  pred_tracks, pred_visibility = model(
373
  video_chunk=video_input[:, ind : ind + model.step * 2].to(device, dtype),
374
- grid_size=0,
375
- queries=queries,
376
- add_support_grid=add_support_grid
377
  ) # B T N 2, B T N 1
378
- tracks = (pred_tracks * torch.tensor([video_preview.shape[2], video_preview.shape[1]]).to(device) / torch.tensor([VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]]).to(device))[0].permute(1, 0, 2).cpu().numpy()
 
 
 
 
 
 
 
 
 
379
  pred_occ = torch.ones_like(pred_visibility[0]).permute(1, 0).cpu().numpy()
380
 
381
  # make color array
@@ -386,15 +464,33 @@ def track(
386
 
387
  # pred_tracks = torch.cat([pred_tracks[:,:1],(pred_tracks[:,:-2] + pred_tracks[:,1:-1] + pred_tracks[:,2:])/ 3, pred_tracks[:,-1:]],dim=1)
388
  # torch.cat([pred_tracks[:,:1],pred_tracks[:,1:]],dim=1)
389
- pred_tracks = (pred_tracks * torch.tensor([video_preview.shape[2], video_preview.shape[1]]).to(device) / torch.tensor([VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]]).to(device))
390
-
391
- vis = Visualizer(save_dir="./saved_videos", pad_value=0, linewidth=2, tracks_leave_trace=0)
 
 
 
 
 
 
392
  # segm_mask = torch.zeros(queries.shape[1])
393
  # segm_mask[:num_tracks] = 1
394
  # print('segm_mask',segm_mask.shape, segm_mask)
395
- # segm_mask=segm_mask,
396
- painted_video = vis.visualize(torch.tensor(video_preview).permute(0, 3, 1, 2)[None].to(pred_tracks.device), pred_tracks, pred_visibility, save_video=False)[0].permute(0, 2, 3, 1).cpu().numpy()
397
-
 
 
 
 
 
 
 
 
 
 
 
 
398
  # painted_video = paint_point_track(video_preview,tracks,pred_occ,colors)
399
 
400
  # save video
@@ -420,8 +516,11 @@ with gr.Blocks() as demo:
420
  is_tracked_query = gr.State([])
421
  query_count = gr.State(0)
422
 
423
- gr.Markdown("# 🎨 CoTracker3: Simpler and Better Point Tracking by Pseudo-Labelling Real Videos")
424
- gr.Markdown("<div style='text-align: left;'> \
 
 
 
425
  <p>Welcome to <a href='https://cotracker3.github.io/' target='_blank'>CoTracker</a>! This space demonstrates point (pixel) tracking in videos. \
426
  The model tracks points on a grid or points selected by you. </p> \
427
  <p> To get started, simply upload your <b>.mp4</b> video or click on one of the example videos to load them. The shorter the video, the faster the processing. We recommend submitting short videos of length <b>2-7 seconds</b>.</p> \
@@ -429,42 +528,59 @@ with gr.Blocks() as demo:
429
  <p style='text-align: left'>For more details, check out our <a href='https://github.com/facebookresearch/co-tracker' target='_blank'>GitHub Repo</a> ⭐. We thank the authors of LocoTrack for their interactive demo.</p> \
430
  </div>"
431
  )
432
-
433
 
434
- gr.Markdown("## First step: upload your video or select an example video, and click submit.")
 
 
435
  with gr.Row():
436
-
437
 
438
  with gr.Accordion("Your video input", open=True) as video_in_drawer:
439
  video_in = gr.Video(label="Video Input", format="mp4")
440
  submit = gr.Button("Submit", scale=0)
441
 
442
  import os
 
443
  apple = os.path.join(os.path.dirname(__file__), "videos", "apple.mp4")
444
  bear = os.path.join(os.path.dirname(__file__), "videos", "bear.mp4")
445
  paragliding_launch = os.path.join(
446
  os.path.dirname(__file__), "videos", "paragliding-launch.mp4"
447
  )
448
- paragliding = os.path.join(os.path.dirname(__file__), "videos", "paragliding.mp4")
 
 
449
  cat = os.path.join(os.path.dirname(__file__), "videos", "cat.mp4")
450
  pillow = os.path.join(os.path.dirname(__file__), "videos", "pillow.mp4")
451
  teddy = os.path.join(os.path.dirname(__file__), "videos", "teddy.mp4")
452
  backpack = os.path.join(os.path.dirname(__file__), "videos", "backpack.mp4")
453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
- gr.Examples(examples=[bear, apple, paragliding, paragliding_launch, cat, pillow, teddy, backpack],
456
- inputs = [
457
- video_in
458
- ],
459
- )
460
-
461
-
462
- gr.Markdown("## Second step: Simply click \"Track\" to track a grid of points or select query points on the video before clicking")
463
  with gr.Row():
464
  with gr.Column():
465
  with gr.Row():
466
  query_frames = gr.Slider(
467
- minimum=0, maximum=100, value=0, step=1, label="Choose Frame", interactive=False)
 
 
 
 
 
 
468
  with gr.Row():
469
  undo = gr.Button("Undo", interactive=False)
470
  clear_frame = gr.Button("Clear Frame", interactive=False)
@@ -472,11 +588,9 @@ with gr.Blocks() as demo:
472
 
473
  with gr.Row():
474
  current_frame = gr.Image(
475
- label="Click to add query points",
476
- type="numpy",
477
- interactive=False
478
  )
479
-
480
  with gr.Row():
481
  track_button = gr.Button("Track", interactive=False)
482
 
@@ -488,12 +602,10 @@ with gr.Blocks() as demo:
488
  loop=True,
489
  )
490
 
491
-
492
-
493
  submit.click(
494
- fn = preprocess_video_input,
495
- inputs = [video_in],
496
- outputs = [
497
  video,
498
  video_preview,
499
  video_queried_preview,
@@ -511,97 +623,96 @@ with gr.Blocks() as demo:
511
  clear_all,
512
  track_button,
513
  ],
514
- queue = False
515
  )
516
 
517
  query_frames.change(
518
- fn = choose_frame,
519
- inputs = [query_frames, video_queried_preview],
520
- outputs = [
521
  current_frame,
522
  ],
523
- queue = False
524
  )
525
 
526
  current_frame.select(
527
- fn = get_point,
528
- inputs = [
529
  query_frames,
530
  video_queried_preview,
531
  query_points,
532
  query_points_color,
533
  query_count,
534
- ],
535
- outputs = [
536
  current_frame,
537
  video_queried_preview,
538
  query_points,
539
  query_points_color,
540
- query_count
541
- ],
542
- queue = False
543
  )
544
-
545
  undo.click(
546
- fn = undo_point,
547
- inputs = [
548
  query_frames,
549
  video_preview,
550
  video_queried_preview,
551
  query_points,
552
  query_points_color,
553
- query_count
554
  ],
555
- outputs = [
556
  current_frame,
557
  video_queried_preview,
558
  query_points,
559
  query_points_color,
560
- query_count
561
  ],
562
- queue = False
563
  )
564
 
565
  clear_frame.click(
566
- fn = clear_frame_fn,
567
- inputs = [
568
  query_frames,
569
  video_preview,
570
  video_queried_preview,
571
  query_points,
572
  query_points_color,
573
- query_count
574
  ],
575
- outputs = [
576
  current_frame,
577
  video_queried_preview,
578
  query_points,
579
  query_points_color,
580
- query_count
581
  ],
582
- queue = False
583
  )
584
 
585
  clear_all.click(
586
- fn = clear_all_fn,
587
- inputs = [
588
  query_frames,
589
  video_preview,
590
  ],
591
- outputs = [
592
  current_frame,
593
  video_queried_preview,
594
  query_points,
595
  query_points_color,
596
- query_count
597
  ],
598
- queue = False
599
  )
600
 
601
-
602
  track_button.click(
603
- fn = track,
604
- inputs = [
605
  video_preview,
606
  video_input,
607
  video_fps,
@@ -609,11 +720,11 @@ with gr.Blocks() as demo:
609
  query_points_color,
610
  query_count,
611
  ],
612
- outputs = [
613
  output_video,
614
  ],
615
- queue = True,
616
  )
617
 
618
-
619
- demo.launch(show_api=False, show_error=True, debug=False, share=False)
 
1
+ # This Gradio demo code is from https://github.com/cvlab-kaist/locotrack/blob/main/demo/demo.py
2
  # We updated it to work with CoTracker3 models. We thank authors of LocoTrack
3
  # for such an amazing Gradio demo.
4
 
 
22
 
23
  # Generate random colormaps for visualizing different points.
24
  def get_colors(num_colors: int) -> List[Tuple[int, int, int]]:
25
+ """Gets colormap for points."""
26
+ colors = []
27
+ for i in np.arange(0.0, 360.0, 360.0 / num_colors):
28
+ hue = i / 360.0
29
+ lightness = (50 + np.random.rand() * 10) / 100.0
30
+ saturation = (90 + np.random.rand() * 10) / 100.0
31
+ color = colorsys.hls_to_rgb(hue, lightness, saturation)
32
+ colors.append((int(color[0] * 255), int(color[1] * 255), int(color[2] * 255)))
33
+ random.shuffle(colors)
34
+ return colors
35
+
36
+ def read_video_cv2(video_path):
37
+ cap = cv2.VideoCapture(video_path)
38
+ frames = []
39
+
40
+ # Get FPS from video metadata
41
+ fps = cap.get(cv2.CAP_PROP_FPS)
42
+
43
+ while cap.isOpened():
44
+ ret, frame = cap.read()
45
+ if not ret:
46
+ break
47
+ # Convert BGR to RGB
48
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
49
+ frames.append(frame)
50
+
51
+ cap.release()
52
+ video_arr = np.array(frames)
53
+ return video_arr, fps
54
 
55
  def get_points_on_a_grid(
56
  size: int,
 
101
  )
102
  return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
103
 
104
+
105
  def paint_point_track(
106
  frames: np.ndarray,
107
  point_tracks: np.ndarray,
108
  visibles: np.ndarray,
109
  colormap: Optional[List[Tuple[int, int, int]]] = None,
110
  ) -> np.ndarray:
111
+ """Converts a sequence of points to color code video.
112
+ Args:
113
+ frames: [num_frames, height, width, 3], np.uint8, [0, 255]
114
+ point_tracks: [num_points, num_frames, 2], np.float32, [0, width / height]
115
+ visibles: [num_points, num_frames], bool
116
+ colormap: colormap for points, each point has a different RGB color.
117
+ Returns:
118
+ video: [num_frames, height, width, 3], np.uint8, [0, 255]
119
+ """
120
+ num_points, num_frames = point_tracks.shape[0:2]
121
+ if colormap is None:
122
+ colormap = get_colors(num_colors=num_points)
123
+ height, width = frames.shape[1:3]
124
+ dot_size_as_fraction_of_min_edge = 0.015
125
+ radius = int(round(min(height, width) * dot_size_as_fraction_of_min_edge))
126
+ diam = radius * 2 + 1
127
+ quadratic_y = np.square(np.arange(diam)[:, np.newaxis] - radius - 1)
128
+ quadratic_x = np.square(np.arange(diam)[np.newaxis, :] - radius - 1)
129
+ icon = (quadratic_y + quadratic_x) - (radius**2) / 2.0
130
+ sharpness = 0.15
131
+ icon = np.clip(icon / (radius * 2 * sharpness), 0, 1)
132
+ icon = 1 - icon[:, :, np.newaxis]
133
+ icon1 = np.pad(icon, [(0, 1), (0, 1), (0, 0)])
134
+ icon2 = np.pad(icon, [(1, 0), (0, 1), (0, 0)])
135
+ icon3 = np.pad(icon, [(0, 1), (1, 0), (0, 0)])
136
+ icon4 = np.pad(icon, [(1, 0), (1, 0), (0, 0)])
137
+
138
+ video = frames.copy()
139
+ for t in range(num_frames):
140
+ # Pad so that points that extend outside the image frame don't crash us
141
+ image = np.pad(
142
+ video[t],
143
+ [
144
+ (radius + 1, radius + 1),
145
+ (radius + 1, radius + 1),
146
+ (0, 0),
147
+ ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  )
149
+ for i in range(num_points):
150
+ # The icon is centered at the center of a pixel, but the input coordinates
151
+ # are raster coordinates. Therefore, to render a point at (1,1) (which
152
+ # lies on the corner between four pixels), we need 1/4 of the icon placed
153
+ # centered on the 0'th row, 0'th column, etc. We need to subtract
154
+ # 0.5 to make the fractional position come out right.
155
+ x, y = point_tracks[i, t, :] + 0.5
156
+ x = min(max(x, 0.0), width)
157
+ y = min(max(y, 0.0), height)
158
+
159
+ if visibles[i, t]:
160
+ x1, y1 = np.floor(x).astype(np.int32), np.floor(y).astype(np.int32)
161
+ x2, y2 = x1 + 1, y1 + 1
162
+
163
+ # bilinear interpolation
164
+ patch = (
165
+ icon1 * (x2 - x) * (y2 - y)
166
+ + icon2 * (x2 - x) * (y - y1)
167
+ + icon3 * (x - x1) * (y2 - y)
168
+ + icon4 * (x - x1) * (y - y1)
169
+ )
170
+ x_ub = x1 + 2 * radius + 2
171
+ y_ub = y1 + 2 * radius + 2
172
+ image[y1:y_ub, x1:x_ub, :] = (1 - patch) * image[
173
+ y1:y_ub, x1:x_ub, :
174
+ ] + patch * np.array(colormap[i])[np.newaxis, np.newaxis, :]
175
+
176
+ # Remove the pad
177
+ video[t] = image[radius + 1 : -radius - 1, radius + 1 : -radius - 1].astype(
178
+ np.uint8
179
+ )
180
+ return video
181
 
182
 
183
+ PREVIEW_WIDTH = 768 # Width of the preview video
184
+ VIDEO_INPUT_RESO = (384, 512) # Resolution of the input video
185
+ POINT_SIZE = 4 # Size of the query point in the preview video
186
+ FRAME_LIMIT = 600 # Limit the number of frames to process
187
 
188
 
189
+ def get_point(
190
+ frame_num,
191
+ video_queried_preview,
192
+ query_points,
193
+ query_points_color,
194
+ query_count,
195
+ evt: gr.SelectData,
196
+ ):
197
  print(f"You selected {(evt.index[0], evt.index[1], frame_num)}")
198
 
199
  current_frame = video_queried_preview[int(frame_num)]
 
217
  # Update the query count
218
  query_count += 1
219
  return (
220
+ current_frame_draw, # Updated frame for preview
221
+ video_queried_preview, # Updated preview video
222
+ query_points, # Updated query points
223
+ query_points_color, # Updated query points color
224
+ query_count, # Updated query count
225
  )
226
 
227
 
228
+ def undo_point(
229
+ frame_num,
230
+ video_preview,
231
+ video_queried_preview,
232
+ query_points,
233
+ query_points_color,
234
+ query_count,
235
+ ):
236
  if len(query_points[int(frame_num)]) == 0:
237
  return (
238
  video_queried_preview[int(frame_num)],
239
  video_queried_preview,
240
  query_points,
241
  query_points_color,
242
+ query_count,
243
  )
244
 
245
  # Get the last point
 
248
 
249
  # Redraw the frame
250
  current_frame_draw = video_preview[int(frame_num)].copy()
251
+ for point, color in zip(
252
+ query_points[int(frame_num)], query_points_color[int(frame_num)]
253
+ ):
254
  x, y, _ = point
255
+ current_frame_draw = cv2.circle(
256
+ current_frame_draw, (x, y), POINT_SIZE, color, -1
257
+ )
258
 
259
  # Update the query count
260
  query_count -= 1
 
262
  # Update the frame
263
  video_queried_preview[int(frame_num)] = current_frame_draw
264
  return (
265
+ current_frame_draw, # Updated frame for preview
266
+ video_queried_preview, # Updated preview video
267
+ query_points, # Updated query points
268
+ query_points_color, # Updated query points color
269
+ query_count, # Updated query count
270
  )
271
 
272
 
273
+ def clear_frame_fn(
274
+ frame_num,
275
+ video_preview,
276
+ video_queried_preview,
277
+ query_points,
278
+ query_points_color,
279
+ query_count,
280
+ ):
281
  query_count -= len(query_points[int(frame_num)])
282
 
283
  query_points[int(frame_num)] = []
 
286
  video_queried_preview[int(frame_num)] = video_preview[int(frame_num)].copy()
287
 
288
  return (
289
+ video_preview[int(frame_num)], # Set the preview frame to the original frame
290
+ video_queried_preview,
291
+ query_points, # Cleared query points
292
+ query_points_color, # Cleared query points color
293
+ query_count, # New query count
294
  )
295
 
296
 
 
297
  def clear_all_fn(frame_num, video_preview):
298
  return (
299
  video_preview[int(frame_num)],
300
  video_preview.copy(),
301
  [[] for _ in range(len(video_preview))],
302
  [[] for _ in range(len(video_preview))],
303
+ 0,
304
  )
305
 
306
 
 
309
 
310
 
311
  def preprocess_video_input(video_path):
312
+ import time
313
+ start_time = time.time()
314
+
315
+ # Read video and get FPS
316
+ video_arr, video_fps = read_video_cv2(video_path)
317
+ end_time = time.time()
318
+ print(f"Time taken to read video: {end_time - start_time} seconds")
319
+
320
+ # Apply frame limit
321
  num_frames = video_arr.shape[0]
322
  if num_frames > FRAME_LIMIT:
323
+ gr.Warning(
324
+ f"The video is too long. Only the first {FRAME_LIMIT} frames will be used.",
325
+ duration=5,
326
+ )
327
  video_arr = video_arr[:FRAME_LIMIT]
328
  num_frames = FRAME_LIMIT
329
 
330
+ start_time = time.time()
 
 
 
 
 
 
 
 
331
 
332
+ # Resize preview video while maintaining aspect ratio
333
+ h, w = video_arr.shape[1:3]
334
+ aspect_ratio = w / h
335
+ if w > PREVIEW_WIDTH:
336
+ new_w = PREVIEW_WIDTH
337
+ new_h = int(new_w / aspect_ratio)
338
+ preview_video = np.zeros((len(video_arr), new_h, new_w, 3), dtype=np.uint8)
339
+ for i in range(len(video_arr)):
340
+ preview_video[i] = cv2.resize(video_arr[i], (new_w, new_h), interpolation=cv2.INTER_LINEAR)
341
+ else:
342
+ preview_video = video_arr.copy()
343
+
344
+ # Resize input video for the model
345
+ input_video = np.zeros((len(video_arr), VIDEO_INPUT_RESO[0], VIDEO_INPUT_RESO[1], 3), dtype=np.uint8)
346
+ for i in range(len(video_arr)):
347
+ input_video[i] = cv2.resize(video_arr[i], (VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]), interpolation=cv2.INTER_LINEAR)
348
+
349
+ end_time = time.time()
350
+ print(f"Time taken to resize videos: {end_time - start_time} seconds")
351
+
352
  interactive = True
353
 
354
  return (
355
+ video_arr, # Original video
356
+ preview_video, # Preview video at resized resolution
357
+ preview_video.copy(), # Copy for visualization
358
+ input_video, # Resized input video for model
359
+ video_fps,
360
+ gr.update(open=False),
361
+ preview_video[0],
362
+ gr.update(minimum=0, maximum=num_frames - 1, value=0, interactive=interactive),
363
+ [[] for _ in range(num_frames)],
364
+ [[] for _ in range(num_frames)],
365
+ [[] for _ in range(num_frames)],
366
+ 0,
367
+ gr.update(interactive=interactive),
 
 
368
  gr.update(interactive=interactive),
369
  gr.update(interactive=interactive),
370
  gr.update(interactive=True),
371
  )
372
 
373
+
374
  @spaces.GPU
375
  def track(
376
  video_preview,
377
+ video_input,
378
+ video_fps,
379
+ query_points,
380
+ query_points_color,
381
+ query_count,
382
  ):
383
+ tracking_mode = "selected"
384
+ if query_count == 0:
385
+ tracking_mode = "grid"
386
+
387
  device = "cuda" if torch.cuda.is_available() else "cpu"
388
  dtype = torch.float if device == "cuda" else torch.float
389
 
390
  # Convert query points to tensor, normalize to input resolution
391
+ if tracking_mode != "grid":
392
  query_points_tensor = []
393
  for frame_points in query_points:
394
  query_points_tensor.extend(frame_points)
395
+
396
  query_points_tensor = torch.tensor(query_points_tensor).float()
397
+ query_points_tensor *= torch.tensor(
398
+ [VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0], 1]
399
+ ) / torch.tensor([[video_preview.shape[2], video_preview.shape[1], 1]])
400
+ query_points_tensor = (
401
+ query_points_tensor[None].flip(-1).to(device, dtype)
402
+ ) # xyt -> tyx
403
+ query_points_tensor = query_points_tensor[:, :, [0, 2, 1]] # tyx -> txy
404
 
405
  video_input = torch.tensor(video_input).unsqueeze(0)
406
 
 
408
  model = model.to(device)
409
 
410
  video_input = video_input.permute(0, 1, 4, 2, 3)
411
+ if tracking_mode == "grid":
412
  xy = get_points_on_a_grid(40, video_input.shape[3:], device=device)
413
  queries = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
414
+ add_support_grid = False
415
  cmap = matplotlib.colormaps.get_cmap("gist_rainbow")
416
  query_points_color = [[]]
417
  query_count = queries.shape[1]
 
427
  # queries__ = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
428
  num_tracks = queries.shape[1]
429
  # queries = torch.cat([queries,queries__],dim=1)
430
+ add_support_grid = True
431
+
432
+ model(
433
+ video_chunk=video_input[:, :1].to(device, dtype),
434
+ is_first_step=True,
435
+ grid_size=0,
436
+ queries=queries,
437
+ add_support_grid=add_support_grid,
438
+ )
439
+ #
440
  for ind in range(0, video_input.shape[1] - model.step, model.step):
441
  pred_tracks, pred_visibility = model(
442
  video_chunk=video_input[:, ind : ind + model.step * 2].to(device, dtype),
443
+ grid_size=0,
444
+ queries=queries,
445
+ add_support_grid=add_support_grid,
446
  ) # B T N 2, B T N 1
447
+ tracks = (
448
+ (
449
+ pred_tracks
450
+ * torch.tensor([video_preview.shape[2], video_preview.shape[1]]).to(device)
451
+ / torch.tensor([VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]]).to(device)
452
+ )[0]
453
+ .permute(1, 0, 2)
454
+ .cpu()
455
+ .numpy()
456
+ )
457
  pred_occ = torch.ones_like(pred_visibility[0]).permute(1, 0).cpu().numpy()
458
 
459
  # make color array
 
464
 
465
  # pred_tracks = torch.cat([pred_tracks[:,:1],(pred_tracks[:,:-2] + pred_tracks[:,1:-1] + pred_tracks[:,2:])/ 3, pred_tracks[:,-1:]],dim=1)
466
  # torch.cat([pred_tracks[:,:1],pred_tracks[:,1:]],dim=1)
467
+ pred_tracks = (
468
+ pred_tracks
469
+ * torch.tensor([video_preview.shape[2], video_preview.shape[1]]).to(device)
470
+ / torch.tensor([VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]]).to(device)
471
+ )
472
+
473
+ vis = Visualizer(
474
+ save_dir="./saved_videos", pad_value=0, linewidth=2, tracks_leave_trace=0
475
+ )
476
  # segm_mask = torch.zeros(queries.shape[1])
477
  # segm_mask[:num_tracks] = 1
478
  # print('segm_mask',segm_mask.shape, segm_mask)
479
+ # segm_mask=segm_mask,
480
+ painted_video = (
481
+ vis.visualize(
482
+ torch.tensor(video_preview)
483
+ .permute(0, 3, 1, 2)[None]
484
+ .to(pred_tracks.device),
485
+ pred_tracks,
486
+ pred_visibility,
487
+ save_video=False,
488
+ )[0]
489
+ .permute(0, 2, 3, 1)
490
+ .cpu()
491
+ .numpy()
492
+ )
493
+
494
  # painted_video = paint_point_track(video_preview,tracks,pred_occ,colors)
495
 
496
  # save video
 
516
  is_tracked_query = gr.State([])
517
  query_count = gr.State(0)
518
 
519
+ gr.Markdown(
520
+ "# 🎨 CoTracker3: Simpler and Better Point Tracking by Pseudo-Labelling Real Videos"
521
+ )
522
+ gr.Markdown(
523
+ "<div style='text-align: left;'> \
524
  <p>Welcome to <a href='https://cotracker3.github.io/' target='_blank'>CoTracker</a>! This space demonstrates point (pixel) tracking in videos. \
525
  The model tracks points on a grid or points selected by you. </p> \
526
  <p> To get started, simply upload your <b>.mp4</b> video or click on one of the example videos to load them. The shorter the video, the faster the processing. We recommend submitting short videos of length <b>2-7 seconds</b>.</p> \
 
528
  <p style='text-align: left'>For more details, check out our <a href='https://github.com/facebookresearch/co-tracker' target='_blank'>GitHub Repo</a> ⭐. We thank the authors of LocoTrack for their interactive demo.</p> \
529
  </div>"
530
  )
 
531
 
532
+ gr.Markdown(
533
+ "## First step: upload your video or select an example video, and click submit."
534
+ )
535
  with gr.Row():
 
536
 
537
  with gr.Accordion("Your video input", open=True) as video_in_drawer:
538
  video_in = gr.Video(label="Video Input", format="mp4")
539
  submit = gr.Button("Submit", scale=0)
540
 
541
  import os
542
+
543
  apple = os.path.join(os.path.dirname(__file__), "videos", "apple.mp4")
544
  bear = os.path.join(os.path.dirname(__file__), "videos", "bear.mp4")
545
  paragliding_launch = os.path.join(
546
  os.path.dirname(__file__), "videos", "paragliding-launch.mp4"
547
  )
548
+ paragliding = os.path.join(
549
+ os.path.dirname(__file__), "videos", "paragliding.mp4"
550
+ )
551
  cat = os.path.join(os.path.dirname(__file__), "videos", "cat.mp4")
552
  pillow = os.path.join(os.path.dirname(__file__), "videos", "pillow.mp4")
553
  teddy = os.path.join(os.path.dirname(__file__), "videos", "teddy.mp4")
554
  backpack = os.path.join(os.path.dirname(__file__), "videos", "backpack.mp4")
555
 
556
+ gr.Examples(
557
+ examples=[
558
+ bear,
559
+ apple,
560
+ paragliding,
561
+ paragliding_launch,
562
+ cat,
563
+ pillow,
564
+ teddy,
565
+ backpack,
566
+ ],
567
+ inputs=[video_in],
568
+ )
569
 
570
+ gr.Markdown(
571
+ '## Second step: Simply click "Track" to track a grid of points or select query points on the video before clicking'
572
+ )
 
 
 
 
 
573
  with gr.Row():
574
  with gr.Column():
575
  with gr.Row():
576
  query_frames = gr.Slider(
577
+ minimum=0,
578
+ maximum=100,
579
+ value=0,
580
+ step=1,
581
+ label="Choose Frame",
582
+ interactive=False,
583
+ )
584
  with gr.Row():
585
  undo = gr.Button("Undo", interactive=False)
586
  clear_frame = gr.Button("Clear Frame", interactive=False)
 
588
 
589
  with gr.Row():
590
  current_frame = gr.Image(
591
+ label="Click to add query points", type="numpy", interactive=False
 
 
592
  )
593
+
594
  with gr.Row():
595
  track_button = gr.Button("Track", interactive=False)
596
 
 
602
  loop=True,
603
  )
604
 
 
 
605
  submit.click(
606
+ fn=preprocess_video_input,
607
+ inputs=[video_in],
608
+ outputs=[
609
  video,
610
  video_preview,
611
  video_queried_preview,
 
623
  clear_all,
624
  track_button,
625
  ],
626
+ queue=False,
627
  )
628
 
629
  query_frames.change(
630
+ fn=choose_frame,
631
+ inputs=[query_frames, video_queried_preview],
632
+ outputs=[
633
  current_frame,
634
  ],
635
+ queue=False,
636
  )
637
 
638
  current_frame.select(
639
+ fn=get_point,
640
+ inputs=[
641
  query_frames,
642
  video_queried_preview,
643
  query_points,
644
  query_points_color,
645
  query_count,
646
+ ],
647
+ outputs=[
648
  current_frame,
649
  video_queried_preview,
650
  query_points,
651
  query_points_color,
652
+ query_count,
653
+ ],
654
+ queue=False,
655
  )
656
+
657
  undo.click(
658
+ fn=undo_point,
659
+ inputs=[
660
  query_frames,
661
  video_preview,
662
  video_queried_preview,
663
  query_points,
664
  query_points_color,
665
+ query_count,
666
  ],
667
+ outputs=[
668
  current_frame,
669
  video_queried_preview,
670
  query_points,
671
  query_points_color,
672
+ query_count,
673
  ],
674
+ queue=False,
675
  )
676
 
677
  clear_frame.click(
678
+ fn=clear_frame_fn,
679
+ inputs=[
680
  query_frames,
681
  video_preview,
682
  video_queried_preview,
683
  query_points,
684
  query_points_color,
685
+ query_count,
686
  ],
687
+ outputs=[
688
  current_frame,
689
  video_queried_preview,
690
  query_points,
691
  query_points_color,
692
+ query_count,
693
  ],
694
+ queue=False,
695
  )
696
 
697
  clear_all.click(
698
+ fn=clear_all_fn,
699
+ inputs=[
700
  query_frames,
701
  video_preview,
702
  ],
703
+ outputs=[
704
  current_frame,
705
  video_queried_preview,
706
  query_points,
707
  query_points_color,
708
+ query_count,
709
  ],
710
+ queue=False,
711
  )
712
 
 
713
  track_button.click(
714
+ fn=track,
715
+ inputs=[
716
  video_preview,
717
  video_input,
718
  video_fps,
 
720
  query_points_color,
721
  query_count,
722
  ],
723
+ outputs=[
724
  output_video,
725
  ],
726
+ queue=True,
727
  )
728
 
729
+
730
+ demo.launch(show_api=False, show_error=True, debug=False, share=False)