Last commit not found
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn.functional as F | |
from typing import Optional, Tuple | |
def smart_cat(tensor1, tensor2, dim): | |
if tensor1 is None: | |
return tensor2 | |
return torch.cat([tensor1, tensor2], dim=dim) | |
def get_points_on_a_grid( | |
size: int, | |
extent: Tuple[float, ...], | |
center: Optional[Tuple[float, ...]] = None, | |
device: Optional[torch.device] = torch.device("cpu"), | |
): | |
r"""Get a grid of points covering a rectangular region | |
`get_points_on_a_grid(size, extent)` generates a :attr:`size` by | |
:attr:`size` grid fo points distributed to cover a rectangular area | |
specified by `extent`. | |
The `extent` is a pair of integer :math:`(H,W)` specifying the height | |
and width of the rectangle. | |
Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)` | |
specifying the vertical and horizontal center coordinates. The center | |
defaults to the middle of the extent. | |
Points are distributed uniformly within the rectangle leaving a margin | |
:math:`m=W/64` from the border. | |
It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of | |
points :math:`P_{ij}=(x_i, y_i)` where | |
.. math:: | |
P_{ij} = \left( | |
c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~ | |
c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i | |
\right) | |
Points are returned in row-major order. | |
Args: | |
size (int): grid size. | |
extent (tuple): height and with of the grid extent. | |
center (tuple, optional): grid center. | |
device (str, optional): Defaults to `"cpu"`. | |
Returns: | |
Tensor: grid. | |
""" | |
if size == 1: | |
return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None] | |
if center is None: | |
center = [extent[0] / 2, extent[1] / 2] | |
margin = extent[1] / 64 | |
range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin) | |
range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin) | |
grid_y, grid_x = torch.meshgrid( | |
torch.linspace(*range_y, size, device=device), | |
torch.linspace(*range_x, size, device=device), | |
indexing="ij", | |
) | |
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2) | |
class CoTrackerOnlinePredictor(torch.nn.Module): | |
def __init__( | |
self, | |
checkpoint=None, | |
offline=False, | |
v2=False, | |
window_len=16, | |
): | |
super().__init__() | |
self.support_grid_size = 6 | |
model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_online").model | |
# build_cotracker(checkpoint, v2=v2, offline=False, window_len=window_len) | |
if checkpoint is not None: | |
with open(checkpoint, "rb") as f: | |
state_dict = torch.load(f, map_location="cpu") | |
if "model" in state_dict: | |
state_dict = state_dict["model"] | |
model.model.load_state_dict(state_dict) | |
print('LOAD STATE DICT') | |
self.interp_shape = model.model_resolution | |
self.step = model.window_len // 2 | |
self.model = model | |
self.model.eval() | |
def forward( | |
self, | |
video_chunk, | |
is_first_step: bool = False, | |
queries: torch.Tensor = None, | |
grid_size: int = 5, | |
grid_query_frame: int = 0, | |
add_support_grid=False, | |
iters: int = 5 | |
): | |
B, T, C, H, W = video_chunk.shape | |
# Initialize online video processing and save queried points | |
# This needs to be done before processing *each new video* | |
if is_first_step: | |
self.model.init_video_online_processing() | |
if queries is not None: | |
B, N, D = queries.shape | |
self.N = N | |
assert D == 3 | |
queries = queries.clone() | |
queries[:, :, 1:] *= queries.new_tensor( | |
[ | |
(self.interp_shape[1] - 1) / (W - 1), | |
(self.interp_shape[0] - 1) / (H - 1), | |
] | |
) | |
if add_support_grid: | |
grid_pts = get_points_on_a_grid( | |
self.support_grid_size, self.interp_shape, device=video_chunk.device | |
) | |
grid_pts = torch.cat( | |
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 | |
) | |
queries = torch.cat([queries, grid_pts], dim=1) | |
elif grid_size > 0: | |
grid_pts = get_points_on_a_grid( | |
grid_size, self.interp_shape, device=video_chunk.device | |
) | |
self.N = grid_size**2 | |
queries = torch.cat( | |
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts], | |
dim=2, | |
) | |
self.queries = queries | |
return (None, None) | |
video_chunk = video_chunk.reshape(B * T, C, H, W) | |
video_chunk = F.interpolate( | |
video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True | |
) | |
video_chunk = video_chunk.reshape( | |
B, T, 3, self.interp_shape[0], self.interp_shape[1] | |
) | |
tracks, visibilities, confidence, __ = self.model( | |
video=video_chunk, queries=self.queries, iters=iters, is_online=True | |
) | |
if add_support_grid: | |
tracks = tracks[:,:,:self.N] | |
visibilities = visibilities[:,:,:self.N] | |
confidence = confidence[:,:,:self.N] | |
visibilities = visibilities * confidence | |
thr = 0.6 | |
return ( | |
tracks | |
* tracks.new_tensor( | |
[ | |
(W - 1) / (self.interp_shape[1] - 1), | |
(H - 1) / (self.interp_shape[0] - 1), | |
] | |
), | |
visibilities > thr, | |
) |