File size: 24,486 Bytes
ce7bf5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import torch
import torch.linalg
import torch.nn as nn

from chroma.layers import graph
from chroma.layers.linalg import eig_leading
from chroma.layers.structure import geometry, protein_graph


class CrossRMSD(nn.Module):
    """Compute optimal RMSDs between two sets of structures.

    This module uses the quaternion-based approach for calculating RMSDs as
    described in `Using Quaternions to Calculate RMSD`, 2004, by Coutsias,
    Seok, and Dill. The minimal RMSD and associated rotation are computed in
    terms of the most positive eigenvalue and associated eigvector of a special
    4x4 matrix.

    Args:
        method (str, optional): Method for calculating the most postive
            eigenvalue. Can be `power` or `symeig`. If `symeig`, this will use
            `torch.symeig`, which is the most accurate method but tends to be
            very slow on GPU for large batches of RMSDs. If `power`, then use
            power iteration to estimate leading eigenvalues. Default is `power`.
        method_iter (int, optional): When the method is `power`, this argument
            sets the number of power iterations used for approximation.
            The default is 50, which has tended to produce estimates of optimal
            RMSD with sub-angstrom accuracy on test problems. Note: Convergence
            rates of power iteration can be highly variable dependening on the
            system. If accuracy is important, it is recommended to compare
            outputs with `symeig`-based RMSDs.
            当使用 "power" 方法时,此参数设置幂迭代的次数

    Inputs:
        X_mobile (Tensor): Mobile coordinates, i.e. the "mobile" coordinates,
            with shape `(num_source, num_atoms, 3)`.
        X_target (Tensor): Target coordinates with shape
            `(num_target, num_atoms, 3)`.

    Outputs:
        RMSD (Tensors): RMSDs after optimal superposition for all pairs of
            source and target structures with shape `(num_source, num_target)`.
            While `forward` returns the Cartesian product of all possible
            alignments, i.e. (`num_source * num_target` alignments), the
            `pairedRMSD` will do the same calculation for zipped batches, i.e.
            `num_source` total alignments.
    """
    """
    method:计算最大特征值的方法,可以是 "power" 或 "symeig"。
    method_iter:当使用 "power" 方法时,此参数设置幂迭代的次数。
    _eps:一个小的正数,用于避免除以零的错误。
    dither:一个布尔值,用于决定是否在计算中加入随机扰动。
    """

    def __init__(self, method="power", method_iter=50, dither=True):
        super(CrossRMSD, self).__init__()

        self.method = method
        self.method_iter = method_iter
        self._eps = 1e-5
        self.dither = dither

        # R_to_F converts xyz cross-covariance matrices (3x3) to the (4x4) F
        # matrix of Coutsias et al. This F matrix encodes the optimal RMSD in
        # its spectra; namely, the eigenvector associated with the most
        # positive eigenvalue of F is the quaternion encoding the optimal
        # 3D rotation for superposition.
        # fmt: off
        R_to_F = np.zeros((9, 16)).astype("f")
        F_nonzero = [
        [(0,0,1.),(1,1,1.),(2,2,1.)],            [(1,2,1.),(2,1,-1.)],            [(2,0,1.),(0,2,-1.)],            [(0,1,1.),(1,0,-1.)],
                [(1,2,1.),(2,1,-1.)],  [(0,0,1.),(1,1,-1.),(2,2,-1.)],             [(0,1,1.),(1,0,1.)],             [(0,2,1.),(2,0,1.)],
                [(2,0,1.),(0,2,-1.)],             [(0,1,1.),(1,0,1.)],  [(0,0,-1.),(1,1,1.),(2,2,-1.)],             [(1,2,1.),(2,1,1.)],
                [(0,1,1.),(1,0,-1.)],             [(0,2,1.),(2,0,1.)],             [(1,2,1.),(2,1,1.)],  [(0,0,-1.),(1,1,-1.),(2,2,1.)]
        ]
        # fmt: on

        for F_ij, nonzero in enumerate(F_nonzero):
            for R_i, R_j, sign in nonzero:
                R_to_F[R_i * 3 + R_j, F_ij] = sign
        self.register_buffer("R_to_F", torch.tensor(R_to_F))
    """
    在这个方法中,首先对坐标进行中心化处理,然后计算交叉协方差矩阵,
    R 展平并与 R_to_F 矩阵相乘得到 F 矩阵。
    之后,根据 method 参数选择的方法计算 F 矩阵的最大特征值,并使用这个特征值来计算 RMSD.
    """

    def forward(self, X_mobile, X_target):
        num_source = X_mobile.size(0)
        num_target = X_target.size(0)
        num_atoms = X_mobile.size(1)

        # Center coordinates
        X_mobile = X_mobile - X_mobile.mean(dim=1, keepdim=True)
        X_target = X_target - X_target.mean(dim=1, keepdim=True)

        # CrossCov matrices contract over atoms
        R = torch.einsum("sai,taj->stij", [X_mobile, X_target])

        # F Matrix has leading eigenvector as optimal quaternion
        R_flat = R.reshape(num_source, num_target, 9)
        F = torch.matmul(R_flat, self.R_to_F).reshape(num_source, num_target, 4, 4)

        # Compute optimal quaternion by extracting leading eigenvector
        if self.method == "symeig":
            top_eig = torch.linalg.eigvalsh(F)[:, :, 3]
        elif self.method == "power":
            top_eig, vec = eig_leading(F, num_iterations=self.method_iter)
        else:
            raise NotImplementedError

        # Compute RMSD in terms of RMSD using the scheme of Coutsias et al
        norms = (X_mobile ** 2).sum(dim=[-1, -2]).unsqueeze(1) + (X_target ** 2).sum(
            dim=[-1, -2]
        ).unsqueeze(0)
        sqRMSD = torch.relu((norms - 2 * top_eig) / (num_atoms + self._eps))
        RMSD = torch.sqrt(sqRMSD)
        return RMSD

    def pairedRMSD(
        self,
        X_mobile,
        X_target,
        mask=None,
        compute_alignment=False,
        align_unmasked=False,
    ):
        """Compute optimal RMSDs between each corresponding batch members.

        Args:
            X_mobile (Tensor): Mobile coordinates with shape
                `(..., num_atoms, 3)`.
            X_target (Tensor): Target coordinates with shape
                `(..., num_atoms, 3)`.
            mask (Tensor, optional): Binary mask tensor for missing atoms with
                shape `(..., num_atoms)`.
            compute_alignment (boolean, optional): If True, also return the
                superposed coordinates.

        Returns:
            RMSD (Tensors): Optimal RMSDs after superposition for all pairs of
                input structures with shape `(...)`.
            X_mobile_transform (Tensor, optional): Superposed coordinates with
                shape `(..., num_atoms, 3)`. Requires
                `compute_alignment` = True`.
        """
        # Collapse all leading batch dimensions
        num_atoms = X_mobile.size(-2)
        batch_dims = list(X_mobile.shape)[:-2]
        X_mobile = X_mobile.reshape([-1, num_atoms, 3])
        X_target = X_target.reshape([-1, num_atoms, 3])
        num_batch = X_mobile.size(0)
        if mask is not None:
            mask = mask.reshape([-1, num_atoms])

        # Center coordinates
        if mask is None:
            X_mobile_mean = X_mobile.mean(dim=1, keepdim=True)
            X_target_mean = X_target.mean(dim=1, keepdim=True)
        else:
            mask_expand = mask.unsqueeze(-1)
            X_mobile_mean = torch.sum(mask_expand * X_mobile, 1, keepdim=True) / (
                torch.sum(mask_expand, 1, keepdim=True) + self._eps
            )
            X_target_mean = torch.sum(mask_expand * X_target, 1, keepdim=True) / (
                torch.sum(mask_expand, 1, keepdim=True) + self._eps
            )

        X_mobile_center = X_mobile - X_mobile_mean
        X_target_center = X_target - X_target_mean

        if mask is not None:
            X_mobile_center = mask_expand * X_mobile_center
            X_target_center = mask_expand * X_target_center

        # Cross-covariance matrices contract over atoms
        R = torch.einsum("sai,saj->sij", [X_mobile_center, X_target_center])

        # F Matrix has leading eigenvector as optimal quaternion
        R_flat = R.reshape(num_batch, 9)
        R_to_F = self.R_to_F.type(R_flat.dtype)
        F = torch.matmul(R_flat, R_to_F).reshape(num_batch, 4, 4)
        if self.dither:
            F = F + 1e-5 * torch.randn_like(F)

        # Compute optimal quaternion by extracting leading eigenvector
        if self.method == "symeig":
            L, V = torch.linalg.eigh(F)
            top_eig = L[:, 3]
            vec = V[:, :, 3]
        elif self.method == "power":
            top_eig, vec = eig_leading(F, num_iterations=self.method_iter)
        else:
            raise NotImplementedError

        # Compute RMSD using top eigenvalue
        norms = (X_mobile_center ** 2).sum(dim=[-1, -2]) + (X_target_center ** 2).sum(
            dim=[-1, -2]
        )
        sqRMSD = torch.relu((norms - 2 * top_eig) / (num_atoms + self._eps))
        rmsd = torch.sqrt(sqRMSD)

        if not compute_alignment:
            # Unpack leading batch dimensions
            rmsd = rmsd.reshape(batch_dims)
            return rmsd
        else:
            R = geometry.rotations_from_quaternions(vec, normalize=False)

            X_mobile_transform = torch.einsum("bxr,bir->bix", R, X_mobile_center)
            X_mobile_transform = X_mobile_transform + X_target_mean

            if mask is not None:
                X_mobile_transform = mask_expand * X_mobile_transform

            # Return the RMSD of the transformed coordinates
            rmsd_direct = rmsd_unaligned(X_mobile_transform, X_target, mask)

            # Unpack leading batch dimensions
            rmsd_direct = rmsd_direct.reshape(batch_dims)
            X_mobile_transform = X_mobile_transform.reshape(batch_dims + [num_atoms, 3])
            if align_unmasked:
                X_mobile_transform = X_mobile - X_mobile_mean
                X_mobile_transform = torch.einsum(
                    "bxr, bir -> bix",
                    R,
                    X_mobile_transform.view(X_mobile.size(0), -1, 3),
                )
                X_mobile_transform = X_mobile_transform + X_target_mean

            return rmsd_direct, X_mobile_transform


