# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch from protenix.openfold_local.np import residue_constants as rc from protenix.openfold_local.utils.tensor_utils import batched_gather, tensor_tree_map, tree_map MSA_FEATURE_NAMES = [ "msa", "deletion_matrix", "msa_mask", "msa_row_mask", "bert_mask", "true_msa", ] def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask): """Create pseudo beta features.""" is_gly = torch.eq(aatype, rc.restype_order["G"]) ca_idx = rc.atom_order["CA"] cb_idx = rc.atom_order["CB"] pseudo_beta = torch.where( torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]), all_atom_positions[..., ca_idx, :], all_atom_positions[..., cb_idx, :], ) if all_atom_mask is not None: pseudo_beta_mask = torch.where( is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx] ) return pseudo_beta, pseudo_beta_mask else: return pseudo_beta def make_atom14_masks(protein): """Construct denser atom positions (14 dimensions instead of 37).""" restype_atom14_to_atom37 = [] restype_atom37_to_atom14 = [] restype_atom14_mask = [] for rt in rc.restypes: atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]] restype_atom14_to_atom37.append( [(rc.atom_order[name] if name else 0) for name in atom_names] ) atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} restype_atom37_to_atom14.append( [ (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in rc.atom_types ] ) restype_atom14_mask.append([(1.0 if name else 0.0) for name in atom_names]) # Add dummy mapping for restype 'UNK' restype_atom14_to_atom37.append([0] * 14) restype_atom37_to_atom14.append([0] * 37) restype_atom14_mask.append([0.0] * 14) restype_atom14_to_atom37 = torch.tensor( restype_atom14_to_atom37, dtype=torch.int32, device=protein["aatype"].device, ) restype_atom37_to_atom14 = torch.tensor( restype_atom37_to_atom14, dtype=torch.int32, device=protein["aatype"].device, ) restype_atom14_mask = torch.tensor( restype_atom14_mask, dtype=torch.float32, device=protein["aatype"].device, ) protein_aatype = protein["aatype"].to(torch.long) # create the mapping for (residx, atom14) --> atom37, i.e. an array # with shape (num_res, 14) containing the atom37 indices for this protein residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype] residx_atom14_mask = restype_atom14_mask[protein_aatype] protein["atom14_atom_exists"] = residx_atom14_mask protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long() # create the gather indices for mapping back residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype] protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long() # create the corresponding mask restype_atom37_mask = torch.zeros( [21, 37], dtype=torch.float32, device=protein["aatype"].device ) for restype, restype_letter in enumerate(rc.restypes): restype_name = rc.restype_1to3[restype_letter] atom_names = rc.residue_atoms[restype_name] for atom_name in atom_names: atom_type = rc.atom_order[atom_name] restype_atom37_mask[restype, atom_type] = 1 residx_atom37_mask = restype_atom37_mask[protein_aatype] protein["atom37_atom_exists"] = residx_atom37_mask return protein def make_atom14_masks_np(batch): batch = tree_map(lambda n: torch.tensor(n, device="cpu"), batch, np.ndarray) out = make_atom14_masks(batch) out = tensor_tree_map(lambda t: np.array(t), out) return out def make_atom14_positions(protein): """Constructs denser atom positions (14 dimensions instead of 37).""" residx_atom14_mask = protein["atom14_atom_exists"] residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"] # Create a mask for known ground truth positions. residx_atom14_gt_mask = residx_atom14_mask * batched_gather( protein["all_atom_mask"], residx_atom14_to_atom37, dim=-1, no_batch_dims=len(protein["all_atom_mask"].shape[:-1]), ) # Gather the ground truth positions. residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * ( batched_gather( protein["all_atom_positions"], residx_atom14_to_atom37, dim=-2, no_batch_dims=len(protein["all_atom_positions"].shape[:-2]), ) ) protein["atom14_atom_exists"] = residx_atom14_mask protein["atom14_gt_exists"] = residx_atom14_gt_mask protein["atom14_gt_positions"] = residx_atom14_gt_positions # As the atom naming is ambiguous for 7 of the 20 amino acids, provide # alternative ground truth coordinates where the naming is swapped restype_3 = [rc.restype_1to3[res] for res in rc.restypes] restype_3 += ["UNK"] # Matrices for renaming ambiguous atoms. all_matrices = { res: torch.eye( 14, dtype=protein["all_atom_mask"].dtype, device=protein["all_atom_mask"].device, ) for res in restype_3 } for resname, swap in rc.residue_atom_renaming_swaps.items(): correspondences = torch.arange(14, device=protein["all_atom_mask"].device) for source_atom_swap, target_atom_swap in swap.items(): source_index = rc.restype_name_to_atom14_names[resname].index( source_atom_swap ) target_index = rc.restype_name_to_atom14_names[resname].index( target_atom_swap ) correspondences[source_index] = target_index correspondences[target_index] = source_index renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14)) for index, correspondence in enumerate(correspondences): renaming_matrix[index, correspondence] = 1.0 all_matrices[resname] = renaming_matrix renaming_matrices = torch.stack([all_matrices[restype] for restype in restype_3]) # Pick the transformation matrices for the given residue sequence # shape (num_res, 14, 14). renaming_transform = renaming_matrices[protein["aatype"]] # Apply it to the ground truth positions. shape (num_res, 14, 3). alternative_gt_positions = torch.einsum( "...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform ) protein["atom14_alt_gt_positions"] = alternative_gt_positions # Create the mask for the alternative ground truth (differs from the # ground truth mask, if only one of the atoms in an ambiguous pair has a # ground truth position). alternative_gt_mask = torch.einsum( "...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform ) protein["atom14_alt_gt_exists"] = alternative_gt_mask # Create an ambiguous atoms mask. shape: (21, 14). restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14)) for resname, swap in rc.residue_atom_renaming_swaps.items(): for atom_name1, atom_name2 in swap.items(): restype = rc.restype_order[rc.restype_3to1[resname]] atom_idx1 = rc.restype_name_to_atom14_names[resname].index(atom_name1) atom_idx2 = rc.restype_name_to_atom14_names[resname].index(atom_name2) restype_atom14_is_ambiguous[restype, atom_idx1] = 1 restype_atom14_is_ambiguous[restype, atom_idx2] = 1 # From this create an ambiguous_mask for the given sequence. protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[protein["aatype"]] return protein