File size: 3,305 Bytes
5359939
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from utils.geometry_utils import CameraPose
from einops import rearrange, repeat
import math
import roma

class ControllableCameraPose(CameraPose):
    def to_vectors(self) -> torch.Tensor:
        """
        Returns the raw camera poses.
        Returns:
            torch.Tensor: The raw camera poses. Shape (B, T, 4 + 12).
        """
        RT = torch.cat([self._R, rearrange(self._T, "b t i -> b t i 1")], dim=-1)
        return torch.cat([self._K, rearrange(RT, "b t i j -> b t (i j)")], dim=-1)

    def extend(
        self,
        num_frames: int,
        x_angle: float = 0.0,
        y_angle: float = 0.0,
        distance: float = 100.0,
    ) -> None:
        """
        Extends the camera poses.
        Let's say 0 degree is the direction of the last camera pose.
        Smoothly Move & rotate the camera poses in the direction of the given angle (clockwise) in a 2D plane.
        Args:
            num_frames (int): The number of frames to extend.
            x_angle (float): The angle to extend. The angle is in degrees.
            y_angle (float): The angle to extend. The angle is in degrees.
        """
        MOVING_SCALE = 0.5 * distance / 100
        self._normalize_by(self._R[:, -1], self._T[:, -1])

        # first compute relative poses for the final n + num_frames th frame

        # compute the rotation matrix for the given angle
        R_final = roma.euler_to_rotmat(
            convention="xyz",
            angles=torch.tensor(
                [-x_angle, -y_angle, 0], device=self._R.device, dtype=torch.float32
            ),
            degrees=True,
            dtype=torch.float32,
            device=self._R.device,
        ).unsqueeze(0)

        # compute the translation vector for the given angle
        T_final = torch.tensor(
            [
                -MOVING_SCALE * num_frames * math.sin(math.radians(y_angle)),
                MOVING_SCALE * num_frames * math.sin(math.radians(x_angle)),
                -MOVING_SCALE * num_frames * math.cos(math.radians(y_angle)),
            ],
            device=self._T.device,
            dtype=self._T.dtype,
        ).unsqueeze(0)

        R = torch.cat(
            [self._R, repeat(R_final, "b i j -> b t i j", t=num_frames).clone()], dim=1
        )
        T = torch.cat(
            [self._T, repeat(T_final, "b i -> b t i", t=num_frames).clone()], dim=1
        )
        K = torch.cat(
            [self._K, repeat(self._K[:, -1], "b i -> b t i", t=num_frames).clone()],
            dim=1,
        )
        self._R = R
        self._T = T
        self._K = K
        # interpolate all frames btwn the last frame and the final frame
        self.replace_with_interpolation(
            torch.cat(
                [
                    torch.zeros_like(self._T[:, :-num_frames, 0]),
                    torch.ones_like(self._T[:, -num_frames:-1, 0]),
                    torch.zeros_like(self._T[:, -1:, 0]),
                ],
                dim=-1,
            ).bool()
        )

def extend_poses(
    conditions: torch.Tensor,
    n: int,
    x_angle: float = 0.0,
    y_angle: float = 0.0,
    distance: float = 0.0,
) -> torch.Tensor:
    poses = ControllableCameraPose.from_vectors(conditions)
    poses.extend(n, x_angle, y_angle, distance)
    return poses.to_vectors()