Hukuna's picture
Upload 221 files
ce7bf5b verified
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layers for measuring and building atomic geometries in proteins.
This module contains pytorch layers for computing common geometric features of
protein backbones in a differentiable way and for converting between internal
and Cartesian coordinate representations.
"""
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class Distances(nn.Module):
"""Euclidean distance layer (pairwise).
This layer computes batched pairwise Euclidean distances, where the input
tensor is treated as a batch of vectors with the final dimension as the
feature dimension and the dimension for pairwise expansion can be specified.
Args:
distance_eps (float, optional): Small parameter to adde to squared
distances to make gradients smooth near 0.
Inputs:
X (tensor): Input coordinates with shape `([...], length, [...], 3)`.
dim (int, optional): Dimension upon which to expand to pairwise
distances. Defaults to -2.
mask (tensor, optional): Masking tensor with shape
`([...], length, [...])`.
Outputs:
D (tensor): Distances with shape `([...], length, length, [...])`
"""
def __init__(self, distance_eps=1e-3):
super(Distances, self).__init__()
self.distance_eps = distance_eps
def forward(
self, X: torch.Tensor, mask: Optional[torch.Tensor] = None, dim: float = -2
) -> torch.Tensor:
dim_expand = dim if dim < 0 else dim + 1
dX = X.unsqueeze(dim_expand - 1) - X.unsqueeze(dim_expand)
D_square = torch.sum(dX ** 2, -1)
D = torch.sqrt(D_square + self.distance_eps)
if mask is not None:
mask_expand = mask.unsqueeze(dim) * mask.unsqueeze(dim + 1)
D = mask_expand * D
return D
class VirtualAtomsCA(nn.Module):
"""Virtual atoms layer, branching from backbone C-alpha carbons.
This layer places virtual atom coordinates relative to backbone coordinates
in a differentiable way.
Args:
virtual_type (str, optional): Type of virtual atom to place. Currently
supported types are `dicons`, a virtual placement that was
optimized to predict potential rotamer interactions, and `cbeta`
which places a virtual C-beta carbon assuming ideal geometry.
distance_eps (float, optional): Small parameter to add to squared
distances to make gradients smooth near 0.
Inputs:
X (Tensor): Backbone coordinates with shape
`(num_batch, num_residues, num_atom_types, 3)`.
C (Tensor): Chain map tensor with shape `(num_batch, num_residues)`.
Outputs:
X_virtual (Tensor): Virtual coordinates with shape
`(num_batch, num_residues, 3)`.
"""
def __init__(self, virtual_type="dicons", distance_eps=1e-3):
super(VirtualAtomsCA, self).__init__()
self.distance_eps = distance_eps
"""
Geometry specifications
dicons
Length CA-X: 2.3866
Angle N-CA-X: 111.0269
Dihedral C-N-CA-X: -138.886412
cbeta
Length CA-X: 1.532 (Engh and Huber, 2001)
Angle N-CA-X: 109.5 (tetrahedral geometry)
Dihedral C-N-CA-X: -125.25 (109.5 / 2 - 180)
"""
self.virtual_type = virtual_type
virtual_geometries = {
"dicons": [2.3866, 111.0269, -138.8864122],
"cbeta": [1.532, 109.5, -125.25],
}
self.virtual_geometries = virtual_geometries
self.distance_eps = distance_eps
def geometry(self):
bond, angle, dihedral = self.virtual_geometries[self.virtual_type]
return bond, angle, dihedral
def forward(self, X: torch.Tensor, C: torch.LongTensor) -> torch.Tensor:
bond, angle, dihedral = self.geometry()
ones = torch.ones([1, 1], device=X.device)
bonds = bond * ones
angles = angle * ones
dihedrals = dihedral * ones
# Build reference frame
# 1.C -> 2.N -> 3.CA -> 4.X
X_N, X_CA, X_C, X_O = X.unbind(2)
X_virtual = extend_atoms(
X_C,
X_N,
X_CA,
bonds,
angles,
dihedrals,
degrees=True,
distance_eps=self.distance_eps,
)
# Mask missing positions
mask = (C > 0).type(torch.float32).unsqueeze(-1)
X_virtual = mask * X_virtual
return X_virtual
def normed_vec(V: torch.Tensor, distance_eps: float = 1e-3) -> torch.Tensor:
"""Normalized vectors with distance smoothing.
This normalization is computed as `U = V / sqrt(|V|^2 + eps)` to avoid cusps
and gradient discontinuities.
Args:
V (Tensor): Batch of vectors with shape `(..., num_dims)`.
distance_eps (float, optional): Distance smoothing parameter for
for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`.
Default: 1E-3.
Returns:
U (Tensor): Batch of normalized vectors with shape `(..., num_dims)`.
"""
# Unit vector from i to j
mag_sq = (V ** 2).sum(dim=-1, keepdim=True)
mag = torch.sqrt(mag_sq + distance_eps)
U = V / mag
return U
def normed_cross(
V1: torch.Tensor, V2: torch.Tensor, distance_eps: float = 1e-3
) -> torch.Tensor:
"""Normalized cross product between vectors.
This normalization is computed as `U = V / sqrt(|V|^2 + eps)` to avoid cusps
and gradient discontinuities.
Args:
V1 (Tensor): Batch of vectors with shape `(..., 3)`.
V2 (Tensor): Batch of vectors with shape `(..., 3)`.
distance_eps (float, optional): Distance smoothing parameter for
for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`.
Default: 1E-3.
Returns:
C (Tensor): Batch of cross products `v_1 x v_2` with shape `(..., 3)`.
"""
C = normed_vec(torch.cross(V1, V2, dim=-1), distance_eps=distance_eps)
return C
def lengths(
atom_i: torch.Tensor, atom_j: torch.Tensor, distance_eps: float = 1e-3
) -> torch.Tensor:
"""Batched bond lengths given batches of atom i and j.
Args:
atom_i (Tensor): Atom `i` coordinates with shape `(..., 3)`.
atom_j (Tensor): Atom `j` coordinates with shape `(..., 3)`.
distance_eps (float, optional): Distance smoothing parameter for
for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`.
Default: 1E-3.
Returns:
L (Tensor): Elementwise bond lengths `||x_i - x_j||` with shape `(...)`.
"""
# Bond length of i-j
dX = atom_j - atom_i
L = torch.sqrt((dX ** 2).sum(dim=-1) + distance_eps)
return L
def angles(
atom_i: torch.Tensor,
atom_j: torch.Tensor,
atom_k: torch.Tensor,
distance_eps: float = 1e-3,
degrees: bool = False,
) -> torch.Tensor:
"""Batched bond angles given atoms `i-j-k`.
Args:
atom_i (Tensor): Atom `i` coordinates with shape `(..., 3)`.
atom_j (Tensor): Atom `j` coordinates with shape `(..., 3)`.
atom_k (Tensor): Atom `k` coordinates with shape `(..., 3)`.
distance_eps (float, optional): Distance smoothing parameter for
for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`.
Default: 1E-3.
degrees (bool, optional): If True, convert to degrees. Default: False.
Returns:
A (Tensor): Elementwise bond angles with shape `(...)`.
"""
# Bond angle of i-j-k
U_ji = normed_vec(atom_i - atom_j, distance_eps=distance_eps)
U_jk = normed_vec(atom_k - atom_j, distance_eps=distance_eps)
inner_prod = torch.einsum("bix,bix->bi", U_ji, U_jk)
inner_prod = torch.clamp(inner_prod, -1, 1)
A = torch.acos(inner_prod)
if degrees:
A = A * 180.0 / np.pi
return A
def dihedrals(
atom_i: torch.Tensor,
atom_j: torch.Tensor,
atom_k: torch.Tensor,
atom_l: torch.Tensor,
distance_eps: float = 1e-3,
degrees: bool = False,
) -> torch.Tensor:
"""Batched bond dihedrals given atoms `i-j-k-l`.
Args:
atom_i (Tensor): Atom `i` coordinates with shape `(..., 3)`.
atom_j (Tensor): Atom `j` coordinates with shape `(..., 3)`.
atom_k (Tensor): Atom `k` coordinates with shape `(..., 3)`.
atom_l (Tensor): Atom `l` coordinates with shape `(..., 3)`.
distance_eps (float, optional): Distance smoothing parameter for
for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`.
Default: 1E-3.
degrees (bool, optional): If True, convert to degrees. Default: False.
Returns:
D (Tensor): Elementwise bond dihedrals with shape `(...)`.
"""
U_ij = normed_vec(atom_j - atom_i, distance_eps=distance_eps)
U_jk = normed_vec(atom_k - atom_j, distance_eps=distance_eps)
U_kl = normed_vec(atom_l - atom_k, distance_eps=distance_eps)
normal_ijk = normed_cross(U_ij, U_jk, distance_eps=distance_eps)
normal_jkl = normed_cross(U_jk, U_kl, distance_eps=distance_eps)
# _inner_product = lambda a, b: torch.einsum("bix,bix->bi", a, b)
_inner_product = lambda a, b: (a * b).sum(-1)
cos_dihedrals = _inner_product(normal_ijk, normal_jkl)
angle_sign = _inner_product(U_ij, normal_jkl)
cos_dihedrals = torch.clamp(cos_dihedrals, -1, 1)
D = torch.sign(angle_sign) * torch.acos(cos_dihedrals)
if degrees:
D = D * 180.0 / np.pi
return D
def extend_atoms(
X_1: torch.Tensor,
X_2: torch.Tensor,
X_3: torch.Tensor,
lengths: torch.Tensor,
angles: torch.Tensor,
dihedrals: torch.Tensor,
distance_eps: float = 1e-3,
degrees: bool = False,
) -> torch.Tensor:
"""Place atom `X_4` given `X_1`, `X_2`, `X_3` and internal coordinates.
___________________
| X_1 - X_2 |
| | |
| X_3 - [X_4] |
|___________________|
This uses a similar approach as NERF:
Parsons et al, Computational Chemistry (2005).
https://doi.org/10.1002/jcc.20237
See the reference for further explanation about converting from internal
coordinates to Cartesian coordinates.
Args:
X_1 (Tensor): First atom coordinates with shape `(..., 3)`.
X_2 (Tensor): Second atom coordinates with shape `(..., 3)`.
X_3 (Tensor): Third atom coordinates with shape `(..., 3)`.
lengths (Tensor): Bond lengths `X_3-X_4` with shape `(...)`.
angles (Tensor): Bond angles `X_2-X_3-X_4` with shape `(...)`.
dihedrals (Tensor): Bond dihedrals `X_1-X_2-X_3-X_4` with shape `(...)`.
distance_eps (float, optional): Distance smoothing parameter for
for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`.
This preserves differentiability for zero distances. Default: 1E-3.
degrees (bool, optional): If True, inputs are treated as degrees.
Default: False.
Returns:
X_4 (Tensor): Placed atom with shape `(..., 3)`.
"""
if degrees:
angles *= np.pi / 180.0
dihedrals *= np.pi / 180.0
r_32 = X_2 - X_3
r_12 = X_2 - X_1
n_1 = normed_vec(r_32, distance_eps=distance_eps)
n_2 = normed_cross(n_1, r_12, distance_eps=distance_eps)
n_3 = normed_cross(n_1, n_2, distance_eps=distance_eps)
lengths = lengths.unsqueeze(-1)
cos_angle = torch.cos(angles).unsqueeze(-1)
sin_angle = torch.sin(angles).unsqueeze(-1)
cos_dihedral = torch.cos(dihedrals).unsqueeze(-1)
sin_dihedral = torch.sin(dihedrals).unsqueeze(-1)
X_4 = X_3 + lengths * (
cos_angle * n_1
+ (sin_angle * sin_dihedral) * n_2
+ (sin_angle * cos_dihedral) * n_3
)
return X_4
class InternalCoords(nn.Module):
"""Internal coordinates layer.
This layer computes internal coordinates (ICs) from a batch of protein
backbones. To make the ICs differentiable everywhere, this layer replaces
distance calculations of the form `sqrt(sum_sq)` with smooth, non-cusped
approximation `sqrt(sum_sq + eps)`.
Args:
distance_eps (float, optional): Small parameter to add to squared
distances to make gradients smooth near 0.
Inputs:
X (Tensor): Backbone coordinates with shape
`(num_batch, num_residues, num_atom_types, 3)`.
C (Tensor): Chain map tensor with shape
`(num_batch, num_residues)`.
Outputs:
dihedrals (Tensor): Backbone dihedral angles with shape
`(num_batch, num_residues, 4)`
angles (Tensor): Backbone bond lengths with shape
`(num_batch, num_residues, 4)`
lengths (Tensor): Backbone bond lengths with shape
`(num_batch, num_residues, 4)`
"""
def __init__(self, distance_eps=1e-3):
super(InternalCoords, self).__init__()
self.distance_eps = distance_eps
def forward(
self,
X: torch.Tensor,
C: Optional[torch.Tensor] = None,
return_masks: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
mask = (C > 0).float()
X_chain = X[:, :, :3, :]
num_batch, num_residues, _, _ = X_chain.shape
X_chain = X_chain.reshape(num_batch, 3 * num_residues, 3)
# This function historically returns the angle complement
_lengths = lambda Xi, Xj: lengths(Xi, Xj, distance_eps=self.distance_eps)
_angles = lambda Xi, Xj, Xk: np.pi - angles(
Xi, Xj, Xk, distance_eps=self.distance_eps
)
_dihedrals = lambda Xi, Xj, Xk, Xl: dihedrals(
Xi, Xj, Xk, Xl, distance_eps=self.distance_eps
)
# Compute internal coordinates associated with -[N]-[CA]-[C]-
NCaC_L = _lengths(X_chain[:, 1:, :], X_chain[:, :-1, :])
NCaC_A = _angles(X_chain[:, :-2, :], X_chain[:, 1:-1, :], X_chain[:, 2:, :])
NCaC_D = _dihedrals(
X_chain[:, :-3, :],
X_chain[:, 1:-2, :],
X_chain[:, 2:-1, :],
X_chain[:, 3:, :],
)
# Compute internal coordinates associated with [C]=[O]
_, X_CA, X_C, X_O = X.unbind(dim=2)
X_N_next = X[:, 1:, 0, :]
O_L = _lengths(X_C, X_O)
O_A = _angles(X_CA, X_C, X_O)
O_D = _dihedrals(X_N_next, X_CA[:, :-1, :], X_C[:, :-1, :], X_O[:, :-1, :])
if C is None:
C = torch.zeros_like(mask)
# Mask nonphysical bonds and angles
# Note: this could probably also be expressed as a Conv, unclear
# which is faster and this probably not rate-limiting.
C = C * (mask.type(torch.long))
ii = torch.stack(3 * [C], dim=-1).view([num_batch, -1])
L0, L1 = ii[:, :-1], ii[:, 1:]
A0, A1, A2 = ii[:, :-2], ii[:, 1:-1], ii[:, 2:]
D0, D1, D2, D3 = ii[:, :-3], ii[:, 1:-2], ii[:, 2:-1], ii[:, 3:]
# Mask for linear backbone
mask_L = torch.eq(L0, L1)
mask_A = torch.eq(A0, A1) * torch.eq(A0, A2)
mask_D = torch.eq(D0, D1) * torch.eq(D0, D2) * torch.eq(D0, D3)
mask_L = mask_L.type(torch.float32)
mask_A = mask_A.type(torch.float32)
mask_D = mask_D.type(torch.float32)
# Masks for branched oxygen
mask_O_D = torch.eq(C[:, :-1], C[:, 1:])
mask_O_D = mask_O_D.type(torch.float32)
mask_O_A = mask
mask_O_L = mask
def _pad_pack(D, A, L, O_D, O_A, O_L):
# Pad and pack together the components
D = F.pad(D, (1, 2))
A = F.pad(A, (0, 2))
L = F.pad(L, (0, 1))
O_D = F.pad(O_D, (0, 1))
D, A, L = [x.reshape(num_batch, num_residues, 3) for x in [D, A, L]]
_pack = lambda a, b: torch.cat([a, b.unsqueeze(-1)], dim=-1)
L = _pack(L, O_L)
A = _pack(A, O_A)
D = _pack(D, O_D)
return D, A, L
D, A, L = _pad_pack(NCaC_D, NCaC_A, NCaC_L, O_D, O_A, O_L)
mask_D, mask_A, mask_L = _pad_pack(
mask_D, mask_A, mask_L, mask_O_D, mask_O_A, mask_O_L
)
mask_expand = mask.unsqueeze(-1)
mask_D = mask_expand * mask_D
mask_A = mask_expand * mask_A
mask_L = mask_expand * mask_L
D = mask_D * D
A = mask_A * A
L = mask_L * L
if not return_masks:
return D, A, L
else:
return D, A, L, mask_D, mask_A, mask_L
class VirtualAtomsCA(nn.Module):
"""Virtual atoms layer, branching from backbone C-alpha carbons.
This layer places virtual atom coordinates relative to backbone coordinates
in a differentiable way.
Args:
virtual_type (str, optional): Type of virtual atom to place. Currently
supported types are `dicons`, a virtual placement that was
optimized to predict potential rotamer interactions, and `cbeta`
which places a virtual C-beta carbon assuming ideal geometry.
distance_eps (float, optional): Small parameter to add to squared
distances to make gradients smooth near 0.
Inputs:
X (Tensor): Backbone coordinates with shape
`(num_batch, num_residues, num_atom_types, 3)`.
C (Tensor): Chain map tensor with shape `(num_batch, num_residues)`.
Outputs:
X_virtual (Tensor): Virtual coordinates with shape
`(num_batch, num_residues, 3)`.
"""
def __init__(self, virtual_type="dicons", distance_eps=1e-3):
super(VirtualAtomsCA, self).__init__()
self.distance_eps = distance_eps
"""
Geometry specifications
dicons
Length CA-X: 2.3866
Angle N-CA-X: 111.0269
Dihedral C-N-CA-X: -138.886412
cbeta
Length CA-X: 1.532 (Engh and Huber, 2001)
Angle N-CA-X: 109.5 (tetrahedral geometry)
Dihedral C-N-CA-X: -125.25 (109.5 / 2 - 180)
"""
self.virtual_type = virtual_type
virtual_geometries = {
"dicons": [2.3866, 111.0269, -138.8864122],
"cbeta": [1.532, 109.5, -125.25],
}
self.virtual_geometries = virtual_geometries
self.distance_eps = distance_eps
def geometry(self):
bond, angle, dihedral = self.virtual_geometries[self.virtual_type]
return bond, angle, dihedral
def forward(self, X: torch.Tensor, C: torch.LongTensor) -> torch.Tensor:
bond, angle, dihedral = self.geometry()
ones = torch.ones([1, 1], device=X.device)
bonds = bond * ones
angles = angle * ones
dihedrals = dihedral * ones
# Build reference frame
# 1.C -> 2.N -> 3.CA -> 4.X
X_N, X_CA, X_C, X_O = X.unbind(2)
X_virtual = extend_atoms(
X_C,
X_N,
X_CA,
bonds,
angles,
dihedrals,
degrees=True,
distance_eps=self.distance_eps,
)
# Mask missing positions
mask = (C > 0).type(torch.float32).unsqueeze(-1)
X_virtual = mask * X_virtual
return X_virtual
def quaternions_from_rotations(R: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
"""Convert a batch of rotation matrices to quaternions.
See en.wikipedia.org/wiki/Quaternions_and_spatial_rotation for further
details on converting between quaternions and rotation matrices.
Args:
R (tensor): Batch of rotation matrices with shape `(..., 3, 3)`.
Returns:
q (tensor): Batch of quaternion vectors with shape `(..., 4)`. Quaternion
is in the order `[angle, axis_x, axis_y, axis_z]`.
"""
batch_dims = list(R.shape)[:-2]
R_flat = R.reshape(batch_dims + [9])
R00, R01, R02, R10, R11, R12, R20, R21, R22 = R_flat.unbind(-1)
# Quaternion possesses both an axis and angle of rotation
_sqrt = lambda r: torch.sqrt(F.relu(r) + eps)
q_angle = _sqrt(1 + R00 + R11 + R22).unsqueeze(-1)
magnitudes = _sqrt(
1 + torch.stack([R00 - R11 - R22, -R00 + R11 - R22, -R00 - R11 + R22], -1)
)
signs = torch.sign(torch.stack([R21 - R12, R02 - R20, R10 - R01], -1))
q_axis = signs * magnitudes
# Normalize (for safety and a missing factor of 2)
q_unc = torch.cat((q_angle, q_axis), -1)
q = normed_vec(q_unc, distance_eps=eps)
return q
def rotations_from_quaternions(
q: torch.Tensor, normalize: bool = False, eps: float = 1e-3
) -> torch.Tensor:
"""Convert a batch of quaternions to rotation matrices.
See en.wikipedia.org/wiki/Quaternions_and_spatial_rotation for further
details on converting between quaternions and rotation matrices.
Returns:
q (tensor): Batch of quaternion vectors with shape `(..., 4)`. Quaternion
is in the order `[angle, axis_x, axis_y, axis_z]`.
normalize (boolean, optional): Option to normalize the quaternion before
conversion.
Args:
R (tensor): Batch of rotation matrices with shape `(..., 3, 3)`.
"""
batch_dims = list(q.shape)[:-1]
if normalize:
q = normed_vec(q, distance_eps=eps)
a, b, c, d = q.unbind(-1)
a2, b2, c2, d2 = a ** 2, b ** 2, c ** 2, d ** 2
R = torch.stack(
[
a2 + b2 - c2 - d2,
2 * b * c - 2 * a * d,
2 * b * d + 2 * a * c,
2 * b * c + 2 * a * d,
a2 - b2 + c2 - d2,
2 * c * d - 2 * a * b,
2 * b * d - 2 * a * c,
2 * c * d + 2 * a * b,
a2 - b2 - c2 + d2,
],
dim=-1,
)
R = R.view(batch_dims + [3, 3])
return R
def frames_from_backbone(X: torch.Tensor, distance_eps: float = 1e-3):
"""Convert a backbone into local reference frames.
Args:
X (Tensor): Backbone coordinates with shape `(..., 4, 3)`.
distance_eps (float, optional): Distance smoothing parameter for
for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`.
Default: 1E-3.
Returns:
R (Tensor): Reference frames with shape `(..., 3, 3)`.
X_CA (Tensor): C-alpha coordinates with shape `(..., 3)`
"""
X_N, X_CA, X_C, X_O = X.unbind(-2)
u_CA_N = normed_vec(X_N - X_CA, distance_eps)
u_CA_C = normed_vec(X_C - X_CA, distance_eps)
n_1 = u_CA_N
n_2 = normed_cross(n_1, u_CA_C, distance_eps)
n_3 = normed_cross(n_1, n_2, distance_eps)
R = torch.stack([n_1, n_2, n_3], -1)
return R, X_CA
def hat(omega: torch.Tensor) -> torch.Tensor:
"""
Maps [x,y,z] to [[0,-z,y], [z,0,-x], [-y, x, 0]]
Args:
omega (torch.tensor): of size (*, 3)
Returns:
hat{omega} (torch.tensor): of size (*, 3, 3) skew symmetric element in so(3)
"""
target = torch.zeros(*omega.size()[:-1], 9, device=omega.device)
index1 = torch.tensor([7, 2, 3], device=omega.device).expand(
*target.size()[:-1], -1
)
index2 = torch.tensor([5, 6, 1], device=omega.device).expand(
*target.size()[:-1], -1
)
return (
target.scatter(-1, index1, omega)
.scatter(-1, index2, -omega)
.reshape(*target.size()[:-1], 3, 3)
)
def V(omega: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
I = torch.eye(3, device=omega.device).expand(*omega.size()[:-1], 3, 3)
theta = omega.pow(2).sum(dim=-1, keepdim=True).add(eps).sqrt()[..., None]
omega_hat = hat(omega)
M1 = ((1 - theta.cos()) / theta.pow(2)) * (omega_hat)
M2 = ((theta - theta.sin()) / theta.pow(3)) * (omega_hat @ omega_hat)
return I + M1 + M2