class BackboneRMSD(nn.Module):
    """Compute optimal RMSDs between two sets of backbones.

    This wraps `CrossRMSD` for use with XCS-formatted protein data.

    Args:
        method (str, optional): Method for calculating the most postive
            eigenvalue. Can be `power` or `symeig`. Default is `power`.
        method_iter (int, optional): Number of power iterations for eigenvalue
            approximation. Requires `method=power`. Default is 50.

    Inputs:
        X_mobile (Tensor): Mobile coordinates with shape
            `(num_source, num_atoms, 4, 3)`.
        X_target (Tensor): Target coordinates with shape
            `(num_target, num_atoms, 4, 3)`.
        C (Tensor): Chain map with shape `(num_batch, num_residues)`.

    Outputs:
        X_aligned (Tensor, optional): Superposed `X_mobile` with shape
            `(num_batch, num_atoms, 3)`.
        rmsd (Tensors): Optimal RMSDs after superposition with shape
            `(num_batch)`.
    """

    def __init__(self, method="symeig"):
        super(BackboneRMSD, self).__init__()
        self.rmsd = CrossRMSD(method=method)
    """
    在 align 方法中,首先根据链映射 C 创建一个掩码 mask。这个掩码用于确定蛋白质中哪些部分将被用于对齐计算。
    接着,将输入的蛋白质坐标 X_mobile 和 X_target 重塑为适合 RMSD 计算的格式。
    然后,使用 CrossRMSD 实例的 pairedRMSD 方法计算 RMSD 并获取对齐后的坐标。
    最后,将对齐后的坐标重新塑形为原始蛋白质坐标的格式并返回.
    """
    def align(self, X_mobile, X_target, C, align_unmasked=False):
        mask = (C > 0).type(torch.float32)
        mask_flat = mask.unsqueeze(-1).expand(-1, -1, 4).reshape(mask.shape[0], -1)

        X_mobile_flat = X_mobile.reshape(X_mobile.size(0), -1, 3)
        X_target_flat = X_target.reshape(X_target.size(0), -1, 3)
        rmsd, X_aligned = self.rmsd.pairedRMSD(
            X_mobile_flat,
            X_target_flat,
            mask=mask_flat,
            compute_alignment=True,
            align_unmasked=align_unmasked,
        )
        X_aligned = X_aligned.reshape(X_mobile.size()).contiguous()
        return X_aligned, rmsd


