Hukuna's picture
Upload 221 files
ce7bf5b verified
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layers for generating protein structure.
This module contains pytorch layers for parametrically generating and
manipulating protein backbones. These can be used in tandem with loss functions
to generate and optimize protein structure (e.g. folding from predictions) or
used as an intermediate layer in a learned structure generation model.
"""
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from chroma.layers.structure import geometry, transforms
class ProteinBackbone(nn.Module):
"""Protein backbone layer with optimizable geometry (batch form).
This layer stores the parameters for a protein backbone, which can be based
on either internal coordinate or Cartesian parameterizations.
It outputs coordinates in Cartesian form as 4D tensors with indices
`[batch, position, atom_type, xyz]`. The `atom_type` index runs over the
heavy atoms of a protein backbone in PDB order, i.e. `[N, CA, C, O]`.
The resulting coordinates can be directly optimized with pytorch optimizers.
Args:
num_residues (int): Number of residues.
num_batch (int): Batch size.
init_state (str): Initialization state. Can be ['alpha', 'beta', '']
use_internal_coords (Boolean): Use a phi,psi parameterization.
Default is True.
X_init (torch.Tensor, optional): Initialize with pre-specified coordinates.
Requires that use_internal_coords=False.
Outputs:
X (torch.Tensor): Backbone coordinates with shape
`(batch_size, num_residues, 4, 3)`.
"""
def __init__(
self,
num_residues: int,
num_batch: int = 1,
init_state: str = "alpha",
use_internal_coords: bool = True,
X_init: Optional[torch.Tensor] = None,
):
super(ProteinBackbone, self).__init__()
# Dimensions
self.num_batch = num_batch
self.num_residues = num_residues
# Rigid body translation and rotation
self.transform = RigidTransform(num_batch=num_batch, keep_centered=True)
self.use_internal_coords = use_internal_coords
if self.use_internal_coords:
# Internal coordinate parameterization
self.phi = nn.Parameter(torch.zeros(num_batch, num_residues))
self.psi = nn.Parameter(torch.zeros(num_batch, num_residues))
# Initializer
phi_psi = {
"alpha": (np.radians(-60.0), np.radians(-45.0)),
"beta": (np.radians(-140.0), np.radians(135.0)),
}
if init_state in phi_psi:
torch.nn.init.constant_(self.phi, phi_psi[init_state][0])
torch.nn.init.constant_(self.psi, phi_psi[init_state][1])
else:
torch.nn.init.uniform_(self.phi, a=-np.pi, b=np.pi)
torch.nn.init.uniform_(self.psi, a=-np.pi, b=np.pi)
self.backbone_geometry = BackboneBuilder()
else:
# Use a Cartesian parameterization
if X_init is not None:
assert not use_internal_coords
else:
X_init = ProteinBackbone(
num_residues=num_residues,
num_batch=num_batch,
init_state=init_state,
use_internal_coords=True,
)()
self.X = nn.Parameter(X_init)
def forward(self) -> torch.Tensor:
if self.use_internal_coords:
X = self.backbone_geometry(self.phi, self.psi)
else:
X = self.X
# Apply rotation and translation
X = self.transform(X)
return X
class RigidTransform(nn.Module):
"""Rigid-body rotation and translation (batch form).
This layer stores the parameters for a rigid body rotation and translation.
It can be composed with other generative geometry layers to optimize over
poses.
Args:
num_batch (int): Number of poses to store parameters for.
keep_centered (Boolean): If True, center the input coordinates by
default.
scale_dX (float): Scale factor which affects the rate of change of
translation.
scale_q (float): Scale factor which affects the rate of change of
rotation.
Inputs:
X (torch.Tensor): Input coordinates with shape `(batch_size, ..., 3)`.
Outputs:
X_t (torch.Tensor): Transformed coordinates with shape:
`(batch_size, ..., 3)`.
"""
def __init__(
self,
num_batch: int = 1,
keep_centered: bool = False,
scale_dX: float = 1.0,
scale_q: float = 1.0,
):
super(RigidTransform, self).__init__()
self.num_batch = num_batch
# Cartesian offset initialized to 0
self.dX = nn.Parameter(torch.zeros(self.num_batch, 3))
self.scale_dX = scale_dX
# Unconstrained quaternion initialized to identity
self.scale_q = scale_q
q_init = np.asarray([[1.0, 0, 0, 0]] * self.num_batch)
q_init = torch.tensor(q_init, dtype=torch.float32) / self.scale_q
self.q_unc = nn.Parameter(q_init)
self.rigid_transform = RigidTransformer(keep_centered=keep_centered)
def forward(self, X: torch.Tensor) -> torch.Tensor:
dX = self.scale_dX * self.dX
q_unc = self.scale_q * self.q_unc
X_t = self.rigid_transform(X, dX, q_unc)
return X_t
class RigidTransformer(nn.Module):
"""Rigid-body rotation and translation (batch form).
This layer applies a rigid body rotation and translation,
and can be composed with other generative geometry layers to modify poses.
Internally, the coordinates are centered before rotation and translation.
The rotation itself is parameterized as a quaternion to prevent
Gimbal lock (https://en.wikipedia.org/wiki/Gimbal_lock).
Args:
center_intput (Boolean): Center the input coordinates (default: True)
default.
Inputs:
X (torch.Tensor): Input coordinates with shape `(batch_size, ..., 3)`.
dX (torch.Tensor): Translation vector with shape `(batch_size, 3)`.
q (torch.Tensor): Rotation vector (quaternion) with shape `(batch_size, 4)`.
It can be any 4-element real vector, but will internally be
normalized to a unit quaternion.
mask (tensor,optional): Mask tensor with shape `(batch_size, ..., 3)`.
Outputs:
X_t (torch.Tensor): Transformed coordinates with shape `(batch_size, ..., 3)`.
"""
def __init__(self, center_rotation: bool = True, keep_centered: bool = False):
super(RigidTransformer, self).__init__()
self.center_rotation = center_rotation
self.keep_centered = keep_centered
self.dist_eps = 1e-5
def _rotation_matrix(self, q_unc: torch.Tensor) -> torch.Tensor:
"""Build rotation matrix from quaternion parameters.
See en.wikipedia.org/wiki/Quaternions_and_spatial_rotation for further
details on converting between quaternions and rotation matrices.
Args:
q_unc (torch.Tensor): Unnormalized quaternion representing rotation with
shape `(batch_size, 3)`.
Returns:
R (torch.Tensor): Rotation matrix with shape `(batch_size, 3)`.
"""
num_batch = q_unc.shape[0]
q = F.normalize(q_unc, dim=-1)
# fmt: off
a,b,c,d = q.unbind(-1)
a2,b2,c2,d2 = a**2, b**2, c**2, d**2
R = torch.stack([
a2 + b2 - c2 - d2, 2*b*c - 2*a*d, 2*b*d + 2*a*c,
2*b*c + 2*a*d, a2 - b2 + c2 - d2, 2*c*d - 2*a*b,
2*b*d - 2*a*c, 2*c*d + 2*a*b, a2 - b2 - c2 + d2
], dim=-1)
# fmt: on
R = R.view([num_batch, 3, 3])
return R
def forward(
self,
X: torch.Tensor,
dX: torch.Tensor,
q: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
num_batch = X.shape[0]
X_flat = X.reshape([num_batch, -1, 3])
# Flatten mask
if mask is not None:
shape_mask = list(mask.size())
shape_X = list(X.size())
shape_mask_expand = shape_mask + [
1 for i in range(len(shape_X) - 1 - len(shape_mask))
]
shape_mask_flat = list(X_flat.size())[:-1] + [1]
mask_flat = mask.reshape(shape_mask_expand).expand(shape_X[:-1])
mask_flat = mask_flat.reshape(shape_mask_flat)
# Compute center
X_mean = torch.sum(mask_flat * X_flat, 1, keepdims=True) / (
torch.sum(mask_flat, 1, keepdims=True) + self.dist_eps
)
else:
X_mean = torch.mean(X_flat, 1, keepdims=True)
# Rotate around center of mass
if self.center_rotation:
X_centered = X_flat - X_mean
else:
X_centered = X_flat
R = self._rotation_matrix(q)
X_rotate = torch.einsum("bxr,bir->bix", R, X_centered)
# Optionally preserve original centering
if self.center_rotation and not self.keep_centered:
X_rotate = X_rotate + X_mean
# Translate
X_transform = X_rotate + dX.unsqueeze(1)
if mask is not None:
X_transform = mask_flat * X_transform + (1 - mask_flat) * X_flat
X_transform = X_transform.view(X.shape)
return X_transform
class BackboneBuilder(nn.Module):
"""Protein backbone builder from dihedral angles (batch form).
See ProteinBackbone() for further explanation of output coordinates.
When only partial information is given such as phi & psi angles, this module
will fall default to using the ideal geometries given in
Engh & Huber, International Tables for Crystallography (2001).
https://doi.org/10.1107/97809553602060000857
Todo:
* Add shifting and padding logic to associate phis and psis with their
'natural' residue indices rather than the child atoms that they
create during NERF recurrence
* Add control over the bond lengths and angles for Oxygen
Inputs:
phi (torch.Tensor): Phi dihedral angles with shape `(batch_size, length)`.
psi (torch.Tensor): Psi dihedral angles with shape `(batch_size, length)`.
omega (torch.Tensor, optional): Omega dihedral angles with shape
`(batch_size, length)`. Defaults to ideal geometry.
angles (torch.Tensor, optional): Bond angles with shape
`(batch_size, 3*length)` Defaults to ideal geometry.
lengths (torch.Tensor, optional): Bond lengths with shape
`(batch_size, 3*length)`. Defaults to ideal geometry.
Outputs:
X (torch.Tensor): Backbone coordinates with shape
`(batch_size, num_residues, 4, 3)`.
"""
def __init__(self):
super(BackboneBuilder, self).__init__()
# From "Structure Quality and Target Parameters", Engh & Huber, 2001
# fmt: off
self.lengths = {
'N_CA': 1.459,
'CA_C': 1.525,
'C_N': 1.336,
'C_O': 1.229
}
angles = {
'N_CA_C': 111.0,
'CA_C_N': 117.2,
'C_N_CA': 121.7,
'omega': 179.3
}
self.angles = {
k: v * np.pi / 180. for k,v in angles.items()
}
# fmt: on
return
def forward(
self,
phi: torch.Tensor,
psi: torch.Tensor,
omega: Optional[torch.Tensor] = None,
angles: Optional[torch.Tensor] = None,
lengths: Optional[torch.Tensor] = None,
add_O: bool = True,
) -> torch.Tensor:
N_batch, N_residues = phi.shape[0], phi.shape[1]
linear_shape = [N_batch, N_residues]
device = phi.device
"""
This uses a similar (but not identical) approach as NERF:
Parsons et al, Computational Chemistry (2005).
https://doi.org/10.1002/jcc.20237
See the reference for further explanation about converting from internal
coordinates to Cartesian coordinates.
____________________________________________________________________
| N-to-C backbone geometry for NERF |
| i.e. which internal coords create which Cartesian coords |
| [% indicates preceding residue] |
|______________________ _________________________________________ ___|
|i-1 |Residue i |i+1|
| | | |
|Atom: [C%]--omega%--[N]----phi----[CA]----psi---[C]---omega---[N>]|
| | | | |
|Parents | | | |
| Bond: C%_N N_CA CA_C |
| Angle: CA%_C%_N C%_N_CA N_CA_C |
|Dihedral: N%_CA%_C%_N CA%_C%_N_CA C%_N_CA_C |
|--------------------------------------------------------------------|
|Bond: [C_N]% [N_CA] [CA_C] |
|Dihedral: psi% omega% phi |
|____________________________________________________________________|
"""
if lengths is None:
lengths = torch.tensor(
[[self.lengths[key] for key in ["C_N", "N_CA", "CA_C"]]],
dtype=torch.float32,
).to(device)
lengths = lengths.repeat(N_batch, N_residues)
if angles is None:
angles = torch.tensor(
[[self.angles[key] for key in ["CA_C_N", "C_N_CA", "N_CA_C"]]],
dtype=torch.float32,
).to(device)
angles = angles.repeat(N_batch, N_residues)
if omega is None:
omega = self.angles["omega"] * torch.ones(linear_shape).to(device)
# Compute un-rotated Cartesian coordinates in batch
dihedrals = torch.stack([psi, omega, phi], -1)
dihedrals = dihedrals.view([N_batch, 3 * N_residues])
angles_comp = np.pi - angles
v = torch.stack(
[
torch.cos(angles_comp),
torch.sin(angles_comp) * torch.cos(dihedrals),
torch.sin(angles_comp) * torch.sin(dihedrals),
],
-1,
)
# Lengths
lengths_list = list(lengths.unsqueeze(-1).unbind(1))
v_list = list(v.unbind(1))
if add_O:
# Build one extra appended residue
lengths_list += lengths_list[-3:]
v_list += v_list[-3:]
def _build_x_i(v_i, l_i, x, u_minus_1, u_minus_2):
"""Recurrence relation for placing atoms (NERF)"""
# Build matrix encoding local reference frame
n_a_unnorm = torch.cross(u_minus_2, u_minus_1)
n_a = F.normalize(n_a_unnorm, dim=-1)
n_b = torch.cross(n_a, u_minus_1)
# Matrix multiply version
R = torch.stack([u_minus_1, n_b, n_a], 2)
u_new = torch.matmul(R, v_i.unsqueeze(-1)).squeeze(-1)
x_new = x + l_i * u_new
return x_new, u_new, u_minus_1
# Initialization
x_i = torch.zeros([N_batch, 3]).to(device)
u_i_minus_2 = torch.tensor([[1.0, 0, 0]] * N_batch, dtype=torch.float32).to(
device
)
u_i_minus_1 = torch.tensor([[0, 1.0, 0]] * N_batch, dtype=torch.float32).to(
device
)
# Build chain via NERF recurrence
X = []
for i, (v_i, l_i) in enumerate(zip(v_list, lengths_list)):
x_i, u_i_minus_1, u_i_minus_2 = _build_x_i(
v_i, l_i, x_i, u_i_minus_1, u_i_minus_2
)
X.append(x_i)
X = torch.stack(X, 1)
# [N,AL,3] => [N,L,A,3]
X = X.view([N_batch, -1, 3, 3])
if add_O:
# Build the oxygen vector using symmetry
u_1 = F.normalize(X[:, :-1, 2, :] - X[:, :-1, 1, :], dim=-1) # CA->C
u_2 = F.normalize(X[:, :-1, 2, :] - X[:, 1:, 0, :], dim=-1) # C<-N*
u = self.lengths["C_O"] * F.normalize(u_1 + u_2, dim=-1)
X = X[:, :-1, :, :]
X_O = X[:, :, 2, :] + u
X = torch.cat([X, X_O.unsqueeze(2)], 2)
X = X - X.mean([1, 2, 3], keepdim=True)
return X
class FrameBuilder(nn.Module):
"""Build protein backbones from rigid residue poses.
Inputs:
R (torch.Tensor): Rotation of residue orientiations
with shape `(num_batch, num_residues, 3, 3)`. If `None`,
then `q` must be provided instead.
t (torch.Tensor): Translation of residue orientiations
with shape `(num_batch, num_residues, 3)`. This is the
location of the C-alpha coordinates.
C (torch.Tensor): Chain map with shape `(num_batch, num_residues)`.
q (Tensor, optional): Quaternions representing residue orientiations
with shape `(num_batch, num_residues, 4)`.
Outputs:
X (torch.Tensor): All-atom protein coordinates with shape
`(num_batch, num_residues, 4, 3)`
"""
def __init__(self, distance_eps: float = 1e-3):
super().__init__()
# Build idealized backbone fragment
t = torch.tensor(
[
[1.459, 0.0, 0.0], # N-C via Engh & Huber is 1.459
[0.0, 0.0, 0.0], # CA is origin
[-0.547, 0.0, -1.424], # C is placed 1.525 A @ 111 degrees from N
],
dtype=torch.float32,
).reshape([1, 1, 3, 3])
R = torch.eye(3).reshape([1, 1, 1, 3, 3])
self.register_buffer("_t_atom", t)
self.register_buffer("_R_atom", R)
# Carbonyl geometry from CHARMM all36_prot ALA definition
self._length_C_O = 1.2297
self._angle_CA_C_O = 122.5200
self._dihedral_Np_CA_C_O = 180
self.distance_eps = distance_eps
def _build_O(self, X_chain: torch.Tensor, C: torch.LongTensor):
"""Build backbone carbonyl oxygen."""
# Build carboxyl groups
X_N, X_CA, X_C = X_chain.unbind(-2)
# TODO: fix this behavior for termini
mask_next = (C > 0).float()[:, 1:].unsqueeze(-1)
X_N_next = F.pad(mask_next * X_N[:, 1:,], (0, 0, 0, 1),)
num_batch, num_residues = C.shape
ones = torch.ones(list(C.shape), dtype=torch.float32, device=C.device)
X_O = geometry.extend_atoms(
X_N_next,
X_CA,
X_C,
self._length_C_O * ones,
self._angle_CA_C_O * ones,
self._dihedral_Np_CA_C_O * ones,
degrees=True,
)
mask = (C > 0).float().reshape(list(C.shape) + [1, 1])
X = mask * torch.stack([X_N, X_CA, X_C, X_O], dim=-2)
return X
def forward(
self,
R: torch.Tensor,
t: torch.Tensor,
C: torch.LongTensor,
q: Optional[torch.Tensor] = None,
):
assert q is None or R is None
if R is None:
# (B,N,1,3,3) and (B,N,1,3)
R = geometry.rotations_from_quaternions(
q, normalize=True, eps=self.distance_eps
)
R = R.unsqueeze(-3)
t_frame = t.unsqueeze(-2)
X_chain = transforms.compose_translation(R, t_frame, self._t_atom)
X = self._build_O(X_chain, C)
return X
def inverse(
self, X: torch.Tensor, C: torch.LongTensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Reconstruct transformations from poses.
Inputs:
X (torch.Tensor): All-atom protein coordinates with shape
`(num_batch, num_residues, 4, 3)`
C (torch.Tensor): Chain map with shape `(num_batch, num_residues)`.
Outputs:
R (torch.Tensor): Rotation of residue orientiations
with shape `(num_batch, num_residues, 3, 3)`.
t (torch.Tensor): Translation of residue orientiations
with shape `(num_batch, num_residues, 3)`. This is the
location of the C-alpha coordinates.
q (torch.Tensor): Quaternions representing residue orientiations
with shape `(num_batch, num_residues, 4)`.
"""
X_bb = X[:, :, :4, :]
R, t = geometry.frames_from_backbone(X_bb, distance_eps=self.distance_eps)
q = geometry.quaternions_from_rotations(R, eps=self.distance_eps)
mask = (C > 0).float().unsqueeze(-1)
R = mask.unsqueeze(-1) * R
t = mask * t
q = mask * q
return R, t, q
class GraphBackboneUpdate(nn.Module):
"""Layer for updating backbone coordinates given graph embeddings.
Args:
dim_nodes (int): Node dimension of graph input.
dim_edges (int): Edge dimension of graph input.
distance_scale (float): Coordinate scaling factor in angstroms. Default
is 10 angstroms per unit, which corresponds to nanometers.
method (str): Method used for predicting coordinates. Options include
* `local`: Node-based relative transformations.
* `neighbor`: Inter-residue geometry.
* `neighbor_global`: Inter-residue geometry with virtual global edge.
* `neighbor_global_affine`: Inter-residue geometry with virtual
global edge, parameterized as a residual update.
* `none`: No transformation-based updates.
iterations (int): Number of method iteractions.
unconstrained (bool): If True, update sub-pose geometries beyond ideal
coordinates.
num_transform_weights (int): Number of uncertainty dimensions per residue
for neighbor-based updates.
black_hole_init (bool): If True, ignore initial geometry and initialize
poses at the origin as in AlphaFold2 (Jumper et al 2020).
Inputs:
X (torch.Tensor): Backbone coordinates with shape
`(num_batch, num_residues, num_atoms, 3)`.
C (torch.LongTensor): Chain map with shape `(num_batch, num_residues)`.
node_h (torch.Tensor): Node features with shape
`(num_batch, num_residues, dim_nodes)`.
edge_h (torch.Tensor): Edge features with shape
`(num_batch, num_residues, num_neighbors, dim_edges)`.
edge_idx (torch.LongTensor): Edge indices for neighbors with shape
`(num_batch, num_residues, num_neighbors)`.
mask_i (torch.Tensor): Node mask with shape `(num_batch, num_residues)`.
mask_ij (torch.Tensor): Edge mask with shape
`(num_batch, num_nodes, num_neighbors)`.
Outputs:
X_update (torch.Tensor): Updated backbone coordinates with shape
`(num_batch, num_residues, num_atoms, 3)`.
"""
def __init__(
self,
dim_nodes: int,
dim_edges: int,
distance_scale: float = 10.0,
distance_eps: float = 1e-3,
method: str = "neighbor",
iterations: int = 1,
unconstrained: bool = True,
num_transform_weights: int = 1,
black_hole_init: bool = False,
):
super(GraphBackboneUpdate, self).__init__()
self.distance_scale = distance_scale
self.distance_eps = distance_eps
self._eps = 1e-5
self.frame_builder = FrameBuilder(distance_eps=distance_eps)
self.method = method
self.iterations = iterations
self.unconstrained = unconstrained
self.num_transform_weights = num_transform_weights
self.black_hole_init = black_hole_init
if self.method == "local":
self.W_q = nn.Linear(dim_nodes, 4)
self.W_t = nn.Linear(dim_nodes, 3)
elif self.method == "neighbor":
self.W_q = nn.Linear(dim_edges, 4)
self.W_t = nn.Linear(dim_edges, 3)
self.W_w = nn.Linear(dim_edges, self.num_transform_weights)
elif self.method == "neighbor_global":
self.W_q = nn.Linear(dim_edges, 4)
self.W_t = nn.Linear(dim_edges, 3)
self.W_w = nn.Linear(dim_edges, self.num_transform_weights)
self.W_q_global = nn.Linear(dim_nodes, 4)
self.W_t_global = nn.Linear(dim_nodes, 3)
self.W_w_global = nn.Linear(dim_nodes, self.num_transform_weights)
elif self.method == "neighbor_global_affine":
self.W_s_node = nn.Linear(dim_nodes, 2)
self.W_s_edge = nn.Linear(dim_edges, 2)
self.W_q = nn.Linear(dim_edges, 4)
self.W_t = nn.Linear(dim_edges, 3)
self.W_w = nn.Linear(dim_edges, self.num_transform_weights)
self.W_q_global = nn.Linear(dim_nodes, 4)
self.W_t_global = nn.Linear(dim_nodes, 3)
self.W_w_global = nn.Linear(dim_nodes, self.num_transform_weights)
if self.method == "none":
# None does no frame based updates
assert self.unconstrained
if self.unconstrained:
self.W_t_local = nn.Linear(dim_nodes, 12)
return
def _init_black_hole(self, X):
R = (
torch.eye(3, device=X.device, dtype=X.dtype)
.reshape(1, 1, 3, 3)
.repeat(X.size(0), X.size(1), 1, 1)
)
t = torch.zeros(X.size(0), X.size(1), 3, dtype=X.dtype, device=X.device)
return R, t
def _update_local_transform(self, X, C, node_h, edge_h, edge_idx, mask_i, mask_ij):
"""Update residue frames via transformation from self."""
R_i, t_i, _ = self.frame_builder.inverse(X, C)
if self.black_hole_init:
R_i, t_i = self._init_black_hole(X)
# Predict transforms
R = geometry.rotations_from_quaternions(
self.W_q(node_h), normalize=True, eps=self.distance_eps
)
t = self.distance_scale * self.W_t(node_h)
# Apply transformations
R_i_pred, t_i_pred = transforms.compose_transforms(R_i, t_i, R, t)
X_update = self.frame_builder(R_i_pred, t_i_pred, C)
return X_update, None, None, None
def _update_neighbor_transform(
self, X, C, node_h, edge_h, edge_idx, mask_i, mask_ij
):
"""Update residue frames via weighted average transformation from neighbors."""
# Predict relative transformations from neighbors to self
R_ji = geometry.rotations_from_quaternions(
self.W_q(edge_h), normalize=True, eps=self.distance_eps
)
t_ji = self.distance_scale * self.W_t(edge_h)
logit_ij = self.W_w(edge_h)
# Compute predicted self locations from each neighbor
R_i, t_i, _ = self.frame_builder.inverse(X, C)
if self.black_hole_init:
R_i, t_i = self._init_black_hole(X)
R_i, t_i = transforms.equilibrate_transforms(
R_i,
t_i,
R_ji,
t_ji,
logit_ij,
mask_ij,
edge_idx,
iterations=self.iterations,
)
X_update = self.frame_builder(R_i, t_i, C)
return X_update, R_ji, t_ji, None
def _update_neighbor_global_transform(
self, X, C, node_h, edge_h, edge_idx, mask_i, mask_ij
):
"""Update residue frames via weighted average transformation from neighbors."""
# Predict relative transformations from neighbors to self
R_ji = geometry.rotations_from_quaternions(
self.W_q(edge_h), normalize=True, eps=self.distance_eps
)
t_ji = self.distance_scale * self.W_t(edge_h)
logit_ji = self.W_w(edge_h)
# Predict relative transformations to global frame
R_global_i = geometry.rotations_from_quaternions(
self.W_q_global(node_h), normalize=True, eps=self.distance_eps
)
t_global_i = self.distance_scale * self.W_t_global(node_h)
logit_global_i = self.W_w_global(node_h)
# Initialize global frame equivariantly
R_i, t_i, _ = self.frame_builder.inverse(X, C)
if self.black_hole_init:
R_i, t_i = self._init_black_hole(X)
R_global, t_global = transforms.average_transforms(
R_i, t_i, mask_i[..., None], mask_i, dim=1, dither_eps=0.0
)
# Compute predicted self locations from each neighbor
R_i, t_i = transforms.equilibrate_transforms(
R_i,
t_i,
R_ji,
t_ji,
logit_ji,
mask_ij,
edge_idx,
iterations=self.iterations,
R_global=R_global,
t_global=t_global,
R_global_i=R_global_i,
t_global_i=t_global_i,
logit_global_i=logit_global_i,
)
X_update = self.frame_builder(R_i, t_i, C)
return X_update, R_ji, t_ji, logit_ji
def _update_neighbor_global_affine_transform(
self, X, C, node_h, edge_h, edge_idx, mask_i, mask_ij
):
"""Update residue frames via weighted average transformation from neighbors."""
# Compute interresidue geometries for current system
R_i_init, t_i_init, _ = self.frame_builder.inverse(X, C)
if self.black_hole_init:
R_i_init, t_i_init = self._init_black_hole(X)
R_j_init, t_j_init = transforms.collect_neighbor_transforms(
R_i_init, t_i_init, edge_idx
)
R_global, t_global = transforms.average_transforms(
R_i_init, t_i_init, mask_i[..., None], mask_i, dim=1, dither_eps=0.0
)
R_ji_init, t_ji_init = transforms.compose_inner_transforms(
R_j_init, t_j_init, R_i_init.unsqueeze(-3), t_i_init.unsqueeze(-2)
)
R_gi_init, t_gi_init = transforms.compose_inner_transforms(
R_global.unsqueeze(1), t_global.unsqueeze(1), R_i_init, t_i_init
)
q_ji_init = geometry.quaternions_from_rotations(R_ji_init)
q_gi_init = geometry.quaternions_from_rotations(R_gi_init)
# Scale factor
s_node = torch.sigmoid(self.W_s_node(node_h)[..., None]).unbind(-2)
s_edge = torch.sigmoid(self.W_s_edge(edge_h)[..., None]).unbind(-2)
d_scale = self.distance_scale
# Use edges to predict relative transformations from neighbors to self
q_ji = s_edge[0] * q_ji_init + (1.0 - s_edge[0]) * self.W_q(edge_h)
t_ji = s_edge[1] * t_ji_init + (1.0 - s_edge[1]) * d_scale * self.W_t(edge_h)
logit_ji = self.W_w(edge_h)
# Use nodes to predict relative transformations to global frame
q_gi = s_node[0] * q_gi_init + (1.0 - s_node[0]) * self.W_q_global(node_h)
t_gi = s_node[1] * t_gi_init + (1.0 - s_node[1]) * d_scale * self.W_t_global(
node_h
)
logit_gi = self.W_w_global(node_h)
R_ji = geometry.rotations_from_quaternions(
q_ji, normalize=True, eps=self.distance_eps
)
R_gi = geometry.rotations_from_quaternions(
q_gi, normalize=True, eps=self.distance_eps
)
# Compute predicted self locations from each neighbor
R_i, t_i = transforms.equilibrate_transforms(
R_i_init,
t_i_init,
R_ji,
t_ji,
logit_ji,
mask_ij,
edge_idx,
iterations=self.iterations,
R_global=R_global,
t_global=t_global,
R_global_i=R_gi,
t_global_i=t_gi,
logit_global_i=logit_gi,
)
X_update = self.frame_builder(R_i, t_i, C)
return X_update, R_ji, t_ji, logit_ji
def _inner_transforms(self, X, C, edge_idx):
R_i, t_i, _ = self.frame_builder.inverse(X, C)
R_ji, t_ji = transforms.collect_neighbor_inner_transforms(R_i, t_i, edge_idx)
return R_ji, t_ji
def _transform_loss(self, R_ij_predict, t_ij_predict, X, C, edge_idx, mask_ij):
"""Compute loss between transforms"""
R_ij, t_ij = self._inner_transforms(X, C, edge_idx)
R_ij_error = (R_ij_predict - R_ij).square().sum([-1, -2])
t_ij_error = (t_ij_predict - t_ij).square().sum([-1])
R_ij_mse = (mask_ij * R_ij_error).sum([1, 2]) / (
mask_ij.sum([1, 2]) + self._eps
)
t_ij_mse = (mask_ij * t_ij_error).sum([1, 2]) / (
mask_ij.sum([1, 2]) + self._eps
)
return R_ij_mse, t_ij_mse
def forward(
self,
X: torch.Tensor,
C: torch.LongTensor,
node_h: torch.Tensor,
edge_h: torch.Tensor,
edge_idx: torch.LongTensor,
mask_i: torch.Tensor,
mask_ij: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
methods = {
"local": self._update_local_transform,
"neighbor": self._update_neighbor_transform,
"neighbor_global": self._update_neighbor_global_transform,
"neighbor_global_affine": self._update_neighbor_global_affine_transform,
"none": lambda *args: args[0],
}
method = methods[self.method]
# Update frames with ideal geometry
X_update, R_ji, t_ji, logit_ji = method(
X, C, node_h, edge_h, edge_idx, mask_i, mask_ij
)
if self.unconstrained:
# Predict atomic updates as delta from ideal geometry
# R_i, t_i, _ = self.frame_builder.inverse(X, C) # NOTE: Old models did this which was a typo
R_i, t_i, _ = self.frame_builder.inverse(X_update, C)
t_local = self.W_t_local(node_h).reshape(X.shape)
# Rotate atomic updates into local frame
R_i = R_i.unsqueeze(-3)
t_i = torch.zeros_like(t_i).unsqueeze(-2)
dX = transforms.compose_translation(R_i, t_i, t_local)
if self.training:
# Randomly swap between ideal coordinates at train time
mask_drop = (
torch.randint(
low=0, high=2, size=[C.shape[0], 1, 1, 1], device=X.device
)
> 0
).float()
dX = mask_drop * dX
X_update = X_update + dX
return X_update, R_ji, t_ji, logit_ji
class LossBackboneResidueDistance(nn.Module):
"""Compute losses for training denoising diffusion models.
Inputs:
X_mobile (torch.Tensor): Mobile coordinates with shape
`(num_source, num_atoms, 4, 3)`.
X_target (torch.Tensor): Target coordinates with shape
`(num_target, num_atoms, 4, 3)`.
C (torch.Tensor): Chain map with shape `(num_batch, num_residues)`.
Outputs:
D_error (Tensor, optional): Per-site average distance errors with shape
`(num_batch)`.
"""
def __init__(self, dist_eps: float = 1e-3):
super(LossBackboneResidueDistance, self).__init__()
self.dist_eps = dist_eps
def _D(self, X):
"""Compute distance matrix between center of mass"""
X_mean = X.mean(2)
D = (
(X_mean[:, :, None, :] - X_mean[:, None, :, :])
.square()
.sum(-1)
.add(self.dist_eps)
.sqrt()
)
return D
def forward(
self, X_mobile: torch.Tensor, X_target: torch.Tensor, C: torch.LongTensor
) -> torch.Tensor:
mask = (C > 0).float()
mask_2D = mask[:, :, None] * mask[:, None, :]
D_error = (self._D(X_mobile) - self._D(X_target)).square()
D_error = (mask_2D * D_error).sum(-1) / (mask_2D.sum(-1) + self.dist_eps)
return D_error
def center_X(X: torch.Tensor, C: torch.LongTensor) -> torch.Tensor:
"""Center each protein system at the origin.
Args:
X (torch.Tensor): Backbone coordinates with shape
`(batch_size, num_residues, num_atoms, 3)`.
C (torch.LongTensor): Chain map with shape
`(num_batch, num_residues)`.
Returns:
X_centered (torch.Tensor): Centered backbone coordinates with shape
`(batch_size, num_residues, num_atoms, 3)`.
"""
mask_expand = (
(C > 0).float().reshape(list(C.shape) + [1, 1]).expand([-1, -1, 4, -1])
)
X_mean = (mask_expand * X).sum([1, 2], keepdims=True) / (
mask_expand.sum([1, 2], keepdims=True)
)
X_centered = mask_expand * (X - X_mean)
return X_centered
def atomic_mean(
X_flat: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute the mean across all 4 atom types.
Args:
X (torch.Tensor): Flattened backbone coordinates with shape
`(batch_size, num_residues * num_atoms, 3)`.
mask (torch.Tensor): Mask with shape `(num_batch, num_residues)`.
Returns:
X_mean (torch.Tensor): System centers with shape `(batch_size, 3)`.
mask_atomic (torch.Tensor): Atomic mask with shape
`(batch_size, num_residues * num_atoms)`.
"""
mask_expand = mask.unsqueeze(-1).expand(-1, -1, 4)
mask_atomic = mask_expand.reshape(mask.shape[0], -1).unsqueeze(-1)
X_mean = torch.sum(mask_atomic * X_flat, 1, keepdims=True) / (
torch.sum(mask_atomic, 1, keepdims=True)
)
return X_mean, mask_atomic
def scale_around_mean(
X: torch.Tensor, C: torch.LongTensor, scale: float
) -> torch.Tensor:
"""Scale coordinates around mean.
Args:
X (torch.Tensor): Backbone coordinates with shape
`(batch_size, num_residues, num_atoms, 3)`.
C (LongTensor): Chain map with shape
`(num_batch, num_residues)`.
scale (float): Scalar factor by which to rescale
the coordinates.
Returns:
X_scaled (torch.Tensor): Backbone coordinates with shape
`(batch_size, num_residues, num_atoms, 3)`.
"""
num_atoms = X.size(2)
mask_expand = (
(C > 0).float().reshape(list(C.shape) + [1, 1]).expand([-1, -1, num_atoms, -1])
)
X_mean = (mask_expand * X).sum([1, 2], keepdims=True) / (
mask_expand.sum([1, 2], keepdims=True)
)
X_rescale = mask_expand * (scale[:, None, None, None] * (X - X_mean) + X_mean)
return X_rescale
def impute_masked_X(X: torch.Tensor, C: torch.LongTensor) -> torch.Tensor:
"""Impute missing structure data to enforce chain contiguity.
The posterior mean under a Brownian bridge is simply either the
nearest unclamped state or a linear interpolant between the two
nearest clamped endpoints along the chain.
Args:
X (torch.Tensor): Backbone coordinates with shape
`(batch_size, num_residues, num_atoms, 3)`.
C (LongTensor): Chain map with shape
`(num_batch, num_residues)`.
Returns:
X (torch.Tensor): Imputed backbone coordinates with shape
`(batch_size, num_residues, num_atoms, 3)`.
"""
X_flat = X.reshape(X.shape[0], -1, 3)
mask = (C > 0).type(torch.float32)
X_mean, mask_atomic = atomic_mean(X_flat, mask)
# Expand chain map into atomic level masking
C_expand = C.unsqueeze(-1).expand(-1, -1, 4)
C_atomic = C_expand.reshape(C.shape[0], -1)
# Find nearest unmasked positions to the left and right
ix = torch.arange(C_atomic.shape[1], device=X.device).reshape([1, -1])
mask_atomic_extend = mask_atomic.squeeze(-1)
ix_mask = mask_atomic_extend * ix - (1 - mask_atomic_extend)
ix_left, _ = torch.cummax(ix_mask, dim=1)
ix_flip = torch.flip(
mask_atomic_extend * ix_mask + (1 - mask_atomic_extend) * C_atomic.shape[1],
[1],
)
ix_right, _ = torch.cummin(ix_flip, dim=1)
ix_right = torch.flip(ix_right, [1])
ix_left = ix_left.long()
ix_right = ix_right.long()
clamped_left = ix_left >= 0
clamped_right = ix_right < C_atomic.shape[1]
ix_left[ix_left < 0] = 0
ix_right[ix_right == C_atomic.shape[1]] = 0
X_left = torch.gather(X_flat, 1, ix_left.unsqueeze(-1).expand(-1, -1, 3))
X_right = torch.gather(X_flat, 1, ix_right.unsqueeze(-1).expand(-1, -1, 3))
# Enfore that adjacent residues are same chain
C_abs = torch.abs(C_atomic)
C_left = torch.gather(C_abs, 1, ix_left)
C_right = torch.gather(C_abs, 1, ix_right)
clamped_left = clamped_left * (C_left == C_abs)
clamped_right = clamped_right * (C_right == C_abs)
# Build linear interpolant
D_left = torch.abs(ix - ix_left)
D_right = torch.abs(ix_right - ix)
interp_theta = (D_right / (D_left + D_right + 1e-5)).unsqueeze(-1)
X_interp = interp_theta * X_left + (1 - interp_theta) * X_right
clamped_left = clamped_left.unsqueeze(-1)
clamped_right = clamped_right.unsqueeze(-1)
X_imputed_flat = mask_atomic * X_flat + (1 - mask_atomic) * (
clamped_left * clamped_right * X_interp
+ clamped_right * (~clamped_left) * X_right
+ (~clamped_right) * clamped_left * X_left
)
X_imputed = X_imputed_flat.reshape(X.shape)
return X_imputed
def expand_chain_map(C: torch.LongTensor) -> torch.Tensor:
"""Expand an integer chain map into a binary chain mask.
Args:
C (LongTensor): Chain map with shape
`(num_batch, num_residues)`.
Returns:
mask_C (torch.Tensor): Expanded binary chain map with shape
`(num_batch, num_residue, num_chains)`.
"""
# Compute the per-chain averages of each feature within a chain, in place
num_batch, num_residues = list(C.shape)
num_chains = int(torch.max(C).item())
# Build a position == chain expanded mask (B,L,C)
C_expand = C.unsqueeze(-1).expand(-1, -1, num_chains)
idx = torch.arange(num_chains, device=C.device) + 1
idx_expand = idx.view(1, 1, -1)
mask_C = (idx_expand == C_expand).type(torch.float32)
return mask_C