File size: 3,712 Bytes
c705408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn.functional as F
from typing import Tuple

from cotracker.models.core.cotracker.cotracker import CoTracker2
from cotracker.models.core.model_utils import get_points_on_a_grid


class EvaluationPredictor(torch.nn.Module):
    def __init__(
        self,
        cotracker_model: CoTracker2,
        interp_shape: Tuple[int, int] = (384, 512),
        grid_size: int = 5,
        local_grid_size: int = 8,
        single_point: bool = True,
        n_iters: int = 6,
    ) -> None:
        super(EvaluationPredictor, self).__init__()
        self.grid_size = grid_size
        self.local_grid_size = local_grid_size
        self.single_point = single_point
        self.interp_shape = interp_shape
        self.n_iters = n_iters

        self.model = cotracker_model
        self.model.eval()

    def forward(self, video, queries):
        queries = queries.clone()
        B, T, C, H, W = video.shape
        B, N, D = queries.shape

        assert D == 3

        video = video.reshape(B * T, C, H, W)
        video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
        video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])

        device = video.device

        queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1)
        queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1)

        if self.single_point:
            traj_e = torch.zeros((B, T, N, 2), device=device)
            vis_e = torch.zeros((B, T, N), device=device)
            for pind in range((N)):
                query = queries[:, pind : pind + 1]

                t = query[0, 0, 0].long()

                traj_e_pind, vis_e_pind = self._process_one_point(video, query)
                traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
                vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
        else:
            if self.grid_size > 0:
                xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
                xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device)  #
                queries = torch.cat([queries, xy], dim=1)  #

            traj_e, vis_e, __ = self.model(
                video=video,
                queries=queries,
                iters=self.n_iters,
            )

        traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1)
        traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1)
        return traj_e, vis_e

    def _process_one_point(self, video, query):
        t = query[0, 0, 0].long()

        device = query.device
        if self.local_grid_size > 0:
            xy_target = get_points_on_a_grid(
                self.local_grid_size,
                (50, 50),
                [query[0, 0, 2].item(), query[0, 0, 1].item()],
            )

            xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to(
                device
            )  #
            query = torch.cat([query, xy_target], dim=1)  #

        if self.grid_size > 0:
            xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
            xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device)  #
            query = torch.cat([query, xy], dim=1)  #
        # crop the video to start from the queried frame
        query[0, 0, 0] = 0
        traj_e_pind, vis_e_pind, __ = self.model(
            video=video[:, t:], queries=query, iters=self.n_iters
        )

        return traj_e_pind, vis_e_pind