Spaces:
Sleeping
Sleeping
# Copyright Generate Biomedicines, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Layers for measuring and building atomic geometries in proteins. | |
This module contains pytorch layers for computing common geometric features of | |
protein backbones in a differentiable way and for converting between internal | |
and Cartesian coordinate representations. | |
""" | |
from typing import Optional, Tuple | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class Distances(nn.Module): | |
"""Euclidean distance layer (pairwise). | |
This layer computes batched pairwise Euclidean distances, where the input | |
tensor is treated as a batch of vectors with the final dimension as the | |
feature dimension and the dimension for pairwise expansion can be specified. | |
Args: | |
distance_eps (float, optional): Small parameter to adde to squared | |
distances to make gradients smooth near 0. | |
Inputs: | |
X (tensor): Input coordinates with shape `([...], length, [...], 3)`. | |
dim (int, optional): Dimension upon which to expand to pairwise | |
distances. Defaults to -2. | |
mask (tensor, optional): Masking tensor with shape | |
`([...], length, [...])`. | |
Outputs: | |
D (tensor): Distances with shape `([...], length, length, [...])` | |
""" | |
def __init__(self, distance_eps=1e-3): | |
super(Distances, self).__init__() | |
self.distance_eps = distance_eps | |
def forward( | |
self, X: torch.Tensor, mask: Optional[torch.Tensor] = None, dim: float = -2 | |
) -> torch.Tensor: | |
dim_expand = dim if dim < 0 else dim + 1 | |
dX = X.unsqueeze(dim_expand - 1) - X.unsqueeze(dim_expand) | |
D_square = torch.sum(dX ** 2, -1) | |
D = torch.sqrt(D_square + self.distance_eps) | |
if mask is not None: | |
mask_expand = mask.unsqueeze(dim) * mask.unsqueeze(dim + 1) | |
D = mask_expand * D | |
return D | |
class VirtualAtomsCA(nn.Module): | |
"""Virtual atoms layer, branching from backbone C-alpha carbons. | |
This layer places virtual atom coordinates relative to backbone coordinates | |
in a differentiable way. | |
Args: | |
virtual_type (str, optional): Type of virtual atom to place. Currently | |
supported types are `dicons`, a virtual placement that was | |
optimized to predict potential rotamer interactions, and `cbeta` | |
which places a virtual C-beta carbon assuming ideal geometry. | |
distance_eps (float, optional): Small parameter to add to squared | |
distances to make gradients smooth near 0. | |
Inputs: | |
X (Tensor): Backbone coordinates with shape | |
`(num_batch, num_residues, num_atom_types, 3)`. | |
C (Tensor): Chain map tensor with shape `(num_batch, num_residues)`. | |
Outputs: | |
X_virtual (Tensor): Virtual coordinates with shape | |
`(num_batch, num_residues, 3)`. | |
""" | |
def __init__(self, virtual_type="dicons", distance_eps=1e-3): | |
super(VirtualAtomsCA, self).__init__() | |
self.distance_eps = distance_eps | |
""" | |
Geometry specifications | |
dicons | |
Length CA-X: 2.3866 | |
Angle N-CA-X: 111.0269 | |
Dihedral C-N-CA-X: -138.886412 | |
cbeta | |
Length CA-X: 1.532 (Engh and Huber, 2001) | |
Angle N-CA-X: 109.5 (tetrahedral geometry) | |
Dihedral C-N-CA-X: -125.25 (109.5 / 2 - 180) | |
""" | |
self.virtual_type = virtual_type | |
virtual_geometries = { | |
"dicons": [2.3866, 111.0269, -138.8864122], | |
"cbeta": [1.532, 109.5, -125.25], | |
} | |
self.virtual_geometries = virtual_geometries | |
self.distance_eps = distance_eps | |
def geometry(self): | |
bond, angle, dihedral = self.virtual_geometries[self.virtual_type] | |
return bond, angle, dihedral | |
def forward(self, X: torch.Tensor, C: torch.LongTensor) -> torch.Tensor: | |
bond, angle, dihedral = self.geometry() | |
ones = torch.ones([1, 1], device=X.device) | |
bonds = bond * ones | |
angles = angle * ones | |
dihedrals = dihedral * ones | |
# Build reference frame | |
# 1.C -> 2.N -> 3.CA -> 4.X | |
X_N, X_CA, X_C, X_O = X.unbind(2) | |
X_virtual = extend_atoms( | |
X_C, | |
X_N, | |
X_CA, | |
bonds, | |
angles, | |
dihedrals, | |
degrees=True, | |
distance_eps=self.distance_eps, | |
) | |
# Mask missing positions | |
mask = (C > 0).type(torch.float32).unsqueeze(-1) | |
X_virtual = mask * X_virtual | |
return X_virtual | |
def normed_vec(V: torch.Tensor, distance_eps: float = 1e-3) -> torch.Tensor: | |
"""Normalized vectors with distance smoothing. | |
This normalization is computed as `U = V / sqrt(|V|^2 + eps)` to avoid cusps | |
and gradient discontinuities. | |
Args: | |
V (Tensor): Batch of vectors with shape `(..., num_dims)`. | |
distance_eps (float, optional): Distance smoothing parameter for | |
for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. | |
Default: 1E-3. | |
Returns: | |
U (Tensor): Batch of normalized vectors with shape `(..., num_dims)`. | |
""" | |
# Unit vector from i to j | |
mag_sq = (V ** 2).sum(dim=-1, keepdim=True) | |
mag = torch.sqrt(mag_sq + distance_eps) | |
U = V / mag | |
return U | |
def normed_cross( | |
V1: torch.Tensor, V2: torch.Tensor, distance_eps: float = 1e-3 | |
) -> torch.Tensor: | |
"""Normalized cross product between vectors. | |
This normalization is computed as `U = V / sqrt(|V|^2 + eps)` to avoid cusps | |
and gradient discontinuities. | |
Args: | |
V1 (Tensor): Batch of vectors with shape `(..., 3)`. | |
V2 (Tensor): Batch of vectors with shape `(..., 3)`. | |
distance_eps (float, optional): Distance smoothing parameter for | |
for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. | |
Default: 1E-3. | |
Returns: | |
C (Tensor): Batch of cross products `v_1 x v_2` with shape `(..., 3)`. | |
""" | |
C = normed_vec(torch.cross(V1, V2, dim=-1), distance_eps=distance_eps) | |
return C | |
def lengths( | |
atom_i: torch.Tensor, atom_j: torch.Tensor, distance_eps: float = 1e-3 | |
) -> torch.Tensor: | |
"""Batched bond lengths given batches of atom i and j. | |
Args: | |
atom_i (Tensor): Atom `i` coordinates with shape `(..., 3)`. | |
atom_j (Tensor): Atom `j` coordinates with shape `(..., 3)`. | |
distance_eps (float, optional): Distance smoothing parameter for | |
for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. | |
Default: 1E-3. | |
Returns: | |
L (Tensor): Elementwise bond lengths `||x_i - x_j||` with shape `(...)`. | |
""" | |
# Bond length of i-j | |
dX = atom_j - atom_i | |
L = torch.sqrt((dX ** 2).sum(dim=-1) + distance_eps) | |
return L | |
def angles( | |
atom_i: torch.Tensor, | |
atom_j: torch.Tensor, | |
atom_k: torch.Tensor, | |
distance_eps: float = 1e-3, | |
degrees: bool = False, | |
) -> torch.Tensor: | |
"""Batched bond angles given atoms `i-j-k`. | |
Args: | |
atom_i (Tensor): Atom `i` coordinates with shape `(..., 3)`. | |
atom_j (Tensor): Atom `j` coordinates with shape `(..., 3)`. | |
atom_k (Tensor): Atom `k` coordinates with shape `(..., 3)`. | |
distance_eps (float, optional): Distance smoothing parameter for | |
for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. | |
Default: 1E-3. | |
degrees (bool, optional): If True, convert to degrees. Default: False. | |
Returns: | |
A (Tensor): Elementwise bond angles with shape `(...)`. | |
""" | |
# Bond angle of i-j-k | |
U_ji = normed_vec(atom_i - atom_j, distance_eps=distance_eps) | |
U_jk = normed_vec(atom_k - atom_j, distance_eps=distance_eps) | |
inner_prod = torch.einsum("bix,bix->bi", U_ji, U_jk) | |
inner_prod = torch.clamp(inner_prod, -1, 1) | |
A = torch.acos(inner_prod) | |
if degrees: | |
A = A * 180.0 / np.pi | |
return A | |
def dihedrals( | |
atom_i: torch.Tensor, | |
atom_j: torch.Tensor, | |
atom_k: torch.Tensor, | |
atom_l: torch.Tensor, | |
distance_eps: float = 1e-3, | |
degrees: bool = False, | |
) -> torch.Tensor: | |
"""Batched bond dihedrals given atoms `i-j-k-l`. | |
Args: | |
atom_i (Tensor): Atom `i` coordinates with shape `(..., 3)`. | |
atom_j (Tensor): Atom `j` coordinates with shape `(..., 3)`. | |
atom_k (Tensor): Atom `k` coordinates with shape `(..., 3)`. | |
atom_l (Tensor): Atom `l` coordinates with shape `(..., 3)`. | |
distance_eps (float, optional): Distance smoothing parameter for | |
for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. | |
Default: 1E-3. | |
degrees (bool, optional): If True, convert to degrees. Default: False. | |
Returns: | |
D (Tensor): Elementwise bond dihedrals with shape `(...)`. | |
""" | |
U_ij = normed_vec(atom_j - atom_i, distance_eps=distance_eps) | |
U_jk = normed_vec(atom_k - atom_j, distance_eps=distance_eps) | |
U_kl = normed_vec(atom_l - atom_k, distance_eps=distance_eps) | |
normal_ijk = normed_cross(U_ij, U_jk, distance_eps=distance_eps) | |
normal_jkl = normed_cross(U_jk, U_kl, distance_eps=distance_eps) | |
# _inner_product = lambda a, b: torch.einsum("bix,bix->bi", a, b) | |
_inner_product = lambda a, b: (a * b).sum(-1) | |
cos_dihedrals = _inner_product(normal_ijk, normal_jkl) | |
angle_sign = _inner_product(U_ij, normal_jkl) | |
cos_dihedrals = torch.clamp(cos_dihedrals, -1, 1) | |
D = torch.sign(angle_sign) * torch.acos(cos_dihedrals) | |
if degrees: | |
D = D * 180.0 / np.pi | |
return D | |
def extend_atoms( | |
X_1: torch.Tensor, | |
X_2: torch.Tensor, | |
X_3: torch.Tensor, | |
lengths: torch.Tensor, | |
angles: torch.Tensor, | |
dihedrals: torch.Tensor, | |
distance_eps: float = 1e-3, | |
degrees: bool = False, | |
) -> torch.Tensor: | |
"""Place atom `X_4` given `X_1`, `X_2`, `X_3` and internal coordinates. | |
___________________ | |
| X_1 - X_2 | | |
| | | | |
| X_3 - [X_4] | | |
|___________________| | |
This uses a similar approach as NERF: | |
Parsons et al, Computational Chemistry (2005). | |
https://doi.org/10.1002/jcc.20237 | |
See the reference for further explanation about converting from internal | |
coordinates to Cartesian coordinates. | |
Args: | |
X_1 (Tensor): First atom coordinates with shape `(..., 3)`. | |
X_2 (Tensor): Second atom coordinates with shape `(..., 3)`. | |
X_3 (Tensor): Third atom coordinates with shape `(..., 3)`. | |
lengths (Tensor): Bond lengths `X_3-X_4` with shape `(...)`. | |
angles (Tensor): Bond angles `X_2-X_3-X_4` with shape `(...)`. | |
dihedrals (Tensor): Bond dihedrals `X_1-X_2-X_3-X_4` with shape `(...)`. | |
distance_eps (float, optional): Distance smoothing parameter for | |
for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. | |
This preserves differentiability for zero distances. Default: 1E-3. | |
degrees (bool, optional): If True, inputs are treated as degrees. | |
Default: False. | |
Returns: | |
X_4 (Tensor): Placed atom with shape `(..., 3)`. | |
""" | |
if degrees: | |
angles *= np.pi / 180.0 | |
dihedrals *= np.pi / 180.0 | |
r_32 = X_2 - X_3 | |
r_12 = X_2 - X_1 | |
n_1 = normed_vec(r_32, distance_eps=distance_eps) | |
n_2 = normed_cross(n_1, r_12, distance_eps=distance_eps) | |
n_3 = normed_cross(n_1, n_2, distance_eps=distance_eps) | |
lengths = lengths.unsqueeze(-1) | |
cos_angle = torch.cos(angles).unsqueeze(-1) | |
sin_angle = torch.sin(angles).unsqueeze(-1) | |
cos_dihedral = torch.cos(dihedrals).unsqueeze(-1) | |
sin_dihedral = torch.sin(dihedrals).unsqueeze(-1) | |
X_4 = X_3 + lengths * ( | |
cos_angle * n_1 | |
+ (sin_angle * sin_dihedral) * n_2 | |
+ (sin_angle * cos_dihedral) * n_3 | |
) | |
return X_4 | |
class InternalCoords(nn.Module): | |
"""Internal coordinates layer. | |
This layer computes internal coordinates (ICs) from a batch of protein | |
backbones. To make the ICs differentiable everywhere, this layer replaces | |
distance calculations of the form `sqrt(sum_sq)` with smooth, non-cusped | |
approximation `sqrt(sum_sq + eps)`. | |
Args: | |
distance_eps (float, optional): Small parameter to add to squared | |
distances to make gradients smooth near 0. | |
Inputs: | |
X (Tensor): Backbone coordinates with shape | |
`(num_batch, num_residues, num_atom_types, 3)`. | |
C (Tensor): Chain map tensor with shape | |
`(num_batch, num_residues)`. | |
Outputs: | |
dihedrals (Tensor): Backbone dihedral angles with shape | |
`(num_batch, num_residues, 4)` | |
angles (Tensor): Backbone bond lengths with shape | |
`(num_batch, num_residues, 4)` | |
lengths (Tensor): Backbone bond lengths with shape | |
`(num_batch, num_residues, 4)` | |
""" | |
def __init__(self, distance_eps=1e-3): | |
super(InternalCoords, self).__init__() | |
self.distance_eps = distance_eps | |
def forward( | |
self, | |
X: torch.Tensor, | |
C: Optional[torch.Tensor] = None, | |
return_masks: bool = False, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
mask = (C > 0).float() | |
X_chain = X[:, :, :3, :] | |
num_batch, num_residues, _, _ = X_chain.shape | |
X_chain = X_chain.reshape(num_batch, 3 * num_residues, 3) | |
# This function historically returns the angle complement | |
_lengths = lambda Xi, Xj: lengths(Xi, Xj, distance_eps=self.distance_eps) | |
_angles = lambda Xi, Xj, Xk: np.pi - angles( | |
Xi, Xj, Xk, distance_eps=self.distance_eps | |
) | |
_dihedrals = lambda Xi, Xj, Xk, Xl: dihedrals( | |
Xi, Xj, Xk, Xl, distance_eps=self.distance_eps | |
) | |
# Compute internal coordinates associated with -[N]-[CA]-[C]- | |
NCaC_L = _lengths(X_chain[:, 1:, :], X_chain[:, :-1, :]) | |
NCaC_A = _angles(X_chain[:, :-2, :], X_chain[:, 1:-1, :], X_chain[:, 2:, :]) | |
NCaC_D = _dihedrals( | |
X_chain[:, :-3, :], | |
X_chain[:, 1:-2, :], | |
X_chain[:, 2:-1, :], | |
X_chain[:, 3:, :], | |
) | |
# Compute internal coordinates associated with [C]=[O] | |
_, X_CA, X_C, X_O = X.unbind(dim=2) | |
X_N_next = X[:, 1:, 0, :] | |
O_L = _lengths(X_C, X_O) | |
O_A = _angles(X_CA, X_C, X_O) | |
O_D = _dihedrals(X_N_next, X_CA[:, :-1, :], X_C[:, :-1, :], X_O[:, :-1, :]) | |
if C is None: | |
C = torch.zeros_like(mask) | |
# Mask nonphysical bonds and angles | |
# Note: this could probably also be expressed as a Conv, unclear | |
# which is faster and this probably not rate-limiting. | |
C = C * (mask.type(torch.long)) | |
ii = torch.stack(3 * [C], dim=-1).view([num_batch, -1]) | |
L0, L1 = ii[:, :-1], ii[:, 1:] | |
A0, A1, A2 = ii[:, :-2], ii[:, 1:-1], ii[:, 2:] | |
D0, D1, D2, D3 = ii[:, :-3], ii[:, 1:-2], ii[:, 2:-1], ii[:, 3:] | |
# Mask for linear backbone | |
mask_L = torch.eq(L0, L1) | |
mask_A = torch.eq(A0, A1) * torch.eq(A0, A2) | |
mask_D = torch.eq(D0, D1) * torch.eq(D0, D2) * torch.eq(D0, D3) | |
mask_L = mask_L.type(torch.float32) | |
mask_A = mask_A.type(torch.float32) | |
mask_D = mask_D.type(torch.float32) | |
# Masks for branched oxygen | |
mask_O_D = torch.eq(C[:, :-1], C[:, 1:]) | |
mask_O_D = mask_O_D.type(torch.float32) | |
mask_O_A = mask | |
mask_O_L = mask | |
def _pad_pack(D, A, L, O_D, O_A, O_L): | |
# Pad and pack together the components | |
D = F.pad(D, (1, 2)) | |
A = F.pad(A, (0, 2)) | |
L = F.pad(L, (0, 1)) | |
O_D = F.pad(O_D, (0, 1)) | |
D, A, L = [x.reshape(num_batch, num_residues, 3) for x in [D, A, L]] | |
_pack = lambda a, b: torch.cat([a, b.unsqueeze(-1)], dim=-1) | |
L = _pack(L, O_L) | |
A = _pack(A, O_A) | |
D = _pack(D, O_D) | |
return D, A, L | |
D, A, L = _pad_pack(NCaC_D, NCaC_A, NCaC_L, O_D, O_A, O_L) | |
mask_D, mask_A, mask_L = _pad_pack( | |
mask_D, mask_A, mask_L, mask_O_D, mask_O_A, mask_O_L | |
) | |
mask_expand = mask.unsqueeze(-1) | |
mask_D = mask_expand * mask_D | |
mask_A = mask_expand * mask_A | |
mask_L = mask_expand * mask_L | |
D = mask_D * D | |
A = mask_A * A | |
L = mask_L * L | |
if not return_masks: | |
return D, A, L | |
else: | |
return D, A, L, mask_D, mask_A, mask_L | |
class VirtualAtomsCA(nn.Module): | |
"""Virtual atoms layer, branching from backbone C-alpha carbons. | |
This layer places virtual atom coordinates relative to backbone coordinates | |
in a differentiable way. | |
Args: | |
virtual_type (str, optional): Type of virtual atom to place. Currently | |
supported types are `dicons`, a virtual placement that was | |
optimized to predict potential rotamer interactions, and `cbeta` | |
which places a virtual C-beta carbon assuming ideal geometry. | |
distance_eps (float, optional): Small parameter to add to squared | |
distances to make gradients smooth near 0. | |
Inputs: | |
X (Tensor): Backbone coordinates with shape | |
`(num_batch, num_residues, num_atom_types, 3)`. | |
C (Tensor): Chain map tensor with shape `(num_batch, num_residues)`. | |
Outputs: | |
X_virtual (Tensor): Virtual coordinates with shape | |
`(num_batch, num_residues, 3)`. | |
""" | |
def __init__(self, virtual_type="dicons", distance_eps=1e-3): | |
super(VirtualAtomsCA, self).__init__() | |
self.distance_eps = distance_eps | |
""" | |
Geometry specifications | |
dicons | |
Length CA-X: 2.3866 | |
Angle N-CA-X: 111.0269 | |
Dihedral C-N-CA-X: -138.886412 | |
cbeta | |
Length CA-X: 1.532 (Engh and Huber, 2001) | |
Angle N-CA-X: 109.5 (tetrahedral geometry) | |
Dihedral C-N-CA-X: -125.25 (109.5 / 2 - 180) | |
""" | |
self.virtual_type = virtual_type | |
virtual_geometries = { | |
"dicons": [2.3866, 111.0269, -138.8864122], | |
"cbeta": [1.532, 109.5, -125.25], | |
} | |
self.virtual_geometries = virtual_geometries | |
self.distance_eps = distance_eps | |
def geometry(self): | |
bond, angle, dihedral = self.virtual_geometries[self.virtual_type] | |
return bond, angle, dihedral | |
def forward(self, X: torch.Tensor, C: torch.LongTensor) -> torch.Tensor: | |
bond, angle, dihedral = self.geometry() | |
ones = torch.ones([1, 1], device=X.device) | |
bonds = bond * ones | |
angles = angle * ones | |
dihedrals = dihedral * ones | |
# Build reference frame | |
# 1.C -> 2.N -> 3.CA -> 4.X | |
X_N, X_CA, X_C, X_O = X.unbind(2) | |
X_virtual = extend_atoms( | |
X_C, | |
X_N, | |
X_CA, | |
bonds, | |
angles, | |
dihedrals, | |
degrees=True, | |
distance_eps=self.distance_eps, | |
) | |
# Mask missing positions | |
mask = (C > 0).type(torch.float32).unsqueeze(-1) | |
X_virtual = mask * X_virtual | |
return X_virtual | |
def quaternions_from_rotations(R: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: | |
"""Convert a batch of rotation matrices to quaternions. | |
See en.wikipedia.org/wiki/Quaternions_and_spatial_rotation for further | |
details on converting between quaternions and rotation matrices. | |
Args: | |
R (tensor): Batch of rotation matrices with shape `(..., 3, 3)`. | |
Returns: | |
q (tensor): Batch of quaternion vectors with shape `(..., 4)`. Quaternion | |
is in the order `[angle, axis_x, axis_y, axis_z]`. | |
""" | |
batch_dims = list(R.shape)[:-2] | |
R_flat = R.reshape(batch_dims + [9]) | |
R00, R01, R02, R10, R11, R12, R20, R21, R22 = R_flat.unbind(-1) | |
# Quaternion possesses both an axis and angle of rotation | |
_sqrt = lambda r: torch.sqrt(F.relu(r) + eps) | |
q_angle = _sqrt(1 + R00 + R11 + R22).unsqueeze(-1) | |
magnitudes = _sqrt( | |
1 + torch.stack([R00 - R11 - R22, -R00 + R11 - R22, -R00 - R11 + R22], -1) | |
) | |
signs = torch.sign(torch.stack([R21 - R12, R02 - R20, R10 - R01], -1)) | |
q_axis = signs * magnitudes | |
# Normalize (for safety and a missing factor of 2) | |
q_unc = torch.cat((q_angle, q_axis), -1) | |
q = normed_vec(q_unc, distance_eps=eps) | |
return q | |
def rotations_from_quaternions( | |
q: torch.Tensor, normalize: bool = False, eps: float = 1e-3 | |
) -> torch.Tensor: | |
"""Convert a batch of quaternions to rotation matrices. | |
See en.wikipedia.org/wiki/Quaternions_and_spatial_rotation for further | |
details on converting between quaternions and rotation matrices. | |
Returns: | |
q (tensor): Batch of quaternion vectors with shape `(..., 4)`. Quaternion | |
is in the order `[angle, axis_x, axis_y, axis_z]`. | |
normalize (boolean, optional): Option to normalize the quaternion before | |
conversion. | |
Args: | |
R (tensor): Batch of rotation matrices with shape `(..., 3, 3)`. | |
""" | |
batch_dims = list(q.shape)[:-1] | |
if normalize: | |
q = normed_vec(q, distance_eps=eps) | |
a, b, c, d = q.unbind(-1) | |
a2, b2, c2, d2 = a ** 2, b ** 2, c ** 2, d ** 2 | |
R = torch.stack( | |
[ | |
a2 + b2 - c2 - d2, | |
2 * b * c - 2 * a * d, | |
2 * b * d + 2 * a * c, | |
2 * b * c + 2 * a * d, | |
a2 - b2 + c2 - d2, | |
2 * c * d - 2 * a * b, | |
2 * b * d - 2 * a * c, | |
2 * c * d + 2 * a * b, | |
a2 - b2 - c2 + d2, | |
], | |
dim=-1, | |
) | |
R = R.view(batch_dims + [3, 3]) | |
return R | |
def frames_from_backbone(X: torch.Tensor, distance_eps: float = 1e-3): | |
"""Convert a backbone into local reference frames. | |
Args: | |
X (Tensor): Backbone coordinates with shape `(..., 4, 3)`. | |
distance_eps (float, optional): Distance smoothing parameter for | |
for computing distances as `sqrt(sum_sq) -> sqrt(sum_sq + eps)`. | |
Default: 1E-3. | |
Returns: | |
R (Tensor): Reference frames with shape `(..., 3, 3)`. | |
X_CA (Tensor): C-alpha coordinates with shape `(..., 3)` | |
""" | |
X_N, X_CA, X_C, X_O = X.unbind(-2) | |
u_CA_N = normed_vec(X_N - X_CA, distance_eps) | |
u_CA_C = normed_vec(X_C - X_CA, distance_eps) | |
n_1 = u_CA_N | |
n_2 = normed_cross(n_1, u_CA_C, distance_eps) | |
n_3 = normed_cross(n_1, n_2, distance_eps) | |
R = torch.stack([n_1, n_2, n_3], -1) | |
return R, X_CA | |
def hat(omega: torch.Tensor) -> torch.Tensor: | |
""" | |
Maps [x,y,z] to [[0,-z,y], [z,0,-x], [-y, x, 0]] | |
Args: | |
omega (torch.tensor): of size (*, 3) | |
Returns: | |
hat{omega} (torch.tensor): of size (*, 3, 3) skew symmetric element in so(3) | |
""" | |
target = torch.zeros(*omega.size()[:-1], 9, device=omega.device) | |
index1 = torch.tensor([7, 2, 3], device=omega.device).expand( | |
*target.size()[:-1], -1 | |
) | |
index2 = torch.tensor([5, 6, 1], device=omega.device).expand( | |
*target.size()[:-1], -1 | |
) | |
return ( | |
target.scatter(-1, index1, omega) | |
.scatter(-1, index2, -omega) | |
.reshape(*target.size()[:-1], 3, 3) | |
) | |
def V(omega: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: | |
I = torch.eye(3, device=omega.device).expand(*omega.size()[:-1], 3, 3) | |
theta = omega.pow(2).sum(dim=-1, keepdim=True).add(eps).sqrt()[..., None] | |
omega_hat = hat(omega) | |
M1 = ((1 - theta.cos()) / theta.pow(2)) * (omega_hat) | |
M2 = ((theta - theta.sin()) / theta.pow(3)) * (omega_hat @ omega_hat) | |
return I + M1 + M2 | |