class LossFragmentRMSD(nn.Module):
    """Compute optimal fragment-pair RMSDs between two sets of backbones.

    Args:
        fragment_k (int, option): Fram
        method (str, optional): Method for calculating the most postive
            eigenvalue. Can be `power` or `symeig`. Default is `power`.
        method_iter (int, optional): Number of power iterations for eigenvalue
            approximation. Requires `method=power`. Default is 50.

    Inputs:
        X_mobile (Tensor): Mobile coordinates with shape
            `(num_source, num_atoms, 4, 3)`.
        X_target (Tensor): Target coordinates with shape
            `(num_target, num_atoms, 4, 3)`.
        edge_idx
        C (Tensor): Chain map with shape `(num_batch, num_residues)`.

    Outputs:
        rmsd (Tensor, optional): Per-site fragment RMSDs with shape
            `(num_batch)`.
    """

    def __init__(self, k=7, method="symeig", method_iter=50):
        super(LossFragmentRMSD, self).__init__()
        self.k = k
        self.rmsd = CrossRMSD(method=method, method_iter=method_iter)

    """
    X_mobile 和 X_target:分别表示待对齐的蛋白质和目标蛋白质的坐标。
    C:表示链映射,用于确定蛋白质中哪些残基(residues)应该被考虑在对齐过程中。
    return_coords:一个布尔值,指示是否返回对齐后的坐标。

    在 forward 方法中,首先将输入的蛋白质坐标 X_mobile 和 X_target 限制在背骨原子上。
    然后,使用 _collect_X_fragments 函数(这个函数没有在代码中定义,可能是在其他地方定义的)从每个蛋白质中收集片段,并根据链映射 C 创建掩码。
    之后,使用 CrossRMSD 实例的 pairedRMSD 方法计算每个片段对的 RMSD,并根据 return_coords 参数决定是否返回对齐后的坐标.
    """

    def forward(self, X_mobile, X_target, C, return_coords=False):
        # Discard potential sidechain coordinates
        X_mobile = X_mobile[:, :, :4, :]
        X_target = X_target[:, :, :4, :]

        # Build graph and pair fragments

        X_fragment_mobile, C_fragment_mobile = _collect_X_fragments(X_mobile, C, self.k)
        X_fragment_target, C_fragment_target = _collect_X_fragments(X_target, C, self.k)
        shape = list(C.shape) + [-1, 3]
        X_fragment_mobile = X_fragment_mobile.reshape(shape)
        X_fragment_target = X_fragment_target.reshape(shape)

        mask = (C_fragment_mobile > 0).float()
        rmsd, X_fragment_mobile_align = self.rmsd.pairedRMSD(
            X_fragment_mobile, X_fragment_target, mask, compute_alignment=True
        )
        if return_coords:
            return rmsd, X_fragment_target, X_fragment_mobile, X_fragment_mobile_align
        else:
            return rmsd


