File size: 9,227 Bytes
ce7bf5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Layers for annotating hydrogen bonds in protein structures.
"""

from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from chroma.layers.graph import collect_neighbors
from chroma.layers.structure import protein_graph
from chroma.layers.structure.geometry import normed_vec


class BackboneHBonds(nn.Module):
    """Compute hydrogen bonds from protein backbones.

    We use the simple electrostatic model for calling hydrogen
    bonds of DSSP, which is described at
    https://en.wikipedia.org/wiki/DSSP_(algorithm). After
    placing virtual hydrogens on all backbone nitrogens,
    we consider potential hydrogen bonds with carbonyl groups
    on the backbone with residue distance |i-j| > 2. The
    picture is:

       -0.20e    +0.20e    -0.42e    +0.42e
        [N_i]-----[H_i] ::: [O_j]=====[C_j]

    Args:
        cutoff_energy (float, optional): Cutoff energy with
            default value -0.5 (DSSP).
        cutoff_distance (float, optional): Max distance
            between `N_i` and `O_j` with default value 3.6 angstroms.
        cutoff_gap (float, optional): Minimum tolerated residue
            distance, i.e. `|i-j| >= cutoff_gap`.
            Default value of 3.

    Inputs:
        X (Tensor): Backbone coordinates with shape
            `(num_batch, num_residues, num_atom_types, 3)`.
        C (LongTensor): Chain map tensor with shape `(num_batch, num_residues)`.
        edge_idx (LongTensor): Edge indices for neighbors with shape
            `(num_batch, num_residues, num_neighbors)`.
        mask_ij (Tensor): Edge mask with shape
             `(num_batch, num_nodes, num_neighbors)`.

    Outputs:
        hbonds (Tensor): Binary matrix annotating backbone hydrogen bonds
            with shape `(num_batch, num_nodes, num_neighbors)`.
        mask_hb_ij (Tensor): Hydrogen bond mask with shape
             `(num_batch, num_nodes, num_neighbors)`.
        H_i (Tensor): Virtual hydrogen coordinates with shape
            `(num_batch, num_nodes, 3)`.
    """

    def __init__(
        self,
        cutoff_energy: float = -0.5,
        cutoff_distance: float = 3.6,
        cutoff_gap: float = 3,
        distance_eps: float = 1e-3,
    ) -> None:
        super(BackboneHBonds, self).__init__()
        self.cutoff_energy = cutoff_energy
        self.cutoff_distance = cutoff_distance
        self.cutoff_gap = cutoff_gap
        self._coefficient = 0.42 * 0.2 * 332
        self._eps = distance_eps

        # Lishan Yao et al. JACS 2008, NMR data
        self._length_NH = 1.015
        return

    def forward(
        self,
        X: torch.Tensor,
        C: torch.LongTensor,
        edge_idx: torch.LongTensor,
        mask_ij: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        num_batch, num_residues, _, _ = X.shape
        # Collect coordinates at i and j
        X_flat = X.reshape([num_batch, num_residues, -1])
        X_j_flat = collect_neighbors(X_flat, edge_idx)
        X_j = X_j_flat.reshape([num_batch, num_residues, -1, 4, 3])

        # Get amide [N-H] atoms at i by
        # by placing virtual H from C_{i-1}-N-Ca neg bisector
        X_prev = F.pad(X, [0, 0, 0, 0, 1, 0], mode="replicate")[:, :-1, :, :]
        C_prev_i = X_prev[:, :, 2, :]
        N_i = X[:, :, 0, :]
        Ca_i = X[:, :, 1, :]
        u_CprevN_i = normed_vec(N_i - C_prev_i)
        u_CaN_i = normed_vec(N_i - Ca_i)
        u_NH_i = normed_vec(u_CprevN_i + u_CaN_i)
        H_i = N_i + self._length_NH * u_NH_i
        # Add broadcasting dimensions
        N_i = N_i[:, :, None, :]
        H_i = H_i[:, :, None, :]

        # Get carbonyl [C=O] atoms at j
        O_j = X_j[:, :, :, 3, :]
        C_j = X_j[:, :, :, 2, :]

        _invD = (
            lambda Xi, Xj: (Xi - Xj).square().sum(-1).add(self._eps).sqrt().reciprocal()
        )
        U_ij = self._coefficient * (
            _invD(N_i, O_j) - _invD(N_i, C_j) + _invD(H_i, C_j) - _invD(H_i, O_j)
        )

        # Mask any bonds exceeding donor/acceptor cutoff distance
        D_nonhydrogen = (N_i - O_j).square().sum(-1).add(self._eps).sqrt()
        mask_ij_cutoff_D = (D_nonhydrogen < self.cutoff_distance).float()

        # Mask hbonds on same chain with |i-j| < gap_cutoff
        mask_ij_nonlocal = 1.0 - _locality_mask(C, edge_idx, cutoff=self.cutoff_gap)

        # Ignore N terminal hydrogen bonding because of ambiguous hydrogen placement
        C_prev = F.pad(C, [1, 0], "constant")[:, 1:]
        mask_i = ((C > 0) * (C == C_prev)).float()
        mask_j = collect_neighbors(C[..., None], edge_idx)[..., 0]
        mask_ij_internal = mask_i[..., None] * (mask_j > 0).float()

        mask_hb_ij = mask_ij * mask_ij_nonlocal * mask_ij_cutoff_D * mask_ij_internal

        # Call hydrogen bonds
        hbonds = mask_hb_ij * (U_ij < self.cutoff_energy).float()
        return hbonds, mask_hb_ij, H_i


class LossBackboneHBonds(nn.Module):
    """Score hydrogen bond recovery from protein backbones.

    Args:
        See `BackboneHBonds`.

    Inputs:
        X (Tensor): Backbone coordinates to score with shape
            `(num_batch, num_residues, 4, 3)`.
        X_target (Tensor): Reference coordinates to compare to with shape
            `(num_batch, num_residues, 4, 3)`.
        C (LongTensor): Chain map tensor with shape `(num_batch, num_residues)`.

    Outputs:
        recovery_local (Tensor): Local hydrogen bond recovery with shape
            `(num_batch)`.
        recovery_nonlocal (Tensor): Nonlocal hydrogen bond recovery with shape
            `(num_batch)`.
        error_co (Tensor): Absolute error in terms of contact order recovery
    """

    def __init__(
        self,
        cutoff_local: float = 8,
        cutoff_energy: float = -0.5,
        cutoff_distance: float = 3.6,
        cutoff_gap: float = 3,
        distance_eps: float = 1e-3,
        num_neighbors: int = 30,
    ) -> None:
        super(LossBackboneHBonds, self).__init__()
        self.cutoff_local = cutoff_local
        self.cutoff_energy = cutoff_energy
        self.cutoff_distance = cutoff_distance
        self.cutoff_gap = cutoff_gap
        self._eps = 1e-3

        self.graph_builder = protein_graph.ProteinGraph(num_neighbors=num_neighbors)
        self.hbonds = BackboneHBonds(
            cutoff_energy=cutoff_energy,
            cutoff_distance=cutoff_distance,
            cutoff_gap=cutoff_gap,
        )

    def forward(
        self, X: torch.Tensor, X_target: torch.Tensor, C: torch.LongTensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # Build Graph
        edge_idx, mask_ij = self.graph_builder(X_target, C)
        hb_target, mask_hb, H_i = self.hbonds(X_target, C, edge_idx, mask_ij)
        hb_current, _, _ = self.hbonds(X, C, edge_idx, mask_ij)

        # Split into local and long range hbonds
        mask_local = _locality_mask(C, edge_idx, cutoff=self.cutoff_local)
        hb_target_local = mask_local * hb_target
        hb_target_nonlocal = (1 - mask_local) * hb_target

        # Compute per complex
        recovery_local = (hb_current * hb_target_local).sum([1, 2]) / (
            hb_target_local.sum([1, 2]) + self._eps
        )
        recovery_nonlocal = (hb_current * hb_target_nonlocal).sum([1, 2]) / (
            hb_target_nonlocal.sum([1, 2]) + self._eps
        )

        # Compute contact order
        co_target = _contact_order(hb_target, C, edge_idx)
        co_current = _contact_order(hb_current, C, edge_idx)

        error_co = (co_target - co_current).abs()
        return recovery_local, recovery_nonlocal, error_co


def _ij_distance(
    C: torch.LongTensor, edge_idx: torch.LongTensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    C_i = C[..., None]
    C_j = collect_neighbors(C_i, edge_idx)[..., 0]
    ix = torch.arange(C.shape[1], device=C.device)[None, :, None].expand(
        C.shape[0], -1, -1
    )
    jx = collect_neighbors(ix, edge_idx)[..., 0]
    dij = (jx - ix).abs()
    mask_same_chain = C_i.eq(C_j).float()
    return dij, mask_same_chain


def _contact_order(
    contacts: torch.Tensor,
    C: torch.LongTensor,
    edge_idx: torch.LongTensor,
    eps: float = 1e-3,
) -> torch.Tensor:
    """Compute contact order"""
    dij, mask_same_chain = _ij_distance(C, edge_idx)
    mask_ij = mask_same_chain * contacts
    CO = (mask_ij * dij).sum([1, 2]) / (mask_ij + eps).sum([1, 2])
    return CO


def _locality_mask(
    C: torch.LongTensor, edge_idx: torch.LongTensor, cutoff: float,
) -> torch.Tensor:
    dij, mask_same_chain = _ij_distance(C, edge_idx)
    mask_ij_local = ((dij < cutoff) * mask_same_chain).float()
    return mask_ij_local