# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools from functools import reduce, wraps from operator import add import numpy as np import torch from dockformerpp.config import NUM_RES from dockformerpp.utils import residue_constants as rc from dockformerpp.utils.residue_constants import restypes from dockformerpp.utils.rigid_utils import Rotation, Rigid from dockformerpp.utils.geometry.rigid_matrix_vector import Rigid3Array from dockformerpp.utils.geometry.rotation_matrix import Rot3Array from dockformerpp.utils.geometry.vector import Vec3Array from dockformerpp.utils.tensor_utils import ( tree_map, tensor_tree_map, batched_gather, ) def cast_to_64bit_ints(protein): # We keep all ints as int64 for k, v in protein.items(): if v.dtype == torch.int32: protein[k] = v.type(torch.int64) return protein def make_one_hot(x, num_classes): x_one_hot = torch.zeros(*x.shape, num_classes, device=x.device) x_one_hot.scatter_(-1, x.unsqueeze(-1), 1) return x_one_hot def curry1(f): """Supply all arguments but the first.""" @wraps(f) def fc(*args, **kwargs): return lambda x: f(x, *args, **kwargs) return fc def squeeze_features(protein): """Remove singleton and repeated dimensions in protein features.""" protein["aatype"] = torch.argmax(protein["aatype"], dim=-1) for k in [ "domain_name", "seq_length", "sequence", "resolution", "residue_index", ]: if k in protein: final_dim = protein[k].shape[-1] if isinstance(final_dim, int) and final_dim == 1: if torch.is_tensor(protein[k]): protein[k] = torch.squeeze(protein[k], dim=-1) else: protein[k] = np.squeeze(protein[k], axis=-1) for k in ["seq_length"]: if k in protein: protein[k] = protein[k][0] return protein def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask): """Create pseudo beta features.""" is_gly = torch.eq(aatype, rc.restype_order["G"]) ca_idx = rc.atom_order["CA"] cb_idx = rc.atom_order["CB"] pseudo_beta = torch.where( torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]), all_atom_positions[..., ca_idx, :], all_atom_positions[..., cb_idx, :], ) if all_atom_mask is not None: pseudo_beta_mask = torch.where( is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx] ) return pseudo_beta, pseudo_beta_mask else: return pseudo_beta @curry1 def make_pseudo_beta(protein): """Create pseudo-beta (alpha for glycine) position and mask.""" (protein["pseudo_beta"], protein["pseudo_beta_mask"]) = pseudo_beta_fn( protein["aatype"], protein["all_atom_positions"], protein["all_atom_mask"], ) return protein @curry1 def make_target_feat(protein): """Create and concatenate protein features.""" # Whether there is a domain break. Always zero for chains, but keeping for # compatibility with domain datasets. aatype_1hot = make_one_hot(protein["aatype"], 20) protein["protein_target_feat"] = aatype_1hot return protein @curry1 def select_feat(protein, feature_list): return {k: v for k, v in protein.items() if k in feature_list} def get_restypes(device): restype_atom14_to_atom37 = [] restype_atom37_to_atom14 = [] restype_atom14_mask = [] for rt in rc.restypes: atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]] restype_atom14_to_atom37.append( [(rc.atom_order[name] if name else 0) for name in atom_names] ) atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} restype_atom37_to_atom14.append( [ (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in rc.atom_types ] ) restype_atom14_mask.append( [(1.0 if name else 0.0) for name in atom_names] ) # Add dummy mapping for restype 'UNK' restype_atom14_to_atom37.append([0] * 14) restype_atom37_to_atom14.append([0] * 37) restype_atom14_mask.append([0.0] * 14) restype_atom14_to_atom37 = torch.tensor( restype_atom14_to_atom37, dtype=torch.int32, device=device, ) restype_atom37_to_atom14 = torch.tensor( restype_atom37_to_atom14, dtype=torch.int32, device=device, ) restype_atom14_mask = torch.tensor( restype_atom14_mask, dtype=torch.float32, device=device, ) return restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask def get_restype_atom37_mask(device): # create the corresponding mask restype_atom37_mask = torch.zeros( [len(restypes) + 1, 37], dtype=torch.float32, device=device ) for restype, restype_letter in enumerate(rc.restypes): restype_name = rc.restype_1to3[restype_letter] atom_names = rc.residue_atoms[restype_name] for atom_name in atom_names: atom_type = rc.atom_order[atom_name] restype_atom37_mask[restype, atom_type] = 1 return restype_atom37_mask def make_atom14_masks(protein): """Construct denser atom positions (14 dimensions instead of 37).""" restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask = get_restypes(protein["aatype"].device) protein_aatype = protein['aatype'].to(torch.long) # create the mapping for (residx, atom14) --> atom37, i.e. an array # with shape (num_res, 14) containing the atom37 indices for this protein residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype] residx_atom14_mask = restype_atom14_mask[protein_aatype] protein["atom14_atom_exists"] = residx_atom14_mask protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long() # create the gather indices for mapping back residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype] protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long() restype_atom37_mask = get_restype_atom37_mask(protein["aatype"].device) residx_atom37_mask = restype_atom37_mask[protein_aatype] protein["atom37_atom_exists"] = residx_atom37_mask return protein def make_atom14_positions(protein): """Constructs denser atom positions (14 dimensions instead of 37).""" residx_atom14_mask = protein["atom14_atom_exists"] residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"] # Create a mask for known ground truth positions. residx_atom14_gt_mask = residx_atom14_mask * batched_gather( protein["all_atom_mask"], residx_atom14_to_atom37, dim=-1, no_batch_dims=len(protein["all_atom_mask"].shape[:-1]), ) # Gather the ground truth positions. residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * ( batched_gather( protein["all_atom_positions"], residx_atom14_to_atom37, dim=-2, no_batch_dims=len(protein["all_atom_positions"].shape[:-2]), ) ) protein["atom14_atom_exists"] = residx_atom14_mask protein["atom14_gt_exists"] = residx_atom14_gt_mask protein["atom14_gt_positions"] = residx_atom14_gt_positions # As the atom naming is ambiguous for 7 of the 20 amino acids, provide # alternative ground truth coordinates where the naming is swapped restype_3 = [rc.restype_1to3[res] for res in rc.restypes] restype_3 += ["UNK"] # Matrices for renaming ambiguous atoms. all_matrices = { res: torch.eye( 14, dtype=protein["all_atom_mask"].dtype, device=protein["all_atom_mask"].device, ) for res in restype_3 } for resname, swap in rc.residue_atom_renaming_swaps.items(): correspondences = torch.arange( 14, device=protein["all_atom_mask"].device ) for source_atom_swap, target_atom_swap in swap.items(): source_index = rc.restype_name_to_atom14_names[resname].index( source_atom_swap ) target_index = rc.restype_name_to_atom14_names[resname].index( target_atom_swap ) correspondences[source_index] = target_index correspondences[target_index] = source_index renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14)) for index, correspondence in enumerate(correspondences): renaming_matrix[index, correspondence] = 1.0 all_matrices[resname] = renaming_matrix renaming_matrices = torch.stack( [all_matrices[restype] for restype in restype_3] ) # Pick the transformation matrices for the given residue sequence # shape (num_res, 14, 14). renaming_transform = renaming_matrices[protein["aatype"]] # Apply it to the ground truth positions. shape (num_res, 14, 3). alternative_gt_positions = torch.einsum( "...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform ) protein["atom14_alt_gt_positions"] = alternative_gt_positions # Create the mask for the alternative ground truth (differs from the # ground truth mask, if only one of the atoms in an ambiguous pair has a # ground truth position). alternative_gt_mask = torch.einsum( "...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform ) protein["atom14_alt_gt_exists"] = alternative_gt_mask # Create an ambiguous atoms mask. shape: (21, 14). restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14)) for resname, swap in rc.residue_atom_renaming_swaps.items(): for atom_name1, atom_name2 in swap.items(): restype = rc.restype_order[rc.restype_3to1[resname]] atom_idx1 = rc.restype_name_to_atom14_names[resname].index( atom_name1 ) atom_idx2 = rc.restype_name_to_atom14_names[resname].index( atom_name2 ) restype_atom14_is_ambiguous[restype, atom_idx1] = 1 restype_atom14_is_ambiguous[restype, atom_idx2] = 1 # From this create an ambiguous_mask for the given sequence. protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[ protein["aatype"] ] return protein def atom37_to_frames(protein, eps=1e-8): aatype = protein["aatype"] all_atom_positions = protein["all_atom_positions"] all_atom_mask = protein["all_atom_mask"] batch_dims = len(aatype.shape[:-1]) restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object) restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"] restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"] for restype, restype_letter in enumerate(rc.restypes): resname = rc.restype_1to3[restype_letter] for chi_idx in range(4): if rc.chi_angles_mask[restype][chi_idx]: names = rc.chi_angles_atoms[resname][chi_idx] restype_rigidgroup_base_atom_names[ restype, chi_idx + 4, : ] = names[1:] restype_rigidgroup_mask = all_atom_mask.new_zeros( (*aatype.shape[:-1], 21, 8), ) restype_rigidgroup_mask[..., 0] = 1 restype_rigidgroup_mask[..., 3] = 1 restype_rigidgroup_mask[..., :len(restypes), 4:] = all_atom_mask.new_tensor( rc.chi_angles_mask ) lookuptable = rc.atom_order.copy() lookuptable[""] = 0 lookup = np.vectorize(lambda x: lookuptable[x]) restype_rigidgroup_base_atom37_idx = lookup( restype_rigidgroup_base_atom_names, ) restype_rigidgroup_base_atom37_idx = aatype.new_tensor( restype_rigidgroup_base_atom37_idx, ) restype_rigidgroup_base_atom37_idx = ( restype_rigidgroup_base_atom37_idx.view( *((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape ) ) residx_rigidgroup_base_atom37_idx = batched_gather( restype_rigidgroup_base_atom37_idx, aatype, dim=-3, no_batch_dims=batch_dims, ) base_atom_pos = batched_gather( all_atom_positions, residx_rigidgroup_base_atom37_idx, dim=-2, no_batch_dims=len(all_atom_positions.shape[:-2]), ) gt_frames = Rigid.from_3_points( p_neg_x_axis=base_atom_pos[..., 0, :], origin=base_atom_pos[..., 1, :], p_xy_plane=base_atom_pos[..., 2, :], eps=eps, ) group_exists = batched_gather( restype_rigidgroup_mask, aatype, dim=-2, no_batch_dims=batch_dims, ) gt_atoms_exist = batched_gather( all_atom_mask, residx_rigidgroup_base_atom37_idx, dim=-1, no_batch_dims=len(all_atom_mask.shape[:-1]), ) gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device) rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1)) rots[..., 0, 0, 0] = -1 rots[..., 0, 2, 2] = -1 rots = Rotation(rot_mats=rots) gt_frames = gt_frames.compose(Rigid(rots, None)) restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros( *((1,) * batch_dims), 21, 8 ) restype_rigidgroup_rots = torch.eye( 3, dtype=all_atom_mask.dtype, device=aatype.device ) restype_rigidgroup_rots = torch.tile( restype_rigidgroup_rots, (*((1,) * batch_dims), 21, 8, 1, 1), ) for resname, _ in rc.residue_atom_renaming_swaps.items(): restype = rc.restype_order[rc.restype_3to1[resname]] chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1) restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1 restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1 restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1 residx_rigidgroup_is_ambiguous = batched_gather( restype_rigidgroup_is_ambiguous, aatype, dim=-2, no_batch_dims=batch_dims, ) residx_rigidgroup_ambiguity_rot = batched_gather( restype_rigidgroup_rots, aatype, dim=-4, no_batch_dims=batch_dims, ) residx_rigidgroup_ambiguity_rot = Rotation( rot_mats=residx_rigidgroup_ambiguity_rot ) alt_gt_frames = gt_frames.compose( Rigid(residx_rigidgroup_ambiguity_rot, None) ) gt_frames_tensor = gt_frames.to_tensor_4x4() alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4() protein["rigidgroups_gt_frames"] = gt_frames_tensor protein["rigidgroups_gt_exists"] = gt_exists protein["rigidgroups_group_exists"] = group_exists protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor return protein def get_chi_atom_indices(): """Returns atom indices needed to compute chi angles for all residue types. Returns: A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are in the order specified in rc.restypes + unknown residue type at the end. For chi angles which are not defined on the residue, the positions indices are by default set to 0. """ chi_atom_indices = [] for residue_name in rc.restypes: residue_name = rc.restype_1to3[residue_name] residue_chi_angles = rc.chi_angles_atoms[residue_name] atom_indices = [] for chi_angle in residue_chi_angles: atom_indices.append([rc.atom_order[atom] for atom in chi_angle]) for _ in range(4 - len(atom_indices)): atom_indices.append( [0, 0, 0, 0] ) # For chi angles not defined on the AA. chi_atom_indices.append(atom_indices) chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. return chi_atom_indices @curry1 def atom37_to_torsion_angles( protein, prefix="", ): """ Convert coordinates to torsion angles. This function is extremely sensitive to floating point imprecisions and should be run with double precision whenever possible. Args: Dict containing: * (prefix)aatype: [*, N_res] residue indices * (prefix)all_atom_positions: [*, N_res, 37, 3] atom positions (in atom37 format) * (prefix)all_atom_mask: [*, N_res, 37] atom position mask Returns: The same dictionary updated with the following features: "(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2]) Torsion angles "(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2]) Alternate torsion angles (accounting for 180-degree symmetry) "(prefix)torsion_angles_mask" ([*, N_res, 7]) Torsion angles mask """ aatype = protein[prefix + "aatype"] all_atom_positions = protein[prefix + "all_atom_positions"] all_atom_mask = protein[prefix + "all_atom_mask"] aatype = torch.clamp(aatype, max=20) pad = all_atom_positions.new_zeros( [*all_atom_positions.shape[:-3], 1, 37, 3] ) prev_all_atom_positions = torch.cat( [pad, all_atom_positions[..., :-1, :, :]], dim=-3 ) pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37]) prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2) pre_omega_atom_pos = torch.cat( [prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]], dim=-2, ) phi_atom_pos = torch.cat( [prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]], dim=-2, ) psi_atom_pos = torch.cat( [all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]], dim=-2, ) pre_omega_mask = torch.prod( prev_all_atom_mask[..., 1:3], dim=-1 ) * torch.prod(all_atom_mask[..., :2], dim=-1) phi_mask = prev_all_atom_mask[..., 2] * torch.prod( all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype ) psi_mask = ( torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype) * all_atom_mask[..., 4] ) chi_atom_indices = torch.as_tensor( get_chi_atom_indices(), device=aatype.device ) atom_indices = chi_atom_indices[..., aatype, :, :] chis_atom_pos = batched_gather( all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2]) ) chi_angles_mask = list(rc.chi_angles_mask) chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask) chis_mask = chi_angles_mask[aatype, :] chi_angle_atoms_mask = batched_gather( all_atom_mask, atom_indices, dim=-1, no_batch_dims=len(atom_indices.shape[:-2]), ) chi_angle_atoms_mask = torch.prod( chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype ) chis_mask = chis_mask * chi_angle_atoms_mask torsions_atom_pos = torch.cat( [ pre_omega_atom_pos[..., None, :, :], phi_atom_pos[..., None, :, :], psi_atom_pos[..., None, :, :], chis_atom_pos, ], dim=-3, ) torsion_angles_mask = torch.cat( [ pre_omega_mask[..., None], phi_mask[..., None], psi_mask[..., None], chis_mask, ], dim=-1, ) torsion_frames = Rigid.from_3_points( torsions_atom_pos[..., 1, :], torsions_atom_pos[..., 2, :], torsions_atom_pos[..., 0, :], eps=1e-8, ) fourth_atom_rel_pos = torsion_frames.invert().apply( torsions_atom_pos[..., 3, :] ) torsion_angles_sin_cos = torch.stack( [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1 ) denom = torch.sqrt( torch.sum( torch.square(torsion_angles_sin_cos), dim=-1, dtype=torsion_angles_sin_cos.dtype, keepdims=True, ) + 1e-8 ) torsion_angles_sin_cos = torsion_angles_sin_cos / denom torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor( [1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0], )[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)] chi_is_ambiguous = torsion_angles_sin_cos.new_tensor( rc.chi_pi_periodic, )[aatype, ...] mirror_torsion_angles = torch.cat( [ all_atom_mask.new_ones(*aatype.shape, 3), 1.0 - 2.0 * chi_is_ambiguous, ], dim=-1, ) alt_torsion_angles_sin_cos = ( torsion_angles_sin_cos * mirror_torsion_angles[..., None] ) protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos protein[prefix + "torsion_angles_mask"] = torsion_angles_mask return protein def get_backbone_frames(protein): # DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why. protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][ ..., 0, :, : ] protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0] return protein def get_chi_angles(protein): dtype = protein["all_atom_mask"].dtype protein["chi_angles_sin_cos"] = ( protein["torsion_angles_sin_cos"][..., 3:, :] ).to(dtype) protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype) return protein @curry1 def random_crop_to_size( protein, crop_size, shape_schema, seed=None, ): """Crop randomly to `crop_size`, or keep as is if shorter than that.""" # We want each ensemble to be cropped the same way g = None if seed is not None: g = torch.Generator(device=protein["seq_length"].device) g.manual_seed(seed) seq_length = protein["seq_length"] num_res_crop_size = min(int(seq_length), crop_size) def _randint(lower, upper): return int(torch.randint( lower, upper + 1, (1,), device=protein["seq_length"].device, generator=g, )[0]) n = seq_length - num_res_crop_size if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.: right_anchor = n else: x = _randint(0, n) right_anchor = n - x num_res_crop_start = _randint(0, right_anchor) for k, v in protein.items(): if k not in shape_schema or (NUM_RES not in shape_schema[k]): continue slices = [] for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)): is_num_res = dim_size == NUM_RES crop_start = num_res_crop_start if is_num_res else 0 crop_size = num_res_crop_size if is_num_res else dim slices.append(slice(crop_start, crop_start + crop_size)) protein[k] = v[slices] protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size) return protein