class LossFragmentPairRMSD(nn.Module):
    """Compute optimal fragment-pair RMSDs between two sets of backbones.

    Args:
        fragment_k (int, option): Fram
        method (str, optional): Method for calculating the most postive
            eigenvalue. Can be `power` or `symeig`. Default is `power`.
        method_iter (int, optional): Number of power iterations for eigenvalue
            approximation. Requires `method=power`. Default is 50.

    Inputs:
        X_mobile (Tensor): Mobile coordinates with shape
            `(num_source, num_atoms, 4, 3)`.
        X_target (Tensor): Target coordinates with shape
            `(num_target, num_atoms, 4, 3)`.
        edge_idx
        C (Tensor): Chain map with shape `(num_batch, num_residues)`.

    Outputs:
        rmsd (Tensor, optional): Per-site fragment RMSDs with shape
            `(num_batch)`.
    """

    def __init__(self, k=7, method="symeig", method_iter=50, graph_num_neighbors=30):
        super(LossFragmentPairRMSD, self).__init__()
        self.k = k
        self.rmsd = CrossRMSD(method=method, method_iter=method_iter)
        self.graph_builder = protein_graph.ProteinGraph(
            num_neighbors=graph_num_neighbors
        )

    def _stack_neighbor(self, node_h, edge_idx):
        neighbor_h = graph.collect_neighbors(node_h, edge_idx)
        node_h = node_h[:, :, None, :].expand(neighbor_h.shape)
        edge_h = torch.cat([neighbor_h, node_h], dim=-1)
        return edge_h

    def _collect_X_fragment_pairs(self, X, C, edge_idx):
        X_kmer, C_kmer = _collect_X_fragments(X, C, self.k)
        X_pair = self._stack_neighbor(X_kmer, edge_idx)
        C_pair = self._stack_neighbor(C_kmer, edge_idx)
        X_pair = X_pair.reshape(list(X_pair.shape)[:-1] + [-1, 3])
        return X_pair, C_pair

    def forward(self, X_mobile, X_target, C, return_coords=False):
        # Discard potential sidechain coordinates
        X_mobile = X_mobile[:, :, :4, :]
        X_target = X_target[:, :, :4, :]

        # Build graph and pair fragments
        edge_idx, mask_ij = self.graph_builder(X_target, C)
        X_pair_mobile, C_pair_mobile = self._collect_X_fragment_pairs(
            X_mobile, C, edge_idx
        )
        X_pair_target, C_pair_target = self._collect_X_fragment_pairs(
            X_target, C, edge_idx
        )

        mask = (C_pair_mobile > 0).float()

        rmsd, X_pair_mobile_align = self.rmsd.pairedRMSD(
            X_pair_mobile, X_pair_target, mask, compute_alignment=True
        )
        if return_coords:
            return rmsd, mask_ij, X_pair_target, X_pair_mobile, X_pair_mobile_align
        else:
            return rmsd, mask_ij


