Spaces:
Sleeping
Sleeping
# Copyright 2021 AlQuraishi Laboratory | |
# Copyright 2021 DeepMind Technologies Limited | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import itertools | |
from functools import reduce, wraps | |
from operator import add | |
import numpy as np | |
import torch | |
from dockformerpp.config import NUM_RES | |
from dockformerpp.utils import residue_constants as rc | |
from dockformerpp.utils.residue_constants import restypes | |
from dockformerpp.utils.rigid_utils import Rotation, Rigid | |
from dockformerpp.utils.geometry.rigid_matrix_vector import Rigid3Array | |
from dockformerpp.utils.geometry.rotation_matrix import Rot3Array | |
from dockformerpp.utils.geometry.vector import Vec3Array | |
from dockformerpp.utils.tensor_utils import ( | |
tree_map, | |
tensor_tree_map, | |
batched_gather, | |
) | |
def cast_to_64bit_ints(protein): | |
# We keep all ints as int64 | |
for k, v in protein.items(): | |
if v.dtype == torch.int32: | |
protein[k] = v.type(torch.int64) | |
return protein | |
def make_one_hot(x, num_classes): | |
x_one_hot = torch.zeros(*x.shape, num_classes, device=x.device) | |
x_one_hot.scatter_(-1, x.unsqueeze(-1), 1) | |
return x_one_hot | |
def curry1(f): | |
"""Supply all arguments but the first.""" | |
def fc(*args, **kwargs): | |
return lambda x: f(x, *args, **kwargs) | |
return fc | |
def squeeze_features(protein): | |
"""Remove singleton and repeated dimensions in protein features.""" | |
protein["aatype"] = torch.argmax(protein["aatype"], dim=-1) | |
for k in [ | |
"domain_name", | |
"seq_length", | |
"sequence", | |
"resolution", | |
"residue_index", | |
]: | |
if k in protein: | |
final_dim = protein[k].shape[-1] | |
if isinstance(final_dim, int) and final_dim == 1: | |
if torch.is_tensor(protein[k]): | |
protein[k] = torch.squeeze(protein[k], dim=-1) | |
else: | |
protein[k] = np.squeeze(protein[k], axis=-1) | |
for k in ["seq_length"]: | |
if k in protein: | |
protein[k] = protein[k][0] | |
return protein | |
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask): | |
"""Create pseudo beta features.""" | |
is_gly = torch.eq(aatype, rc.restype_order["G"]) | |
ca_idx = rc.atom_order["CA"] | |
cb_idx = rc.atom_order["CB"] | |
pseudo_beta = torch.where( | |
torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]), | |
all_atom_positions[..., ca_idx, :], | |
all_atom_positions[..., cb_idx, :], | |
) | |
if all_atom_mask is not None: | |
pseudo_beta_mask = torch.where( | |
is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx] | |
) | |
return pseudo_beta, pseudo_beta_mask | |
else: | |
return pseudo_beta | |
def make_pseudo_beta(protein): | |
"""Create pseudo-beta (alpha for glycine) position and mask.""" | |
(protein["pseudo_beta"], protein["pseudo_beta_mask"]) = pseudo_beta_fn( | |
protein["aatype"], | |
protein["all_atom_positions"], | |
protein["all_atom_mask"], | |
) | |
return protein | |
def make_target_feat(protein): | |
"""Create and concatenate protein features.""" | |
# Whether there is a domain break. Always zero for chains, but keeping for | |
# compatibility with domain datasets. | |
aatype_1hot = make_one_hot(protein["aatype"], 20) | |
protein["protein_target_feat"] = aatype_1hot | |
return protein | |
def select_feat(protein, feature_list): | |
return {k: v for k, v in protein.items() if k in feature_list} | |
def get_restypes(device): | |
restype_atom14_to_atom37 = [] | |
restype_atom37_to_atom14 = [] | |
restype_atom14_mask = [] | |
for rt in rc.restypes: | |
atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]] | |
restype_atom14_to_atom37.append( | |
[(rc.atom_order[name] if name else 0) for name in atom_names] | |
) | |
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} | |
restype_atom37_to_atom14.append( | |
[ | |
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) | |
for name in rc.atom_types | |
] | |
) | |
restype_atom14_mask.append( | |
[(1.0 if name else 0.0) for name in atom_names] | |
) | |
# Add dummy mapping for restype 'UNK' | |
restype_atom14_to_atom37.append([0] * 14) | |
restype_atom37_to_atom14.append([0] * 37) | |
restype_atom14_mask.append([0.0] * 14) | |
restype_atom14_to_atom37 = torch.tensor( | |
restype_atom14_to_atom37, | |
dtype=torch.int32, | |
device=device, | |
) | |
restype_atom37_to_atom14 = torch.tensor( | |
restype_atom37_to_atom14, | |
dtype=torch.int32, | |
device=device, | |
) | |
restype_atom14_mask = torch.tensor( | |
restype_atom14_mask, | |
dtype=torch.float32, | |
device=device, | |
) | |
return restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask | |
def get_restype_atom37_mask(device): | |
# create the corresponding mask | |
restype_atom37_mask = torch.zeros( | |
[len(restypes) + 1, 37], dtype=torch.float32, device=device | |
) | |
for restype, restype_letter in enumerate(rc.restypes): | |
restype_name = rc.restype_1to3[restype_letter] | |
atom_names = rc.residue_atoms[restype_name] | |
for atom_name in atom_names: | |
atom_type = rc.atom_order[atom_name] | |
restype_atom37_mask[restype, atom_type] = 1 | |
return restype_atom37_mask | |
def make_atom14_masks(protein): | |
"""Construct denser atom positions (14 dimensions instead of 37).""" | |
restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask = get_restypes(protein["aatype"].device) | |
protein_aatype = protein['aatype'].to(torch.long) | |
# create the mapping for (residx, atom14) --> atom37, i.e. an array | |
# with shape (num_res, 14) containing the atom37 indices for this protein | |
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype] | |
residx_atom14_mask = restype_atom14_mask[protein_aatype] | |
protein["atom14_atom_exists"] = residx_atom14_mask | |
protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long() | |
# create the gather indices for mapping back | |
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype] | |
protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long() | |
restype_atom37_mask = get_restype_atom37_mask(protein["aatype"].device) | |
residx_atom37_mask = restype_atom37_mask[protein_aatype] | |
protein["atom37_atom_exists"] = residx_atom37_mask | |
return protein | |
def make_atom14_positions(protein): | |
"""Constructs denser atom positions (14 dimensions instead of 37).""" | |
residx_atom14_mask = protein["atom14_atom_exists"] | |
residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"] | |
# Create a mask for known ground truth positions. | |
residx_atom14_gt_mask = residx_atom14_mask * batched_gather( | |
protein["all_atom_mask"], | |
residx_atom14_to_atom37, | |
dim=-1, | |
no_batch_dims=len(protein["all_atom_mask"].shape[:-1]), | |
) | |
# Gather the ground truth positions. | |
residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * ( | |
batched_gather( | |
protein["all_atom_positions"], | |
residx_atom14_to_atom37, | |
dim=-2, | |
no_batch_dims=len(protein["all_atom_positions"].shape[:-2]), | |
) | |
) | |
protein["atom14_atom_exists"] = residx_atom14_mask | |
protein["atom14_gt_exists"] = residx_atom14_gt_mask | |
protein["atom14_gt_positions"] = residx_atom14_gt_positions | |
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide | |
# alternative ground truth coordinates where the naming is swapped | |
restype_3 = [rc.restype_1to3[res] for res in rc.restypes] | |
restype_3 += ["UNK"] | |
# Matrices for renaming ambiguous atoms. | |
all_matrices = { | |
res: torch.eye( | |
14, | |
dtype=protein["all_atom_mask"].dtype, | |
device=protein["all_atom_mask"].device, | |
) | |
for res in restype_3 | |
} | |
for resname, swap in rc.residue_atom_renaming_swaps.items(): | |
correspondences = torch.arange( | |
14, device=protein["all_atom_mask"].device | |
) | |
for source_atom_swap, target_atom_swap in swap.items(): | |
source_index = rc.restype_name_to_atom14_names[resname].index( | |
source_atom_swap | |
) | |
target_index = rc.restype_name_to_atom14_names[resname].index( | |
target_atom_swap | |
) | |
correspondences[source_index] = target_index | |
correspondences[target_index] = source_index | |
renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14)) | |
for index, correspondence in enumerate(correspondences): | |
renaming_matrix[index, correspondence] = 1.0 | |
all_matrices[resname] = renaming_matrix | |
renaming_matrices = torch.stack( | |
[all_matrices[restype] for restype in restype_3] | |
) | |
# Pick the transformation matrices for the given residue sequence | |
# shape (num_res, 14, 14). | |
renaming_transform = renaming_matrices[protein["aatype"]] | |
# Apply it to the ground truth positions. shape (num_res, 14, 3). | |
alternative_gt_positions = torch.einsum( | |
"...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform | |
) | |
protein["atom14_alt_gt_positions"] = alternative_gt_positions | |
# Create the mask for the alternative ground truth (differs from the | |
# ground truth mask, if only one of the atoms in an ambiguous pair has a | |
# ground truth position). | |
alternative_gt_mask = torch.einsum( | |
"...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform | |
) | |
protein["atom14_alt_gt_exists"] = alternative_gt_mask | |
# Create an ambiguous atoms mask. shape: (21, 14). | |
restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14)) | |
for resname, swap in rc.residue_atom_renaming_swaps.items(): | |
for atom_name1, atom_name2 in swap.items(): | |
restype = rc.restype_order[rc.restype_3to1[resname]] | |
atom_idx1 = rc.restype_name_to_atom14_names[resname].index( | |
atom_name1 | |
) | |
atom_idx2 = rc.restype_name_to_atom14_names[resname].index( | |
atom_name2 | |
) | |
restype_atom14_is_ambiguous[restype, atom_idx1] = 1 | |
restype_atom14_is_ambiguous[restype, atom_idx2] = 1 | |
# From this create an ambiguous_mask for the given sequence. | |
protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[ | |
protein["aatype"] | |
] | |
return protein | |
def atom37_to_frames(protein, eps=1e-8): | |
aatype = protein["aatype"] | |
all_atom_positions = protein["all_atom_positions"] | |
all_atom_mask = protein["all_atom_mask"] | |
batch_dims = len(aatype.shape[:-1]) | |
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object) | |
restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"] | |
restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"] | |
for restype, restype_letter in enumerate(rc.restypes): | |
resname = rc.restype_1to3[restype_letter] | |
for chi_idx in range(4): | |
if rc.chi_angles_mask[restype][chi_idx]: | |
names = rc.chi_angles_atoms[resname][chi_idx] | |
restype_rigidgroup_base_atom_names[ | |
restype, chi_idx + 4, : | |
] = names[1:] | |
restype_rigidgroup_mask = all_atom_mask.new_zeros( | |
(*aatype.shape[:-1], 21, 8), | |
) | |
restype_rigidgroup_mask[..., 0] = 1 | |
restype_rigidgroup_mask[..., 3] = 1 | |
restype_rigidgroup_mask[..., :len(restypes), 4:] = all_atom_mask.new_tensor( | |
rc.chi_angles_mask | |
) | |
lookuptable = rc.atom_order.copy() | |
lookuptable[""] = 0 | |
lookup = np.vectorize(lambda x: lookuptable[x]) | |
restype_rigidgroup_base_atom37_idx = lookup( | |
restype_rigidgroup_base_atom_names, | |
) | |
restype_rigidgroup_base_atom37_idx = aatype.new_tensor( | |
restype_rigidgroup_base_atom37_idx, | |
) | |
restype_rigidgroup_base_atom37_idx = ( | |
restype_rigidgroup_base_atom37_idx.view( | |
*((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape | |
) | |
) | |
residx_rigidgroup_base_atom37_idx = batched_gather( | |
restype_rigidgroup_base_atom37_idx, | |
aatype, | |
dim=-3, | |
no_batch_dims=batch_dims, | |
) | |
base_atom_pos = batched_gather( | |
all_atom_positions, | |
residx_rigidgroup_base_atom37_idx, | |
dim=-2, | |
no_batch_dims=len(all_atom_positions.shape[:-2]), | |
) | |
gt_frames = Rigid.from_3_points( | |
p_neg_x_axis=base_atom_pos[..., 0, :], | |
origin=base_atom_pos[..., 1, :], | |
p_xy_plane=base_atom_pos[..., 2, :], | |
eps=eps, | |
) | |
group_exists = batched_gather( | |
restype_rigidgroup_mask, | |
aatype, | |
dim=-2, | |
no_batch_dims=batch_dims, | |
) | |
gt_atoms_exist = batched_gather( | |
all_atom_mask, | |
residx_rigidgroup_base_atom37_idx, | |
dim=-1, | |
no_batch_dims=len(all_atom_mask.shape[:-1]), | |
) | |
gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists | |
rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device) | |
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1)) | |
rots[..., 0, 0, 0] = -1 | |
rots[..., 0, 2, 2] = -1 | |
rots = Rotation(rot_mats=rots) | |
gt_frames = gt_frames.compose(Rigid(rots, None)) | |
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros( | |
*((1,) * batch_dims), 21, 8 | |
) | |
restype_rigidgroup_rots = torch.eye( | |
3, dtype=all_atom_mask.dtype, device=aatype.device | |
) | |
restype_rigidgroup_rots = torch.tile( | |
restype_rigidgroup_rots, | |
(*((1,) * batch_dims), 21, 8, 1, 1), | |
) | |
for resname, _ in rc.residue_atom_renaming_swaps.items(): | |
restype = rc.restype_order[rc.restype_3to1[resname]] | |
chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1) | |
restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1 | |
restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1 | |
restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1 | |
residx_rigidgroup_is_ambiguous = batched_gather( | |
restype_rigidgroup_is_ambiguous, | |
aatype, | |
dim=-2, | |
no_batch_dims=batch_dims, | |
) | |
residx_rigidgroup_ambiguity_rot = batched_gather( | |
restype_rigidgroup_rots, | |
aatype, | |
dim=-4, | |
no_batch_dims=batch_dims, | |
) | |
residx_rigidgroup_ambiguity_rot = Rotation( | |
rot_mats=residx_rigidgroup_ambiguity_rot | |
) | |
alt_gt_frames = gt_frames.compose( | |
Rigid(residx_rigidgroup_ambiguity_rot, None) | |
) | |
gt_frames_tensor = gt_frames.to_tensor_4x4() | |
alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4() | |
protein["rigidgroups_gt_frames"] = gt_frames_tensor | |
protein["rigidgroups_gt_exists"] = gt_exists | |
protein["rigidgroups_group_exists"] = group_exists | |
protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous | |
protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor | |
return protein | |
def get_chi_atom_indices(): | |
"""Returns atom indices needed to compute chi angles for all residue types. | |
Returns: | |
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are | |
in the order specified in rc.restypes + unknown residue type | |
at the end. For chi angles which are not defined on the residue, the | |
positions indices are by default set to 0. | |
""" | |
chi_atom_indices = [] | |
for residue_name in rc.restypes: | |
residue_name = rc.restype_1to3[residue_name] | |
residue_chi_angles = rc.chi_angles_atoms[residue_name] | |
atom_indices = [] | |
for chi_angle in residue_chi_angles: | |
atom_indices.append([rc.atom_order[atom] for atom in chi_angle]) | |
for _ in range(4 - len(atom_indices)): | |
atom_indices.append( | |
[0, 0, 0, 0] | |
) # For chi angles not defined on the AA. | |
chi_atom_indices.append(atom_indices) | |
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. | |
return chi_atom_indices | |
def atom37_to_torsion_angles( | |
protein, | |
prefix="", | |
): | |
""" | |
Convert coordinates to torsion angles. | |
This function is extremely sensitive to floating point imprecisions | |
and should be run with double precision whenever possible. | |
Args: | |
Dict containing: | |
* (prefix)aatype: | |
[*, N_res] residue indices | |
* (prefix)all_atom_positions: | |
[*, N_res, 37, 3] atom positions (in atom37 | |
format) | |
* (prefix)all_atom_mask: | |
[*, N_res, 37] atom position mask | |
Returns: | |
The same dictionary updated with the following features: | |
"(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2]) | |
Torsion angles | |
"(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2]) | |
Alternate torsion angles (accounting for 180-degree symmetry) | |
"(prefix)torsion_angles_mask" ([*, N_res, 7]) | |
Torsion angles mask | |
""" | |
aatype = protein[prefix + "aatype"] | |
all_atom_positions = protein[prefix + "all_atom_positions"] | |
all_atom_mask = protein[prefix + "all_atom_mask"] | |
aatype = torch.clamp(aatype, max=20) | |
pad = all_atom_positions.new_zeros( | |
[*all_atom_positions.shape[:-3], 1, 37, 3] | |
) | |
prev_all_atom_positions = torch.cat( | |
[pad, all_atom_positions[..., :-1, :, :]], dim=-3 | |
) | |
pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37]) | |
prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2) | |
pre_omega_atom_pos = torch.cat( | |
[prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]], | |
dim=-2, | |
) | |
phi_atom_pos = torch.cat( | |
[prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]], | |
dim=-2, | |
) | |
psi_atom_pos = torch.cat( | |
[all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]], | |
dim=-2, | |
) | |
pre_omega_mask = torch.prod( | |
prev_all_atom_mask[..., 1:3], dim=-1 | |
) * torch.prod(all_atom_mask[..., :2], dim=-1) | |
phi_mask = prev_all_atom_mask[..., 2] * torch.prod( | |
all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype | |
) | |
psi_mask = ( | |
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype) | |
* all_atom_mask[..., 4] | |
) | |
chi_atom_indices = torch.as_tensor( | |
get_chi_atom_indices(), device=aatype.device | |
) | |
atom_indices = chi_atom_indices[..., aatype, :, :] | |
chis_atom_pos = batched_gather( | |
all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2]) | |
) | |
chi_angles_mask = list(rc.chi_angles_mask) | |
chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) | |
chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask) | |
chis_mask = chi_angles_mask[aatype, :] | |
chi_angle_atoms_mask = batched_gather( | |
all_atom_mask, | |
atom_indices, | |
dim=-1, | |
no_batch_dims=len(atom_indices.shape[:-2]), | |
) | |
chi_angle_atoms_mask = torch.prod( | |
chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype | |
) | |
chis_mask = chis_mask * chi_angle_atoms_mask | |
torsions_atom_pos = torch.cat( | |
[ | |
pre_omega_atom_pos[..., None, :, :], | |
phi_atom_pos[..., None, :, :], | |
psi_atom_pos[..., None, :, :], | |
chis_atom_pos, | |
], | |
dim=-3, | |
) | |
torsion_angles_mask = torch.cat( | |
[ | |
pre_omega_mask[..., None], | |
phi_mask[..., None], | |
psi_mask[..., None], | |
chis_mask, | |
], | |
dim=-1, | |
) | |
torsion_frames = Rigid.from_3_points( | |
torsions_atom_pos[..., 1, :], | |
torsions_atom_pos[..., 2, :], | |
torsions_atom_pos[..., 0, :], | |
eps=1e-8, | |
) | |
fourth_atom_rel_pos = torsion_frames.invert().apply( | |
torsions_atom_pos[..., 3, :] | |
) | |
torsion_angles_sin_cos = torch.stack( | |
[fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1 | |
) | |
denom = torch.sqrt( | |
torch.sum( | |
torch.square(torsion_angles_sin_cos), | |
dim=-1, | |
dtype=torsion_angles_sin_cos.dtype, | |
keepdims=True, | |
) | |
+ 1e-8 | |
) | |
torsion_angles_sin_cos = torsion_angles_sin_cos / denom | |
torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor( | |
[1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0], | |
)[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)] | |
chi_is_ambiguous = torsion_angles_sin_cos.new_tensor( | |
rc.chi_pi_periodic, | |
)[aatype, ...] | |
mirror_torsion_angles = torch.cat( | |
[ | |
all_atom_mask.new_ones(*aatype.shape, 3), | |
1.0 - 2.0 * chi_is_ambiguous, | |
], | |
dim=-1, | |
) | |
alt_torsion_angles_sin_cos = ( | |
torsion_angles_sin_cos * mirror_torsion_angles[..., None] | |
) | |
protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos | |
protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos | |
protein[prefix + "torsion_angles_mask"] = torsion_angles_mask | |
return protein | |
def get_backbone_frames(protein): | |
# DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why. | |
protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][ | |
..., 0, :, : | |
] | |
protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0] | |
return protein | |
def get_chi_angles(protein): | |
dtype = protein["all_atom_mask"].dtype | |
protein["chi_angles_sin_cos"] = ( | |
protein["torsion_angles_sin_cos"][..., 3:, :] | |
).to(dtype) | |
protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype) | |
return protein | |
def random_crop_to_size( | |
protein, | |
crop_size, | |
shape_schema, | |
seed=None, | |
): | |
"""Crop randomly to `crop_size`, or keep as is if shorter than that.""" | |
# We want each ensemble to be cropped the same way | |
g = None | |
if seed is not None: | |
g = torch.Generator(device=protein["seq_length"].device) | |
g.manual_seed(seed) | |
seq_length = protein["seq_length"] | |
num_res_crop_size = min(int(seq_length), crop_size) | |
def _randint(lower, upper): | |
return int(torch.randint( | |
lower, | |
upper + 1, | |
(1,), | |
device=protein["seq_length"].device, | |
generator=g, | |
)[0]) | |
n = seq_length - num_res_crop_size | |
if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.: | |
right_anchor = n | |
else: | |
x = _randint(0, n) | |
right_anchor = n - x | |
num_res_crop_start = _randint(0, right_anchor) | |
for k, v in protein.items(): | |
if k not in shape_schema or (NUM_RES not in shape_schema[k]): | |
continue | |
slices = [] | |
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)): | |
is_num_res = dim_size == NUM_RES | |
crop_start = num_res_crop_start if is_num_res else 0 | |
crop_size = num_res_crop_size if is_num_res else dim | |
slices.append(slice(crop_start, crop_start + crop_size)) | |
protein[k] = v[slices] | |
protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size) | |
return protein | |