Hukuna's picture
Upload 221 files
ce7bf5b verified
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layers for annotating hydrogen bonds in protein structures.
"""
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from chroma.layers.graph import collect_neighbors
from chroma.layers.structure import protein_graph
from chroma.layers.structure.geometry import normed_vec
class BackboneHBonds(nn.Module):
"""Compute hydrogen bonds from protein backbones.
We use the simple electrostatic model for calling hydrogen
bonds of DSSP, which is described at
https://en.wikipedia.org/wiki/DSSP_(algorithm). After
placing virtual hydrogens on all backbone nitrogens,
we consider potential hydrogen bonds with carbonyl groups
on the backbone with residue distance |i-j| > 2. The
picture is:
-0.20e +0.20e -0.42e +0.42e
[N_i]-----[H_i] ::: [O_j]=====[C_j]
Args:
cutoff_energy (float, optional): Cutoff energy with
default value -0.5 (DSSP).
cutoff_distance (float, optional): Max distance
between `N_i` and `O_j` with default value 3.6 angstroms.
cutoff_gap (float, optional): Minimum tolerated residue
distance, i.e. `|i-j| >= cutoff_gap`.
Default value of 3.
Inputs:
X (Tensor): Backbone coordinates with shape
`(num_batch, num_residues, num_atom_types, 3)`.
C (LongTensor): Chain map tensor with shape `(num_batch, num_residues)`.
edge_idx (LongTensor): Edge indices for neighbors with shape
`(num_batch, num_residues, num_neighbors)`.
mask_ij (Tensor): Edge mask with shape
`(num_batch, num_nodes, num_neighbors)`.
Outputs:
hbonds (Tensor): Binary matrix annotating backbone hydrogen bonds
with shape `(num_batch, num_nodes, num_neighbors)`.
mask_hb_ij (Tensor): Hydrogen bond mask with shape
`(num_batch, num_nodes, num_neighbors)`.
H_i (Tensor): Virtual hydrogen coordinates with shape
`(num_batch, num_nodes, 3)`.
"""
def __init__(
self,
cutoff_energy: float = -0.5,
cutoff_distance: float = 3.6,
cutoff_gap: float = 3,
distance_eps: float = 1e-3,
) -> None:
super(BackboneHBonds, self).__init__()
self.cutoff_energy = cutoff_energy
self.cutoff_distance = cutoff_distance
self.cutoff_gap = cutoff_gap
self._coefficient = 0.42 * 0.2 * 332
self._eps = distance_eps
# Lishan Yao et al. JACS 2008, NMR data
self._length_NH = 1.015
return
def forward(
self,
X: torch.Tensor,
C: torch.LongTensor,
edge_idx: torch.LongTensor,
mask_ij: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_batch, num_residues, _, _ = X.shape
# Collect coordinates at i and j
X_flat = X.reshape([num_batch, num_residues, -1])
X_j_flat = collect_neighbors(X_flat, edge_idx)
X_j = X_j_flat.reshape([num_batch, num_residues, -1, 4, 3])
# Get amide [N-H] atoms at i by
# by placing virtual H from C_{i-1}-N-Ca neg bisector
X_prev = F.pad(X, [0, 0, 0, 0, 1, 0], mode="replicate")[:, :-1, :, :]
C_prev_i = X_prev[:, :, 2, :]
N_i = X[:, :, 0, :]
Ca_i = X[:, :, 1, :]
u_CprevN_i = normed_vec(N_i - C_prev_i)
u_CaN_i = normed_vec(N_i - Ca_i)
u_NH_i = normed_vec(u_CprevN_i + u_CaN_i)
H_i = N_i + self._length_NH * u_NH_i
# Add broadcasting dimensions
N_i = N_i[:, :, None, :]
H_i = H_i[:, :, None, :]
# Get carbonyl [C=O] atoms at j
O_j = X_j[:, :, :, 3, :]
C_j = X_j[:, :, :, 2, :]
_invD = (
lambda Xi, Xj: (Xi - Xj).square().sum(-1).add(self._eps).sqrt().reciprocal()
)
U_ij = self._coefficient * (
_invD(N_i, O_j) - _invD(N_i, C_j) + _invD(H_i, C_j) - _invD(H_i, O_j)
)
# Mask any bonds exceeding donor/acceptor cutoff distance
D_nonhydrogen = (N_i - O_j).square().sum(-1).add(self._eps).sqrt()
mask_ij_cutoff_D = (D_nonhydrogen < self.cutoff_distance).float()
# Mask hbonds on same chain with |i-j| < gap_cutoff
mask_ij_nonlocal = 1.0 - _locality_mask(C, edge_idx, cutoff=self.cutoff_gap)
# Ignore N terminal hydrogen bonding because of ambiguous hydrogen placement
C_prev = F.pad(C, [1, 0], "constant")[:, 1:]
mask_i = ((C > 0) * (C == C_prev)).float()
mask_j = collect_neighbors(C[..., None], edge_idx)[..., 0]
mask_ij_internal = mask_i[..., None] * (mask_j > 0).float()
mask_hb_ij = mask_ij * mask_ij_nonlocal * mask_ij_cutoff_D * mask_ij_internal
# Call hydrogen bonds
hbonds = mask_hb_ij * (U_ij < self.cutoff_energy).float()
return hbonds, mask_hb_ij, H_i
class LossBackboneHBonds(nn.Module):
"""Score hydrogen bond recovery from protein backbones.
Args:
See `BackboneHBonds`.
Inputs:
X (Tensor): Backbone coordinates to score with shape
`(num_batch, num_residues, 4, 3)`.
X_target (Tensor): Reference coordinates to compare to with shape
`(num_batch, num_residues, 4, 3)`.
C (LongTensor): Chain map tensor with shape `(num_batch, num_residues)`.
Outputs:
recovery_local (Tensor): Local hydrogen bond recovery with shape
`(num_batch)`.
recovery_nonlocal (Tensor): Nonlocal hydrogen bond recovery with shape
`(num_batch)`.
error_co (Tensor): Absolute error in terms of contact order recovery
"""
def __init__(
self,
cutoff_local: float = 8,
cutoff_energy: float = -0.5,
cutoff_distance: float = 3.6,
cutoff_gap: float = 3,
distance_eps: float = 1e-3,
num_neighbors: int = 30,
) -> None:
super(LossBackboneHBonds, self).__init__()
self.cutoff_local = cutoff_local
self.cutoff_energy = cutoff_energy
self.cutoff_distance = cutoff_distance
self.cutoff_gap = cutoff_gap
self._eps = 1e-3
self.graph_builder = protein_graph.ProteinGraph(num_neighbors=num_neighbors)
self.hbonds = BackboneHBonds(
cutoff_energy=cutoff_energy,
cutoff_distance=cutoff_distance,
cutoff_gap=cutoff_gap,
)
def forward(
self, X: torch.Tensor, X_target: torch.Tensor, C: torch.LongTensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Build Graph
edge_idx, mask_ij = self.graph_builder(X_target, C)
hb_target, mask_hb, H_i = self.hbonds(X_target, C, edge_idx, mask_ij)
hb_current, _, _ = self.hbonds(X, C, edge_idx, mask_ij)
# Split into local and long range hbonds
mask_local = _locality_mask(C, edge_idx, cutoff=self.cutoff_local)
hb_target_local = mask_local * hb_target
hb_target_nonlocal = (1 - mask_local) * hb_target
# Compute per complex
recovery_local = (hb_current * hb_target_local).sum([1, 2]) / (
hb_target_local.sum([1, 2]) + self._eps
)
recovery_nonlocal = (hb_current * hb_target_nonlocal).sum([1, 2]) / (
hb_target_nonlocal.sum([1, 2]) + self._eps
)
# Compute contact order
co_target = _contact_order(hb_target, C, edge_idx)
co_current = _contact_order(hb_current, C, edge_idx)
error_co = (co_target - co_current).abs()
return recovery_local, recovery_nonlocal, error_co
def _ij_distance(
C: torch.LongTensor, edge_idx: torch.LongTensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
C_i = C[..., None]
C_j = collect_neighbors(C_i, edge_idx)[..., 0]
ix = torch.arange(C.shape[1], device=C.device)[None, :, None].expand(
C.shape[0], -1, -1
)
jx = collect_neighbors(ix, edge_idx)[..., 0]
dij = (jx - ix).abs()
mask_same_chain = C_i.eq(C_j).float()
return dij, mask_same_chain
def _contact_order(
contacts: torch.Tensor,
C: torch.LongTensor,
edge_idx: torch.LongTensor,
eps: float = 1e-3,
) -> torch.Tensor:
"""Compute contact order"""
dij, mask_same_chain = _ij_distance(C, edge_idx)
mask_ij = mask_same_chain * contacts
CO = (mask_ij * dij).sum([1, 2]) / (mask_ij + eps).sum([1, 2])
return CO
def _locality_mask(
C: torch.LongTensor, edge_idx: torch.LongTensor, cutoff: float,
) -> torch.Tensor:
dij, mask_same_chain = _ij_distance(C, edge_idx)
mask_ij_local = ((dij < cutoff) * mask_same_chain).float()
return mask_ij_local