class LossNeighborhoodRMSD(nn.Module):
    """Compute optimal fragment-pair RMSDs between two sets of backbones.

    Args:
        fragment_k (int, option): Fram
        method (str, optional): Method for calculating the most postive
            eigenvalue. Can be `power` or `symeig`. Default is `power`.
        method_iter (int, optional): Number of power iterations for eigenvalue
            approximation. Requires `method=power`. Default is 50.

    Inputs:
        X_mobile (Tensor): Mobile coordinates with shape
            `(num_source, num_atoms, 4, 3)`.
        X_target (Tensor): Target coordinates with shape
            `(num_target, num_atoms, 4, 3)`.
        edge_idx
        C (Tensor): Chain map with shape `(num_batch, num_residues)`.

    Outputs:
        rmsd (Tensor, optional): Per-site fragment RMSDs with shape
            `(num_batch)`.
    """

    def __init__(self, method="symeig", method_iter=50, graph_num_neighbors=30):
        super(LossNeighborhoodRMSD, self).__init__()
        self.rmsd = CrossRMSD(method=method, method_iter=method_iter)
        self.graph_builder = protein_graph.ProteinGraph(
            num_neighbors=graph_num_neighbors
        )

    def _collect_X_neighborhood(self, X, C, edge_idx):
        num_batch, num_nodes, num_atoms, _ = X.shape
        shape_flat = [num_batch, num_nodes, -1]
        X_flat = X.reshape(shape_flat)
        C_flat = C[..., None].expand([-1, -1, num_atoms])
        X_neighborhood = graph.collect_neighbors(X_flat, edge_idx).reshape(
            [num_batch, num_nodes, -1, 3]
        )
        C_neighborhood = graph.collect_neighbors(C_flat, edge_idx).reshape(
            [num_batch, num_nodes, -1]
        )
        return X_neighborhood, C_neighborhood

    def forward(self, X_mobile, X_target, C, return_coords=False):
        # Discard potential sidechain coordinates
        X_mobile = X_mobile[:, :, :4, :]
        X_target = X_target[:, :, :4, :]

        # Build graph and pair fragments
        edge_idx, mask_ij = self.graph_builder(X_target, C)
        X_neighborhood_mobile, C_neighborhood_mobile = self._collect_X_neighborhood(
            X_mobile, C, edge_idx
        )
        X_neighborhood_target, C_neighborhood_target = self._collect_X_neighborhood(
            X_target, C, edge_idx
        )
        mask = (C_neighborhood_mobile > 0).float()

        rmsd, X_neighborhood_mobile_align = self.rmsd.pairedRMSD(
            X_neighborhood_mobile, X_neighborhood_target, mask, compute_alignment=True
        )
        mask = (mask.sum(-1) > 0).float()
        if return_coords:
            return (
                rmsd,
                mask,
                X_neighborhood_target,
                X_neighborhood_mobile,
                X_neighborhood_mobile_align,
            )
        else:
            return rmsd, mask


