Spaces:
Sleeping
Sleeping
# Copyright Generate Biomedicines, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from __future__ import print_function | |
import numpy as np | |
import torch | |
import torch.linalg | |
import torch.nn as nn | |
from chroma.layers import graph | |
from chroma.layers.linalg import eig_leading | |
from chroma.layers.structure import geometry, protein_graph | |
class CrossRMSD(nn.Module): | |
"""Compute optimal RMSDs between two sets of structures. | |
This module uses the quaternion-based approach for calculating RMSDs as | |
described in `Using Quaternions to Calculate RMSD`, 2004, by Coutsias, | |
Seok, and Dill. The minimal RMSD and associated rotation are computed in | |
terms of the most positive eigenvalue and associated eigvector of a special | |
4x4 matrix. | |
Args: | |
method (str, optional): Method for calculating the most postive | |
eigenvalue. Can be `power` or `symeig`. If `symeig`, this will use | |
`torch.symeig`, which is the most accurate method but tends to be | |
very slow on GPU for large batches of RMSDs. If `power`, then use | |
power iteration to estimate leading eigenvalues. Default is `power`. | |
method_iter (int, optional): When the method is `power`, this argument | |
sets the number of power iterations used for approximation. | |
The default is 50, which has tended to produce estimates of optimal | |
RMSD with sub-angstrom accuracy on test problems. Note: Convergence | |
rates of power iteration can be highly variable dependening on the | |
system. If accuracy is important, it is recommended to compare | |
outputs with `symeig`-based RMSDs. | |
当使用 "power" 方法时,此参数设置幂迭代的次数 | |
Inputs: | |
X_mobile (Tensor): Mobile coordinates, i.e. the "mobile" coordinates, | |
with shape `(num_source, num_atoms, 3)`. | |
X_target (Tensor): Target coordinates with shape | |
`(num_target, num_atoms, 3)`. | |
Outputs: | |
RMSD (Tensors): RMSDs after optimal superposition for all pairs of | |
source and target structures with shape `(num_source, num_target)`. | |
While `forward` returns the Cartesian product of all possible | |
alignments, i.e. (`num_source * num_target` alignments), the | |
`pairedRMSD` will do the same calculation for zipped batches, i.e. | |
`num_source` total alignments. | |
""" | |
""" | |
method:计算最大特征值的方法,可以是 "power" 或 "symeig"。 | |
method_iter:当使用 "power" 方法时,此参数设置幂迭代的次数。 | |
_eps:一个小的正数,用于避免除以零的错误。 | |
dither:一个布尔值,用于决定是否在计算中加入随机扰动。 | |
""" | |
def __init__(self, method="power", method_iter=50, dither=True): | |
super(CrossRMSD, self).__init__() | |
self.method = method | |
self.method_iter = method_iter | |
self._eps = 1e-5 | |
self.dither = dither | |
# R_to_F converts xyz cross-covariance matrices (3x3) to the (4x4) F | |
# matrix of Coutsias et al. This F matrix encodes the optimal RMSD in | |
# its spectra; namely, the eigenvector associated with the most | |
# positive eigenvalue of F is the quaternion encoding the optimal | |
# 3D rotation for superposition. | |
# fmt: off | |
R_to_F = np.zeros((9, 16)).astype("f") | |
F_nonzero = [ | |
[(0,0,1.),(1,1,1.),(2,2,1.)], [(1,2,1.),(2,1,-1.)], [(2,0,1.),(0,2,-1.)], [(0,1,1.),(1,0,-1.)], | |
[(1,2,1.),(2,1,-1.)], [(0,0,1.),(1,1,-1.),(2,2,-1.)], [(0,1,1.),(1,0,1.)], [(0,2,1.),(2,0,1.)], | |
[(2,0,1.),(0,2,-1.)], [(0,1,1.),(1,0,1.)], [(0,0,-1.),(1,1,1.),(2,2,-1.)], [(1,2,1.),(2,1,1.)], | |
[(0,1,1.),(1,0,-1.)], [(0,2,1.),(2,0,1.)], [(1,2,1.),(2,1,1.)], [(0,0,-1.),(1,1,-1.),(2,2,1.)] | |
] | |
# fmt: on | |
for F_ij, nonzero in enumerate(F_nonzero): | |
for R_i, R_j, sign in nonzero: | |
R_to_F[R_i * 3 + R_j, F_ij] = sign | |
self.register_buffer("R_to_F", torch.tensor(R_to_F)) | |
""" | |
在这个方法中,首先对坐标进行中心化处理,然后计算交叉协方差矩阵, | |
R 展平并与 R_to_F 矩阵相乘得到 F 矩阵。 | |
之后,根据 method 参数选择的方法计算 F 矩阵的最大特征值,并使用这个特征值来计算 RMSD. | |
""" | |
def forward(self, X_mobile, X_target): | |
num_source = X_mobile.size(0) | |
num_target = X_target.size(0) | |
num_atoms = X_mobile.size(1) | |
# Center coordinates | |
X_mobile = X_mobile - X_mobile.mean(dim=1, keepdim=True) | |
X_target = X_target - X_target.mean(dim=1, keepdim=True) | |
# CrossCov matrices contract over atoms | |
R = torch.einsum("sai,taj->stij", [X_mobile, X_target]) | |
# F Matrix has leading eigenvector as optimal quaternion | |
R_flat = R.reshape(num_source, num_target, 9) | |
F = torch.matmul(R_flat, self.R_to_F).reshape(num_source, num_target, 4, 4) | |
# Compute optimal quaternion by extracting leading eigenvector | |
if self.method == "symeig": | |
top_eig = torch.linalg.eigvalsh(F)[:, :, 3] | |
elif self.method == "power": | |
top_eig, vec = eig_leading(F, num_iterations=self.method_iter) | |
else: | |
raise NotImplementedError | |
# Compute RMSD in terms of RMSD using the scheme of Coutsias et al | |
norms = (X_mobile ** 2).sum(dim=[-1, -2]).unsqueeze(1) + (X_target ** 2).sum( | |
dim=[-1, -2] | |
).unsqueeze(0) | |
sqRMSD = torch.relu((norms - 2 * top_eig) / (num_atoms + self._eps)) | |
RMSD = torch.sqrt(sqRMSD) | |
return RMSD | |
def pairedRMSD( | |
self, | |
X_mobile, | |
X_target, | |
mask=None, | |
compute_alignment=False, | |
align_unmasked=False, | |
): | |
"""Compute optimal RMSDs between each corresponding batch members. | |
Args: | |
X_mobile (Tensor): Mobile coordinates with shape | |
`(..., num_atoms, 3)`. | |
X_target (Tensor): Target coordinates with shape | |
`(..., num_atoms, 3)`. | |
mask (Tensor, optional): Binary mask tensor for missing atoms with | |
shape `(..., num_atoms)`. | |
compute_alignment (boolean, optional): If True, also return the | |
superposed coordinates. | |
Returns: | |
RMSD (Tensors): Optimal RMSDs after superposition for all pairs of | |
input structures with shape `(...)`. | |
X_mobile_transform (Tensor, optional): Superposed coordinates with | |
shape `(..., num_atoms, 3)`. Requires | |
`compute_alignment` = True`. | |
""" | |
# Collapse all leading batch dimensions | |
num_atoms = X_mobile.size(-2) | |
batch_dims = list(X_mobile.shape)[:-2] | |
X_mobile = X_mobile.reshape([-1, num_atoms, 3]) | |
X_target = X_target.reshape([-1, num_atoms, 3]) | |
num_batch = X_mobile.size(0) | |
if mask is not None: | |
mask = mask.reshape([-1, num_atoms]) | |
# Center coordinates | |
if mask is None: | |
X_mobile_mean = X_mobile.mean(dim=1, keepdim=True) | |
X_target_mean = X_target.mean(dim=1, keepdim=True) | |
else: | |
mask_expand = mask.unsqueeze(-1) | |
X_mobile_mean = torch.sum(mask_expand * X_mobile, 1, keepdim=True) / ( | |
torch.sum(mask_expand, 1, keepdim=True) + self._eps | |
) | |
X_target_mean = torch.sum(mask_expand * X_target, 1, keepdim=True) / ( | |
torch.sum(mask_expand, 1, keepdim=True) + self._eps | |
) | |
X_mobile_center = X_mobile - X_mobile_mean | |
X_target_center = X_target - X_target_mean | |
if mask is not None: | |
X_mobile_center = mask_expand * X_mobile_center | |
X_target_center = mask_expand * X_target_center | |
# Cross-covariance matrices contract over atoms | |
R = torch.einsum("sai,saj->sij", [X_mobile_center, X_target_center]) | |
# F Matrix has leading eigenvector as optimal quaternion | |
R_flat = R.reshape(num_batch, 9) | |
R_to_F = self.R_to_F.type(R_flat.dtype) | |
F = torch.matmul(R_flat, R_to_F).reshape(num_batch, 4, 4) | |
if self.dither: | |
F = F + 1e-5 * torch.randn_like(F) | |
# Compute optimal quaternion by extracting leading eigenvector | |
if self.method == "symeig": | |
L, V = torch.linalg.eigh(F) | |
top_eig = L[:, 3] | |
vec = V[:, :, 3] | |
elif self.method == "power": | |
top_eig, vec = eig_leading(F, num_iterations=self.method_iter) | |
else: | |
raise NotImplementedError | |
# Compute RMSD using top eigenvalue | |
norms = (X_mobile_center ** 2).sum(dim=[-1, -2]) + (X_target_center ** 2).sum( | |
dim=[-1, -2] | |
) | |
sqRMSD = torch.relu((norms - 2 * top_eig) / (num_atoms + self._eps)) | |
rmsd = torch.sqrt(sqRMSD) | |
if not compute_alignment: | |
# Unpack leading batch dimensions | |
rmsd = rmsd.reshape(batch_dims) | |
return rmsd | |
else: | |
R = geometry.rotations_from_quaternions(vec, normalize=False) | |
X_mobile_transform = torch.einsum("bxr,bir->bix", R, X_mobile_center) | |
X_mobile_transform = X_mobile_transform + X_target_mean | |
if mask is not None: | |
X_mobile_transform = mask_expand * X_mobile_transform | |
# Return the RMSD of the transformed coordinates | |
rmsd_direct = rmsd_unaligned(X_mobile_transform, X_target, mask) | |
# Unpack leading batch dimensions | |
rmsd_direct = rmsd_direct.reshape(batch_dims) | |
X_mobile_transform = X_mobile_transform.reshape(batch_dims + [num_atoms, 3]) | |
if align_unmasked: | |
X_mobile_transform = X_mobile - X_mobile_mean | |
X_mobile_transform = torch.einsum( | |
"bxr, bir -> bix", | |
R, | |
X_mobile_transform.view(X_mobile.size(0), -1, 3), | |
) | |
X_mobile_transform = X_mobile_transform + X_target_mean | |
return rmsd_direct, X_mobile_transform | |
class BackboneRMSD(nn.Module): | |
"""Compute optimal RMSDs between two sets of backbones. | |
This wraps `CrossRMSD` for use with XCS-formatted protein data. | |
Args: | |
method (str, optional): Method for calculating the most postive | |
eigenvalue. Can be `power` or `symeig`. Default is `power`. | |
method_iter (int, optional): Number of power iterations for eigenvalue | |
approximation. Requires `method=power`. Default is 50. | |
Inputs: | |
X_mobile (Tensor): Mobile coordinates with shape | |
`(num_source, num_atoms, 4, 3)`. | |
X_target (Tensor): Target coordinates with shape | |
`(num_target, num_atoms, 4, 3)`. | |
C (Tensor): Chain map with shape `(num_batch, num_residues)`. | |
Outputs: | |
X_aligned (Tensor, optional): Superposed `X_mobile` with shape | |
`(num_batch, num_atoms, 3)`. | |
rmsd (Tensors): Optimal RMSDs after superposition with shape | |
`(num_batch)`. | |
""" | |
def __init__(self, method="symeig"): | |
super(BackboneRMSD, self).__init__() | |
self.rmsd = CrossRMSD(method=method) | |
""" | |
在 align 方法中,首先根据链映射 C 创建一个掩码 mask。这个掩码用于确定蛋白质中哪些部分将被用于对齐计算。 | |
接着,将输入的蛋白质坐标 X_mobile 和 X_target 重塑为适合 RMSD 计算的格式。 | |
然后,使用 CrossRMSD 实例的 pairedRMSD 方法计算 RMSD 并获取对齐后的坐标。 | |
最后,将对齐后的坐标重新塑形为原始蛋白质坐标的格式并返回. | |
""" | |
def align(self, X_mobile, X_target, C, align_unmasked=False): | |
mask = (C > 0).type(torch.float32) | |
mask_flat = mask.unsqueeze(-1).expand(-1, -1, 4).reshape(mask.shape[0], -1) | |
X_mobile_flat = X_mobile.reshape(X_mobile.size(0), -1, 3) | |
X_target_flat = X_target.reshape(X_target.size(0), -1, 3) | |
rmsd, X_aligned = self.rmsd.pairedRMSD( | |
X_mobile_flat, | |
X_target_flat, | |
mask=mask_flat, | |
compute_alignment=True, | |
align_unmasked=align_unmasked, | |
) | |
X_aligned = X_aligned.reshape(X_mobile.size()).contiguous() | |
return X_aligned, rmsd | |
class LossFragmentRMSD(nn.Module): | |
"""Compute optimal fragment-pair RMSDs between two sets of backbones. | |
Args: | |
fragment_k (int, option): Fram | |
method (str, optional): Method for calculating the most postive | |
eigenvalue. Can be `power` or `symeig`. Default is `power`. | |
method_iter (int, optional): Number of power iterations for eigenvalue | |
approximation. Requires `method=power`. Default is 50. | |
Inputs: | |
X_mobile (Tensor): Mobile coordinates with shape | |
`(num_source, num_atoms, 4, 3)`. | |
X_target (Tensor): Target coordinates with shape | |
`(num_target, num_atoms, 4, 3)`. | |
edge_idx | |
C (Tensor): Chain map with shape `(num_batch, num_residues)`. | |
Outputs: | |
rmsd (Tensor, optional): Per-site fragment RMSDs with shape | |
`(num_batch)`. | |
""" | |
def __init__(self, k=7, method="symeig", method_iter=50): | |
super(LossFragmentRMSD, self).__init__() | |
self.k = k | |
self.rmsd = CrossRMSD(method=method, method_iter=method_iter) | |
""" | |
X_mobile 和 X_target:分别表示待对齐的蛋白质和目标蛋白质的坐标。 | |
C:表示链映射,用于确定蛋白质中哪些残基(residues)应该被考虑在对齐过程中。 | |
return_coords:一个布尔值,指示是否返回对齐后的坐标。 | |
在 forward 方法中,首先将输入的蛋白质坐标 X_mobile 和 X_target 限制在背骨原子上。 | |
然后,使用 _collect_X_fragments 函数(这个函数没有在代码中定义,可能是在其他地方定义的)从每个蛋白质中收集片段,并根据链映射 C 创建掩码。 | |
之后,使用 CrossRMSD 实例的 pairedRMSD 方法计算每个片段对的 RMSD,并根据 return_coords 参数决定是否返回对齐后的坐标. | |
""" | |
def forward(self, X_mobile, X_target, C, return_coords=False): | |
# Discard potential sidechain coordinates | |
X_mobile = X_mobile[:, :, :4, :] | |
X_target = X_target[:, :, :4, :] | |
# Build graph and pair fragments | |
X_fragment_mobile, C_fragment_mobile = _collect_X_fragments(X_mobile, C, self.k) | |
X_fragment_target, C_fragment_target = _collect_X_fragments(X_target, C, self.k) | |
shape = list(C.shape) + [-1, 3] | |
X_fragment_mobile = X_fragment_mobile.reshape(shape) | |
X_fragment_target = X_fragment_target.reshape(shape) | |
mask = (C_fragment_mobile > 0).float() | |
rmsd, X_fragment_mobile_align = self.rmsd.pairedRMSD( | |
X_fragment_mobile, X_fragment_target, mask, compute_alignment=True | |
) | |
if return_coords: | |
return rmsd, X_fragment_target, X_fragment_mobile, X_fragment_mobile_align | |
else: | |
return rmsd | |
class LossFragmentPairRMSD(nn.Module): | |
"""Compute optimal fragment-pair RMSDs between two sets of backbones. | |
Args: | |
fragment_k (int, option): Fram | |
method (str, optional): Method for calculating the most postive | |
eigenvalue. Can be `power` or `symeig`. Default is `power`. | |
method_iter (int, optional): Number of power iterations for eigenvalue | |
approximation. Requires `method=power`. Default is 50. | |
Inputs: | |
X_mobile (Tensor): Mobile coordinates with shape | |
`(num_source, num_atoms, 4, 3)`. | |
X_target (Tensor): Target coordinates with shape | |
`(num_target, num_atoms, 4, 3)`. | |
edge_idx | |
C (Tensor): Chain map with shape `(num_batch, num_residues)`. | |
Outputs: | |
rmsd (Tensor, optional): Per-site fragment RMSDs with shape | |
`(num_batch)`. | |
""" | |
def __init__(self, k=7, method="symeig", method_iter=50, graph_num_neighbors=30): | |
super(LossFragmentPairRMSD, self).__init__() | |
self.k = k | |
self.rmsd = CrossRMSD(method=method, method_iter=method_iter) | |
self.graph_builder = protein_graph.ProteinGraph( | |
num_neighbors=graph_num_neighbors | |
) | |
def _stack_neighbor(self, node_h, edge_idx): | |
neighbor_h = graph.collect_neighbors(node_h, edge_idx) | |
node_h = node_h[:, :, None, :].expand(neighbor_h.shape) | |
edge_h = torch.cat([neighbor_h, node_h], dim=-1) | |
return edge_h | |
def _collect_X_fragment_pairs(self, X, C, edge_idx): | |
X_kmer, C_kmer = _collect_X_fragments(X, C, self.k) | |
X_pair = self._stack_neighbor(X_kmer, edge_idx) | |
C_pair = self._stack_neighbor(C_kmer, edge_idx) | |
X_pair = X_pair.reshape(list(X_pair.shape)[:-1] + [-1, 3]) | |
return X_pair, C_pair | |
def forward(self, X_mobile, X_target, C, return_coords=False): | |
# Discard potential sidechain coordinates | |
X_mobile = X_mobile[:, :, :4, :] | |
X_target = X_target[:, :, :4, :] | |
# Build graph and pair fragments | |
edge_idx, mask_ij = self.graph_builder(X_target, C) | |
X_pair_mobile, C_pair_mobile = self._collect_X_fragment_pairs( | |
X_mobile, C, edge_idx | |
) | |
X_pair_target, C_pair_target = self._collect_X_fragment_pairs( | |
X_target, C, edge_idx | |
) | |
mask = (C_pair_mobile > 0).float() | |
rmsd, X_pair_mobile_align = self.rmsd.pairedRMSD( | |
X_pair_mobile, X_pair_target, mask, compute_alignment=True | |
) | |
if return_coords: | |
return rmsd, mask_ij, X_pair_target, X_pair_mobile, X_pair_mobile_align | |
else: | |
return rmsd, mask_ij | |
class LossNeighborhoodRMSD(nn.Module): | |
"""Compute optimal fragment-pair RMSDs between two sets of backbones. | |
Args: | |
fragment_k (int, option): Fram | |
method (str, optional): Method for calculating the most postive | |
eigenvalue. Can be `power` or `symeig`. Default is `power`. | |
method_iter (int, optional): Number of power iterations for eigenvalue | |
approximation. Requires `method=power`. Default is 50. | |
Inputs: | |
X_mobile (Tensor): Mobile coordinates with shape | |
`(num_source, num_atoms, 4, 3)`. | |
X_target (Tensor): Target coordinates with shape | |
`(num_target, num_atoms, 4, 3)`. | |
edge_idx | |
C (Tensor): Chain map with shape `(num_batch, num_residues)`. | |
Outputs: | |
rmsd (Tensor, optional): Per-site fragment RMSDs with shape | |
`(num_batch)`. | |
""" | |
def __init__(self, method="symeig", method_iter=50, graph_num_neighbors=30): | |
super(LossNeighborhoodRMSD, self).__init__() | |
self.rmsd = CrossRMSD(method=method, method_iter=method_iter) | |
self.graph_builder = protein_graph.ProteinGraph( | |
num_neighbors=graph_num_neighbors | |
) | |
def _collect_X_neighborhood(self, X, C, edge_idx): | |
num_batch, num_nodes, num_atoms, _ = X.shape | |
shape_flat = [num_batch, num_nodes, -1] | |
X_flat = X.reshape(shape_flat) | |
C_flat = C[..., None].expand([-1, -1, num_atoms]) | |
X_neighborhood = graph.collect_neighbors(X_flat, edge_idx).reshape( | |
[num_batch, num_nodes, -1, 3] | |
) | |
C_neighborhood = graph.collect_neighbors(C_flat, edge_idx).reshape( | |
[num_batch, num_nodes, -1] | |
) | |
return X_neighborhood, C_neighborhood | |
def forward(self, X_mobile, X_target, C, return_coords=False): | |
# Discard potential sidechain coordinates | |
X_mobile = X_mobile[:, :, :4, :] | |
X_target = X_target[:, :, :4, :] | |
# Build graph and pair fragments | |
edge_idx, mask_ij = self.graph_builder(X_target, C) | |
X_neighborhood_mobile, C_neighborhood_mobile = self._collect_X_neighborhood( | |
X_mobile, C, edge_idx | |
) | |
X_neighborhood_target, C_neighborhood_target = self._collect_X_neighborhood( | |
X_target, C, edge_idx | |
) | |
mask = (C_neighborhood_mobile > 0).float() | |
rmsd, X_neighborhood_mobile_align = self.rmsd.pairedRMSD( | |
X_neighborhood_mobile, X_neighborhood_target, mask, compute_alignment=True | |
) | |
mask = (mask.sum(-1) > 0).float() | |
if return_coords: | |
return ( | |
rmsd, | |
mask, | |
X_neighborhood_target, | |
X_neighborhood_mobile, | |
X_neighborhood_mobile_align, | |
) | |
else: | |
return rmsd, mask | |
def rmsd_unaligned(X_a, X_b, mask=None, eps=1e-5, _min_rmsd=1e-8): | |
"""Compute RMSD between two coordinate sets without alignment. | |
Args: | |
X_a (Tensor): Coordinate set 1 with shape `(..., num_points, 3)`. | |
X_b (Tensor): Coordinate set 2 with shape `(..., num_points, 3)`. | |
mask (Tensor, optional): Mask with shape `(..., num_points)`. | |
eps (float, optional): Small number to prevent division by zero. | |
default is 1E-5. | |
Returns: | |
rmsd (Tensor): Root mean squared deviations (raw) with shape `(...)`. | |
""" | |
squared_dev = ((X_a - X_b) ** 2).sum(-1) | |
if mask is None: | |
rmsd = torch.sqrt(squared_dev.mean(-1).clamp(min=_min_rmsd)) | |
else: | |
rmsd = torch.sqrt( | |
(mask * squared_dev).sum(-1).clamp(min=_min_rmsd) / (mask.sum(-1) + eps) | |
) | |
return rmsd | |
""" | |
这两个函数是处理蛋白质结构数据的关键部分,特别是在需要从蛋白质结构中提取和分析特定长度片段的情况下。 | |
_collect_X_fragments 函数处理蛋白质的坐标和链映射信息,以收集和处理特定长度的片段, | |
而 _collect_kmers 函数则是一个更通用的工具,用于从任何给定的节点特征矩阵中收集 k-mers. | |
_collect_X_fragments: | |
函数首先将 X 和 C 转换为扁平形状。 | |
然后,使用 _collect_kmers 函数从 X_flat 和 C_flat 中收集 k-mers,这些 k-mers 本质上是局部的、长度为 k 的片段。 | |
最后,函数使用 torch.where 来处理非连续原子,将它们视为缺失,并返回处理后的 X_kmer 和 C_kmer。 | |
_collect_kmers: | |
函数的主要步骤包括: | |
构建索引以定位 k-mers。首先,创建一个长度为 k 的索引数组 k_idx。 | |
然后,使用这个索引和节点的索引 node_idx 生成 k-mers 的索引 kmer_idx。 | |
使用 kmer_idx 从 node_h 中收集相邻节点的特征,形成新的 k-mer 特征矩阵 kmer_h。 | |
这个函数的关键在于它能够从原始的节点特征矩阵中构建出包含局部邻居信息的新矩阵,这对于处理基于图的结构(如蛋白质结构)特别有用。 | |
""" | |
def _collect_X_fragments(X, C, k): | |
num_batch, num_nodes, num_atoms, _ = X.shape | |
shape_flat = [num_batch, num_nodes, -1] | |
X_flat = X.reshape(shape_flat) | |
C_flat = C[..., None].expand([-1, -1, num_atoms]) | |
# Grab local kmers | |
X_kmer = _collect_kmers(X_flat, k).reshape(shape_flat) | |
C_kmer = _collect_kmers(C_flat, k).reshape(shape_flat) | |
# Treat noncontiguous atoms as missing | |
C_kmer = torch.where(C[..., None].eq(C_kmer), C_kmer, -C_kmer.abs()) | |
return X_kmer, C_kmer | |
def _collect_kmers(node_h, k): | |
"""Gather `(B,I,H) => (B,I,K,H)`""" | |
device = node_h.device | |
num_batch, num_nodes, _ = node_h.shape | |
# Build indices | |
k_idx = torch.arange(k, device=device) - (k - 1) // 2 | |
node_idx = torch.arange(node_h.shape[1], device=device) | |
kmer_idx = node_idx[None, :, None] - k_idx[None, None, :] | |
kmer_idx = kmer_idx.clamp(min=0, max=num_nodes - 1).long() | |
kmer_idx = kmer_idx.expand([num_batch, -1, k]) | |
# Collect neighbors | |
kmer_h = graph.collect_neighbors(node_h, kmer_idx) | |
return kmer_h | |