# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn.functional as F from typing import Optional, Tuple def smart_cat(tensor1, tensor2, dim): if tensor1 is None: return tensor2 return torch.cat([tensor1, tensor2], dim=dim) def get_points_on_a_grid( size: int, extent: Tuple[float, ...], center: Optional[Tuple[float, ...]] = None, device: Optional[torch.device] = torch.device("cpu"), ): r"""Get a grid of points covering a rectangular region `get_points_on_a_grid(size, extent)` generates a :attr:`size` by :attr:`size` grid fo points distributed to cover a rectangular area specified by `extent`. The `extent` is a pair of integer :math:`(H,W)` specifying the height and width of the rectangle. Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)` specifying the vertical and horizontal center coordinates. The center defaults to the middle of the extent. Points are distributed uniformly within the rectangle leaving a margin :math:`m=W/64` from the border. It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of points :math:`P_{ij}=(x_i, y_i)` where .. math:: P_{ij} = \left( c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~ c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i \right) Points are returned in row-major order. Args: size (int): grid size. extent (tuple): height and with of the grid extent. center (tuple, optional): grid center. device (str, optional): Defaults to `"cpu"`. Returns: Tensor: grid. """ if size == 1: return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None] if center is None: center = [extent[0] / 2, extent[1] / 2] margin = extent[1] / 64 range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin) range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin) grid_y, grid_x = torch.meshgrid( torch.linspace(*range_y, size, device=device), torch.linspace(*range_x, size, device=device), indexing="ij", ) return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2) class CoTrackerOnlinePredictor(torch.nn.Module): def __init__( self, checkpoint=None, offline=False, v2=False, window_len=16, ): super().__init__() self.support_grid_size = 6 model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_online").model # build_cotracker(checkpoint, v2=v2, offline=False, window_len=window_len) if checkpoint is not None: with open(checkpoint, "rb") as f: state_dict = torch.load(f, map_location="cpu") if "model" in state_dict: state_dict = state_dict["model"] model.load_state_dict(state_dict) print('LOAD STATE DICT') self.interp_shape = model.model_resolution self.step = model.window_len // 2 self.model = model self.model.eval() @torch.no_grad() def forward( self, video_chunk, is_first_step: bool = False, queries: torch.Tensor = None, grid_size: int = 5, grid_query_frame: int = 0, add_support_grid=False, iters: int = 5 ): B, T, C, H, W = video_chunk.shape # Initialize online video processing and save queried points # This needs to be done before processing *each new video* if is_first_step: self.model.init_video_online_processing() if queries is not None: B, N, D = queries.shape self.N = N assert D == 3 queries = queries.clone() queries[:, :, 1:] *= queries.new_tensor( [ (self.interp_shape[1] - 1) / (W - 1), (self.interp_shape[0] - 1) / (H - 1), ] ) if add_support_grid: grid_pts = get_points_on_a_grid( self.support_grid_size, self.interp_shape, device=video_chunk.device ) grid_pts = torch.cat( [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 ) queries = torch.cat([queries, grid_pts], dim=1) elif grid_size > 0: grid_pts = get_points_on_a_grid( grid_size, self.interp_shape, device=video_chunk.device ) self.N = grid_size**2 queries = torch.cat( [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts], dim=2, ) self.queries = queries return (None, None) video_chunk = video_chunk.reshape(B * T, C, H, W) video_chunk = F.interpolate( video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True ) video_chunk = video_chunk.reshape( B, T, 3, self.interp_shape[0], self.interp_shape[1] ) tracks, visibilities, confidence, __ = self.model( video=video_chunk, queries=self.queries, iters=iters, is_online=True ) if add_support_grid: tracks = tracks[:,:,:self.N] visibilities = visibilities[:,:,:self.N] confidence = confidence[:,:,:self.N] visibilities = visibilities * confidence thr = 0.6 return ( tracks * tracks.new_tensor( [ (W - 1) / (self.interp_shape[1] - 1), (H - 1) / (self.interp_shape[0] - 1), ] ), visibilities > thr, )