# Copyright Generate Biomedicines, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Layers for annotating hydrogen bonds in protein structures. """ from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F from chroma.layers.graph import collect_neighbors from chroma.layers.structure import protein_graph from chroma.layers.structure.geometry import normed_vec class BackboneHBonds(nn.Module): """Compute hydrogen bonds from protein backbones. We use the simple electrostatic model for calling hydrogen bonds of DSSP, which is described at https://en.wikipedia.org/wiki/DSSP_(algorithm). After placing virtual hydrogens on all backbone nitrogens, we consider potential hydrogen bonds with carbonyl groups on the backbone with residue distance |i-j| > 2. The picture is: -0.20e +0.20e -0.42e +0.42e [N_i]-----[H_i] ::: [O_j]=====[C_j] Args: cutoff_energy (float, optional): Cutoff energy with default value -0.5 (DSSP). cutoff_distance (float, optional): Max distance between `N_i` and `O_j` with default value 3.6 angstroms. cutoff_gap (float, optional): Minimum tolerated residue distance, i.e. `|i-j| >= cutoff_gap`. Default value of 3. Inputs: X (Tensor): Backbone coordinates with shape `(num_batch, num_residues, num_atom_types, 3)`. C (LongTensor): Chain map tensor with shape `(num_batch, num_residues)`. edge_idx (LongTensor): Edge indices for neighbors with shape `(num_batch, num_residues, num_neighbors)`. mask_ij (Tensor): Edge mask with shape `(num_batch, num_nodes, num_neighbors)`. Outputs: hbonds (Tensor): Binary matrix annotating backbone hydrogen bonds with shape `(num_batch, num_nodes, num_neighbors)`. mask_hb_ij (Tensor): Hydrogen bond mask with shape `(num_batch, num_nodes, num_neighbors)`. H_i (Tensor): Virtual hydrogen coordinates with shape `(num_batch, num_nodes, 3)`. """ def __init__( self, cutoff_energy: float = -0.5, cutoff_distance: float = 3.6, cutoff_gap: float = 3, distance_eps: float = 1e-3, ) -> None: super(BackboneHBonds, self).__init__() self.cutoff_energy = cutoff_energy self.cutoff_distance = cutoff_distance self.cutoff_gap = cutoff_gap self._coefficient = 0.42 * 0.2 * 332 self._eps = distance_eps # Lishan Yao et al. JACS 2008, NMR data self._length_NH = 1.015 return def forward( self, X: torch.Tensor, C: torch.LongTensor, edge_idx: torch.LongTensor, mask_ij: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: num_batch, num_residues, _, _ = X.shape # Collect coordinates at i and j X_flat = X.reshape([num_batch, num_residues, -1]) X_j_flat = collect_neighbors(X_flat, edge_idx) X_j = X_j_flat.reshape([num_batch, num_residues, -1, 4, 3]) # Get amide [N-H] atoms at i by # by placing virtual H from C_{i-1}-N-Ca neg bisector X_prev = F.pad(X, [0, 0, 0, 0, 1, 0], mode="replicate")[:, :-1, :, :] C_prev_i = X_prev[:, :, 2, :] N_i = X[:, :, 0, :] Ca_i = X[:, :, 1, :] u_CprevN_i = normed_vec(N_i - C_prev_i) u_CaN_i = normed_vec(N_i - Ca_i) u_NH_i = normed_vec(u_CprevN_i + u_CaN_i) H_i = N_i + self._length_NH * u_NH_i # Add broadcasting dimensions N_i = N_i[:, :, None, :] H_i = H_i[:, :, None, :] # Get carbonyl [C=O] atoms at j O_j = X_j[:, :, :, 3, :] C_j = X_j[:, :, :, 2, :] _invD = ( lambda Xi, Xj: (Xi - Xj).square().sum(-1).add(self._eps).sqrt().reciprocal() ) U_ij = self._coefficient * ( _invD(N_i, O_j) - _invD(N_i, C_j) + _invD(H_i, C_j) - _invD(H_i, O_j) ) # Mask any bonds exceeding donor/acceptor cutoff distance D_nonhydrogen = (N_i - O_j).square().sum(-1).add(self._eps).sqrt() mask_ij_cutoff_D = (D_nonhydrogen < self.cutoff_distance).float() # Mask hbonds on same chain with |i-j| < gap_cutoff mask_ij_nonlocal = 1.0 - _locality_mask(C, edge_idx, cutoff=self.cutoff_gap) # Ignore N terminal hydrogen bonding because of ambiguous hydrogen placement C_prev = F.pad(C, [1, 0], "constant")[:, 1:] mask_i = ((C > 0) * (C == C_prev)).float() mask_j = collect_neighbors(C[..., None], edge_idx)[..., 0] mask_ij_internal = mask_i[..., None] * (mask_j > 0).float() mask_hb_ij = mask_ij * mask_ij_nonlocal * mask_ij_cutoff_D * mask_ij_internal # Call hydrogen bonds hbonds = mask_hb_ij * (U_ij < self.cutoff_energy).float() return hbonds, mask_hb_ij, H_i class LossBackboneHBonds(nn.Module): """Score hydrogen bond recovery from protein backbones. Args: See `BackboneHBonds`. Inputs: X (Tensor): Backbone coordinates to score with shape `(num_batch, num_residues, 4, 3)`. X_target (Tensor): Reference coordinates to compare to with shape `(num_batch, num_residues, 4, 3)`. C (LongTensor): Chain map tensor with shape `(num_batch, num_residues)`. Outputs: recovery_local (Tensor): Local hydrogen bond recovery with shape `(num_batch)`. recovery_nonlocal (Tensor): Nonlocal hydrogen bond recovery with shape `(num_batch)`. error_co (Tensor): Absolute error in terms of contact order recovery """ def __init__( self, cutoff_local: float = 8, cutoff_energy: float = -0.5, cutoff_distance: float = 3.6, cutoff_gap: float = 3, distance_eps: float = 1e-3, num_neighbors: int = 30, ) -> None: super(LossBackboneHBonds, self).__init__() self.cutoff_local = cutoff_local self.cutoff_energy = cutoff_energy self.cutoff_distance = cutoff_distance self.cutoff_gap = cutoff_gap self._eps = 1e-3 self.graph_builder = protein_graph.ProteinGraph(num_neighbors=num_neighbors) self.hbonds = BackboneHBonds( cutoff_energy=cutoff_energy, cutoff_distance=cutoff_distance, cutoff_gap=cutoff_gap, ) def forward( self, X: torch.Tensor, X_target: torch.Tensor, C: torch.LongTensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Build Graph edge_idx, mask_ij = self.graph_builder(X_target, C) hb_target, mask_hb, H_i = self.hbonds(X_target, C, edge_idx, mask_ij) hb_current, _, _ = self.hbonds(X, C, edge_idx, mask_ij) # Split into local and long range hbonds mask_local = _locality_mask(C, edge_idx, cutoff=self.cutoff_local) hb_target_local = mask_local * hb_target hb_target_nonlocal = (1 - mask_local) * hb_target # Compute per complex recovery_local = (hb_current * hb_target_local).sum([1, 2]) / ( hb_target_local.sum([1, 2]) + self._eps ) recovery_nonlocal = (hb_current * hb_target_nonlocal).sum([1, 2]) / ( hb_target_nonlocal.sum([1, 2]) + self._eps ) # Compute contact order co_target = _contact_order(hb_target, C, edge_idx) co_current = _contact_order(hb_current, C, edge_idx) error_co = (co_target - co_current).abs() return recovery_local, recovery_nonlocal, error_co def _ij_distance( C: torch.LongTensor, edge_idx: torch.LongTensor, ) -> Tuple[torch.Tensor, torch.Tensor]: C_i = C[..., None] C_j = collect_neighbors(C_i, edge_idx)[..., 0] ix = torch.arange(C.shape[1], device=C.device)[None, :, None].expand( C.shape[0], -1, -1 ) jx = collect_neighbors(ix, edge_idx)[..., 0] dij = (jx - ix).abs() mask_same_chain = C_i.eq(C_j).float() return dij, mask_same_chain def _contact_order( contacts: torch.Tensor, C: torch.LongTensor, edge_idx: torch.LongTensor, eps: float = 1e-3, ) -> torch.Tensor: """Compute contact order""" dij, mask_same_chain = _ij_distance(C, edge_idx) mask_ij = mask_same_chain * contacts CO = (mask_ij * dij).sum([1, 2]) / (mask_ij + eps).sum([1, 2]) return CO def _locality_mask( C: torch.LongTensor, edge_idx: torch.LongTensor, cutoff: float, ) -> torch.Tensor: dij, mask_same_chain = _ij_distance(C, edge_idx) mask_ij_local = ((dij < cutoff) * mask_same_chain).float() return mask_ij_local