File size: 6,199 Bytes
0fdcb79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rigid3Array Transformations represented by a Matrix and a Vector."""

from __future__ import annotations
import dataclasses
from typing import Union, List

import torch

from dockformerpp.utils.geometry import rotation_matrix
from dockformerpp.utils.geometry import vector


Float = Union[float, torch.Tensor]


@dataclasses.dataclass(frozen=True)
class Rigid3Array:
    """Rigid Transformation, i.e. element of special euclidean group."""

    rotation: rotation_matrix.Rot3Array
    translation: vector.Vec3Array

    def __matmul__(self, other: Rigid3Array) -> Rigid3Array:
        new_rotation = self.rotation @ other.rotation # __matmul__
        new_translation = self.apply_to_point(other.translation)
        return Rigid3Array(new_rotation, new_translation)

    def __getitem__(self, index) -> Rigid3Array:
        return Rigid3Array(
            self.rotation[index],
            self.translation[index],
        )

    def __mul__(self, other: torch.Tensor) -> Rigid3Array:
        return Rigid3Array(
            self.rotation * other,
            self.translation * other,
        )

    def map_tensor_fn(self, fn) -> Rigid3Array:
        return Rigid3Array(
            self.rotation.map_tensor_fn(fn),
            self.translation.map_tensor_fn(fn),
        )

    def inverse(self) -> Rigid3Array:
        """Return Rigid3Array corresponding to inverse transform."""
        inv_rotation = self.rotation.inverse()
        inv_translation = inv_rotation.apply_to_point(-self.translation)
        return Rigid3Array(inv_rotation, inv_translation)

    def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
        """Apply Rigid3Array transform to point."""
        return self.rotation.apply_to_point(point) + self.translation

    def apply(self, point: torch.Tensor) -> torch.Tensor:
        return self.apply_to_point(vector.Vec3Array.from_array(point)).to_tensor()

    def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
        """Apply inverse Rigid3Array transform to point."""
        new_point = point - self.translation
        return self.rotation.apply_inverse_to_point(new_point)

    def invert_apply(self, point: torch.Tensor) -> torch.Tensor:
        return self.apply_inverse_to_point(vector.Vec3Array.from_array(point)).to_tensor()

    def compose_rotation(self, other_rotation):
        rot = self.rotation @ other_rotation
        return Rigid3Array(rot, self.translation.clone())

    def compose(self, other_rigid):
        return self @ other_rigid

    def unsqueeze(self, dim: int):
        return Rigid3Array(
            self.rotation.unsqueeze(dim),
            self.translation.unsqueeze(dim),
        )

    @property
    def shape(self) -> torch.Size:
        return self.rotation.xx.shape

    @property
    def dtype(self) -> torch.dtype:
        return self.rotation.xx.dtype

    @property
    def device(self) -> torch.device:
        return self.rotation.xx.device

    @classmethod
    def identity(cls, shape, device) -> Rigid3Array:
        """Return identity Rigid3Array of given shape."""
        return cls(
            rotation_matrix.Rot3Array.identity(shape, device),
            vector.Vec3Array.zeros(shape, device)
        )

    @classmethod
    def cat(cls, rigids: List[Rigid3Array], dim: int) -> Rigid3Array:
        return cls(
            rotation_matrix.Rot3Array.cat(
                [r.rotation for r in rigids], dim=dim
            ),
            vector.Vec3Array.cat(
                [r.translation for r in rigids], dim=dim
            ),
        ) 

    def scale_translation(self, factor: Float) -> Rigid3Array:
        """Scale translation in Rigid3Array by 'factor'."""
        return Rigid3Array(self.rotation, self.translation * factor)

    def to_tensor(self) -> torch.Tensor:
        rot_array = self.rotation.to_tensor()
        vec_array = self.translation.to_tensor()
        array = torch.zeros(
            rot_array.shape[:-2] + (4, 4), 
            device=rot_array.device, 
            dtype=rot_array.dtype
        )
        array[..., :3, :3] = rot_array
        array[..., :3, 3] = vec_array
        array[..., 3, 3] = 1.
        return array

    def to_tensor_4x4(self) -> torch.Tensor:
        return self.to_tensor()

    def reshape(self, new_shape) -> Rigid3Array:
        rots = self.rotation.reshape(new_shape)
        trans = self.translation.reshape(new_shape)
        return Rigid3Array(rots, trans)

    def stop_rot_gradient(self) -> Rigid3Array:
        return Rigid3Array(
            self.rotation.stop_gradient(),
            self.translation,
        )

    @classmethod
    def from_array(cls, array):
        rot = rotation_matrix.Rot3Array.from_array(
            array[..., :3, :3],
        )
        vec = vector.Vec3Array.from_array(array[..., :3, 3])
        return cls(rot, vec)

    @classmethod
    def from_tensor_4x4(cls, array):
        return cls.from_array(array)

    @classmethod
    def from_array4x4(cls, array: torch.tensor) -> Rigid3Array:
        """Construct Rigid3Array from homogeneous 4x4 array."""
        rotation = rotation_matrix.Rot3Array(
            array[..., 0, 0], array[..., 0, 1], array[..., 0, 2],
            array[..., 1, 0], array[..., 1, 1], array[..., 1, 2],
            array[..., 2, 0], array[..., 2, 1], array[..., 2, 2]
        )
        translation = vector.Vec3Array(
            array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]
        )
        return cls(rotation, translation)

    def cuda(self) -> Rigid3Array:
        return Rigid3Array.from_tensor_4x4(self.to_tensor_4x4().cuda())