Spaces:
Running
on
L4
Running
on
L4
File size: 6,199 Bytes
0fdcb79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rigid3Array Transformations represented by a Matrix and a Vector."""
from __future__ import annotations
import dataclasses
from typing import Union, List
import torch
from dockformerpp.utils.geometry import rotation_matrix
from dockformerpp.utils.geometry import vector
Float = Union[float, torch.Tensor]
@dataclasses.dataclass(frozen=True)
class Rigid3Array:
"""Rigid Transformation, i.e. element of special euclidean group."""
rotation: rotation_matrix.Rot3Array
translation: vector.Vec3Array
def __matmul__(self, other: Rigid3Array) -> Rigid3Array:
new_rotation = self.rotation @ other.rotation # __matmul__
new_translation = self.apply_to_point(other.translation)
return Rigid3Array(new_rotation, new_translation)
def __getitem__(self, index) -> Rigid3Array:
return Rigid3Array(
self.rotation[index],
self.translation[index],
)
def __mul__(self, other: torch.Tensor) -> Rigid3Array:
return Rigid3Array(
self.rotation * other,
self.translation * other,
)
def map_tensor_fn(self, fn) -> Rigid3Array:
return Rigid3Array(
self.rotation.map_tensor_fn(fn),
self.translation.map_tensor_fn(fn),
)
def inverse(self) -> Rigid3Array:
"""Return Rigid3Array corresponding to inverse transform."""
inv_rotation = self.rotation.inverse()
inv_translation = inv_rotation.apply_to_point(-self.translation)
return Rigid3Array(inv_rotation, inv_translation)
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply Rigid3Array transform to point."""
return self.rotation.apply_to_point(point) + self.translation
def apply(self, point: torch.Tensor) -> torch.Tensor:
return self.apply_to_point(vector.Vec3Array.from_array(point)).to_tensor()
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply inverse Rigid3Array transform to point."""
new_point = point - self.translation
return self.rotation.apply_inverse_to_point(new_point)
def invert_apply(self, point: torch.Tensor) -> torch.Tensor:
return self.apply_inverse_to_point(vector.Vec3Array.from_array(point)).to_tensor()
def compose_rotation(self, other_rotation):
rot = self.rotation @ other_rotation
return Rigid3Array(rot, self.translation.clone())
def compose(self, other_rigid):
return self @ other_rigid
def unsqueeze(self, dim: int):
return Rigid3Array(
self.rotation.unsqueeze(dim),
self.translation.unsqueeze(dim),
)
@property
def shape(self) -> torch.Size:
return self.rotation.xx.shape
@property
def dtype(self) -> torch.dtype:
return self.rotation.xx.dtype
@property
def device(self) -> torch.device:
return self.rotation.xx.device
@classmethod
def identity(cls, shape, device) -> Rigid3Array:
"""Return identity Rigid3Array of given shape."""
return cls(
rotation_matrix.Rot3Array.identity(shape, device),
vector.Vec3Array.zeros(shape, device)
)
@classmethod
def cat(cls, rigids: List[Rigid3Array], dim: int) -> Rigid3Array:
return cls(
rotation_matrix.Rot3Array.cat(
[r.rotation for r in rigids], dim=dim
),
vector.Vec3Array.cat(
[r.translation for r in rigids], dim=dim
),
)
def scale_translation(self, factor: Float) -> Rigid3Array:
"""Scale translation in Rigid3Array by 'factor'."""
return Rigid3Array(self.rotation, self.translation * factor)
def to_tensor(self) -> torch.Tensor:
rot_array = self.rotation.to_tensor()
vec_array = self.translation.to_tensor()
array = torch.zeros(
rot_array.shape[:-2] + (4, 4),
device=rot_array.device,
dtype=rot_array.dtype
)
array[..., :3, :3] = rot_array
array[..., :3, 3] = vec_array
array[..., 3, 3] = 1.
return array
def to_tensor_4x4(self) -> torch.Tensor:
return self.to_tensor()
def reshape(self, new_shape) -> Rigid3Array:
rots = self.rotation.reshape(new_shape)
trans = self.translation.reshape(new_shape)
return Rigid3Array(rots, trans)
def stop_rot_gradient(self) -> Rigid3Array:
return Rigid3Array(
self.rotation.stop_gradient(),
self.translation,
)
@classmethod
def from_array(cls, array):
rot = rotation_matrix.Rot3Array.from_array(
array[..., :3, :3],
)
vec = vector.Vec3Array.from_array(array[..., :3, 3])
return cls(rot, vec)
@classmethod
def from_tensor_4x4(cls, array):
return cls.from_array(array)
@classmethod
def from_array4x4(cls, array: torch.tensor) -> Rigid3Array:
"""Construct Rigid3Array from homogeneous 4x4 array."""
rotation = rotation_matrix.Rot3Array(
array[..., 0, 0], array[..., 0, 1], array[..., 0, 2],
array[..., 1, 0], array[..., 1, 1], array[..., 1, 2],
array[..., 2, 0], array[..., 2, 1], array[..., 2, 2]
)
translation = vector.Vec3Array(
array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]
)
return cls(rotation, translation)
def cuda(self) -> Rigid3Array:
return Rigid3Array.from_tensor_4x4(self.to_tensor_4x4().cuda())
|