|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from typing import Optional, Tuple |
|
|
|
def smart_cat(tensor1, tensor2, dim): |
|
if tensor1 is None: |
|
return tensor2 |
|
return torch.cat([tensor1, tensor2], dim=dim) |
|
|
|
|
|
def get_points_on_a_grid( |
|
size: int, |
|
extent: Tuple[float, ...], |
|
center: Optional[Tuple[float, ...]] = None, |
|
device: Optional[torch.device] = torch.device("cpu"), |
|
): |
|
r"""Get a grid of points covering a rectangular region |
|
|
|
`get_points_on_a_grid(size, extent)` generates a :attr:`size` by |
|
:attr:`size` grid fo points distributed to cover a rectangular area |
|
specified by `extent`. |
|
|
|
The `extent` is a pair of integer :math:`(H,W)` specifying the height |
|
and width of the rectangle. |
|
|
|
Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)` |
|
specifying the vertical and horizontal center coordinates. The center |
|
defaults to the middle of the extent. |
|
|
|
Points are distributed uniformly within the rectangle leaving a margin |
|
:math:`m=W/64` from the border. |
|
|
|
It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of |
|
points :math:`P_{ij}=(x_i, y_i)` where |
|
|
|
.. math:: |
|
P_{ij} = \left( |
|
c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~ |
|
c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i |
|
\right) |
|
|
|
Points are returned in row-major order. |
|
|
|
Args: |
|
size (int): grid size. |
|
extent (tuple): height and with of the grid extent. |
|
center (tuple, optional): grid center. |
|
device (str, optional): Defaults to `"cpu"`. |
|
|
|
Returns: |
|
Tensor: grid. |
|
""" |
|
if size == 1: |
|
return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None] |
|
|
|
if center is None: |
|
center = [extent[0] / 2, extent[1] / 2] |
|
|
|
margin = extent[1] / 64 |
|
range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin) |
|
range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin) |
|
grid_y, grid_x = torch.meshgrid( |
|
torch.linspace(*range_y, size, device=device), |
|
torch.linspace(*range_x, size, device=device), |
|
indexing="ij", |
|
) |
|
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2) |
|
|
|
|
|
class CoTrackerOnlinePredictor(torch.nn.Module): |
|
def __init__( |
|
self, |
|
checkpoint=None, |
|
offline=False, |
|
v2=False, |
|
window_len=16, |
|
): |
|
super().__init__() |
|
self.support_grid_size = 6 |
|
model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_online").model |
|
|
|
if checkpoint is not None: |
|
with open(checkpoint, "rb") as f: |
|
state_dict = torch.load(f, map_location="cpu") |
|
if "model" in state_dict: |
|
state_dict = state_dict["model"] |
|
model.model.load_state_dict(state_dict) |
|
print('LOAD STATE DICT') |
|
|
|
|
|
self.interp_shape = model.model_resolution |
|
self.step = model.window_len // 2 |
|
self.model = model |
|
self.model.eval() |
|
|
|
@torch.no_grad() |
|
def forward( |
|
self, |
|
video_chunk, |
|
is_first_step: bool = False, |
|
queries: torch.Tensor = None, |
|
grid_size: int = 5, |
|
grid_query_frame: int = 0, |
|
add_support_grid=False, |
|
iters: int = 5 |
|
): |
|
B, T, C, H, W = video_chunk.shape |
|
|
|
|
|
if is_first_step: |
|
self.model.init_video_online_processing() |
|
if queries is not None: |
|
B, N, D = queries.shape |
|
self.N = N |
|
assert D == 3 |
|
queries = queries.clone() |
|
queries[:, :, 1:] *= queries.new_tensor( |
|
[ |
|
(self.interp_shape[1] - 1) / (W - 1), |
|
(self.interp_shape[0] - 1) / (H - 1), |
|
] |
|
) |
|
if add_support_grid: |
|
grid_pts = get_points_on_a_grid( |
|
self.support_grid_size, self.interp_shape, device=video_chunk.device |
|
) |
|
grid_pts = torch.cat( |
|
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 |
|
) |
|
queries = torch.cat([queries, grid_pts], dim=1) |
|
elif grid_size > 0: |
|
grid_pts = get_points_on_a_grid( |
|
grid_size, self.interp_shape, device=video_chunk.device |
|
) |
|
self.N = grid_size**2 |
|
queries = torch.cat( |
|
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts], |
|
dim=2, |
|
) |
|
|
|
self.queries = queries |
|
return (None, None) |
|
|
|
video_chunk = video_chunk.reshape(B * T, C, H, W) |
|
video_chunk = F.interpolate( |
|
video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True |
|
) |
|
video_chunk = video_chunk.reshape( |
|
B, T, 3, self.interp_shape[0], self.interp_shape[1] |
|
) |
|
|
|
tracks, visibilities, confidence, __ = self.model( |
|
video=video_chunk, queries=self.queries, iters=iters, is_online=True |
|
) |
|
if add_support_grid: |
|
tracks = tracks[:,:,:self.N] |
|
visibilities = visibilities[:,:,:self.N] |
|
confidence = confidence[:,:,:self.N] |
|
|
|
visibilities = visibilities * confidence |
|
thr = 0.6 |
|
return ( |
|
tracks |
|
* tracks.new_tensor( |
|
[ |
|
(W - 1) / (self.interp_shape[1] - 1), |
|
(H - 1) / (self.interp_shape[0] - 1), |
|
] |
|
), |
|
visibilities > thr, |
|
) |