def rmsd_unaligned(X_a, X_b, mask=None, eps=1e-5, _min_rmsd=1e-8):
    """Compute RMSD between two coordinate sets without alignment.

    Args:
        X_a (Tensor): Coordinate set 1 with shape `(..., num_points, 3)`.
        X_b (Tensor): Coordinate set 2 with shape `(..., num_points, 3)`.
        mask (Tensor, optional): Mask with shape `(..., num_points)`.
        eps (float, optional): Small number to prevent division by zero.
            default is 1E-5.

    Returns:
        rmsd (Tensor): Root mean squared deviations (raw) with shape `(...)`.
    """
    squared_dev = ((X_a - X_b) ** 2).sum(-1)
    if mask is None:
        rmsd = torch.sqrt(squared_dev.mean(-1).clamp(min=_min_rmsd))
    else:
        rmsd = torch.sqrt(
            (mask * squared_dev).sum(-1).clamp(min=_min_rmsd) / (mask.sum(-1) + eps)
        )
    return rmsd

"""
这两个函数是处理蛋白质结构数据的关键部分,特别是在需要从蛋白质结构中提取和分析特定长度片段的情况下。
_collect_X_fragments 函数处理蛋白质的坐标和链映射信息,以收集和处理特定长度的片段,
而 _collect_kmers 函数则是一个更通用的工具,用于从任何给定的节点特征矩阵中收集 k-mers.

_collect_X_fragments:
函数首先将 X 和 C 转换为扁平形状。
然后,使用 _collect_kmers 函数从 X_flat 和 C_flat 中收集 k-mers,这些 k-mers 本质上是局部的、长度为 k 的片段。
最后,函数使用 torch.where 来处理非连续原子,将它们视为缺失,并返回处理后的 X_kmer 和 C_kmer。

_collect_kmers:
函数的主要步骤包括:

构建索引以定位 k-mers。首先,创建一个长度为 k 的索引数组 k_idx。
然后,使用这个索引和节点的索引 node_idx 生成 k-mers 的索引 kmer_idx。
使用 kmer_idx 从 node_h 中收集相邻节点的特征,形成新的 k-mer 特征矩阵 kmer_h。


这个函数的关键在于它能够从原始的节点特征矩阵中构建出包含局部邻居信息的新矩阵,这对于处理基于图的结构(如蛋白质结构)特别有用。

"""

def _collect_X_fragments(X, C, k):
    num_batch, num_nodes, num_atoms, _ = X.shape
    shape_flat = [num_batch, num_nodes, -1]
    X_flat = X.reshape(shape_flat)
    C_flat = C[..., None].expand([-1, -1, num_atoms])

    # Grab local kmers
    X_kmer = _collect_kmers(X_flat, k).reshape(shape_flat)
    C_kmer = _collect_kmers(C_flat, k).reshape(shape_flat)

    # Treat noncontiguous atoms as missing
    C_kmer = torch.where(C[..., None].eq(C_kmer), C_kmer, -C_kmer.abs())
    return X_kmer, C_kmer


def _collect_kmers(node_h, k):
    """Gather `(B,I,H) => (B,I,K,H)`"""
    device = node_h.device
    num_batch, num_nodes, _ = node_h.shape

    # Build indices
    k_idx = torch.arange(k, device=device) - (k - 1) // 2
    node_idx = torch.arange(node_h.shape[1], device=device)
    kmer_idx = node_idx[None, :, None] - k_idx[None, None, :]
    kmer_idx = kmer_idx.clamp(min=0, max=num_nodes - 1).long()
    kmer_idx = kmer_idx.expand([num_batch, -1, k])

    # Collect neighbors
    kmer_h = graph.collect_neighbors(node_h, kmer_idx)
    return kmer_h