nikkar commited on
Commit
0af8fd1
·
verified ·
1 Parent(s): 0e0df35

Create predictor.py

Browse files
Files changed (1) hide show
  1. predictor.py +167 -0
predictor.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ def smart_cat(tensor1, tensor2, dim):
11
+ if tensor1 is None:
12
+ return tensor2
13
+ return torch.cat([tensor1, tensor2], dim=dim)
14
+
15
+
16
+ def get_points_on_a_grid(
17
+ size: int,
18
+ extent: Tuple[float, ...],
19
+ center: Optional[Tuple[float, ...]] = None,
20
+ device: Optional[torch.device] = torch.device("cpu"),
21
+ ):
22
+ r"""Get a grid of points covering a rectangular region
23
+
24
+ `get_points_on_a_grid(size, extent)` generates a :attr:`size` by
25
+ :attr:`size` grid fo points distributed to cover a rectangular area
26
+ specified by `extent`.
27
+
28
+ The `extent` is a pair of integer :math:`(H,W)` specifying the height
29
+ and width of the rectangle.
30
+
31
+ Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
32
+ specifying the vertical and horizontal center coordinates. The center
33
+ defaults to the middle of the extent.
34
+
35
+ Points are distributed uniformly within the rectangle leaving a margin
36
+ :math:`m=W/64` from the border.
37
+
38
+ It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
39
+ points :math:`P_{ij}=(x_i, y_i)` where
40
+
41
+ .. math::
42
+ P_{ij} = \left(
43
+ c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
44
+ c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
45
+ \right)
46
+
47
+ Points are returned in row-major order.
48
+
49
+ Args:
50
+ size (int): grid size.
51
+ extent (tuple): height and with of the grid extent.
52
+ center (tuple, optional): grid center.
53
+ device (str, optional): Defaults to `"cpu"`.
54
+
55
+ Returns:
56
+ Tensor: grid.
57
+ """
58
+ if size == 1:
59
+ return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
60
+
61
+ if center is None:
62
+ center = [extent[0] / 2, extent[1] / 2]
63
+
64
+ margin = extent[1] / 64
65
+ range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
66
+ range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
67
+ grid_y, grid_x = torch.meshgrid(
68
+ torch.linspace(*range_y, size, device=device),
69
+ torch.linspace(*range_x, size, device=device),
70
+ indexing="ij",
71
+ )
72
+ return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
73
+
74
+
75
+ class CoTrackerOnlinePredictor(torch.nn.Module):
76
+ def __init__(
77
+ self,
78
+ checkpoint="./checkpoints/scaled_online.pth",
79
+ offline=False,
80
+ v2=False,
81
+ window_len=16,
82
+ ):
83
+ super().__init__()
84
+ self.support_grid_size = 6
85
+ model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_online").model
86
+ # build_cotracker(checkpoint, v2=v2, offline=False, window_len=window_len)
87
+ self.interp_shape = model.model_resolution
88
+ self.step = model.window_len // 2
89
+ self.model = model
90
+ self.model.eval()
91
+
92
+ @torch.no_grad()
93
+ def forward(
94
+ self,
95
+ video_chunk,
96
+ is_first_step: bool = False,
97
+ queries: torch.Tensor = None,
98
+ grid_size: int = 5,
99
+ grid_query_frame: int = 0,
100
+ add_support_grid=False,
101
+ iters: int = 5
102
+ ):
103
+ B, T, C, H, W = video_chunk.shape
104
+ # Initialize online video processing and save queried points
105
+ # This needs to be done before processing *each new video*
106
+ if is_first_step:
107
+ self.model.init_video_online_processing()
108
+ if queries is not None:
109
+ B, N, D = queries.shape
110
+ self.N = N
111
+ assert D == 3
112
+ queries = queries.clone()
113
+ queries[:, :, 1:] *= queries.new_tensor(
114
+ [
115
+ (self.interp_shape[1] - 1) / (W - 1),
116
+ (self.interp_shape[0] - 1) / (H - 1),
117
+ ]
118
+ )
119
+ if add_support_grid:
120
+ grid_pts = get_points_on_a_grid(
121
+ self.support_grid_size, self.interp_shape, device=video_chunk.device
122
+ )
123
+ grid_pts = torch.cat(
124
+ [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
125
+ )
126
+ queries = torch.cat([queries, grid_pts], dim=1)
127
+ elif grid_size > 0:
128
+ grid_pts = get_points_on_a_grid(
129
+ grid_size, self.interp_shape, device=video_chunk.device
130
+ )
131
+ self.N = grid_size**2
132
+ queries = torch.cat(
133
+ [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
134
+ dim=2,
135
+ )
136
+
137
+ self.queries = queries
138
+ return (None, None)
139
+
140
+ video_chunk = video_chunk.reshape(B * T, C, H, W)
141
+ video_chunk = F.interpolate(
142
+ video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True
143
+ )
144
+ video_chunk = video_chunk.reshape(
145
+ B, T, 3, self.interp_shape[0], self.interp_shape[1]
146
+ )
147
+
148
+ tracks, visibilities, confidence, __ = self.model(
149
+ video=video_chunk, queries=self.queries, iters=iters, is_online=True
150
+ )
151
+ if add_support_grid:
152
+ tracks = tracks[:,:,:self.N]
153
+ visibilities = visibilities[:,:,:self.N]
154
+ confidence = confidence[:,:,:self.N]
155
+
156
+ visibilities = visibilities * confidence
157
+ thr = 0.6
158
+ return (
159
+ tracks
160
+ * tracks.new_tensor(
161
+ [
162
+ (W - 1) / (self.interp_shape[1] - 1),
163
+ (H - 1) / (self.interp_shape[0] - 1),
164
+ ]
165
+ ),
166
+ visibilities > thr,
167
+ )