FoldMark / protenix /openfold_local /data /data_transforms.py
Zaixi's picture
Add large file
89c0b51
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from protenix.openfold_local.np import residue_constants as rc
from protenix.openfold_local.utils.tensor_utils import batched_gather, tensor_tree_map, tree_map
MSA_FEATURE_NAMES = [
"msa",
"deletion_matrix",
"msa_mask",
"msa_row_mask",
"bert_mask",
"true_msa",
]
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
"""Create pseudo beta features."""
is_gly = torch.eq(aatype, rc.restype_order["G"])
ca_idx = rc.atom_order["CA"]
cb_idx = rc.atom_order["CB"]
pseudo_beta = torch.where(
torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :],
)
if all_atom_mask is not None:
pseudo_beta_mask = torch.where(
is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
)
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
def make_atom14_masks(protein):
"""Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = []
restype_atom37_to_atom14 = []
restype_atom14_mask = []
for rt in rc.restypes:
atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
restype_atom14_to_atom37.append(
[(rc.atom_order[name] if name else 0) for name in atom_names]
)
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append(
[
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in rc.atom_types
]
)
restype_atom14_mask.append([(1.0 if name else 0.0) for name in atom_names])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.0] * 14)
restype_atom14_to_atom37 = torch.tensor(
restype_atom14_to_atom37,
dtype=torch.int32,
device=protein["aatype"].device,
)
restype_atom37_to_atom14 = torch.tensor(
restype_atom37_to_atom14,
dtype=torch.int32,
device=protein["aatype"].device,
)
restype_atom14_mask = torch.tensor(
restype_atom14_mask,
dtype=torch.float32,
device=protein["aatype"].device,
)
protein_aatype = protein["aatype"].to(torch.long)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
residx_atom14_mask = restype_atom14_mask[protein_aatype]
protein["atom14_atom_exists"] = residx_atom14_mask
protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
# create the gather indices for mapping back
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
# create the corresponding mask
restype_atom37_mask = torch.zeros(
[21, 37], dtype=torch.float32, device=protein["aatype"].device
)
for restype, restype_letter in enumerate(rc.restypes):
restype_name = rc.restype_1to3[restype_letter]
atom_names = rc.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = rc.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[protein_aatype]
protein["atom37_atom_exists"] = residx_atom37_mask
return protein
def make_atom14_masks_np(batch):
batch = tree_map(lambda n: torch.tensor(n, device="cpu"), batch, np.ndarray)
out = make_atom14_masks(batch)
out = tensor_tree_map(lambda t: np.array(t), out)
return out
def make_atom14_positions(protein):
"""Constructs denser atom positions (14 dimensions instead of 37)."""
residx_atom14_mask = protein["atom14_atom_exists"]
residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"]
# Create a mask for known ground truth positions.
residx_atom14_gt_mask = residx_atom14_mask * batched_gather(
protein["all_atom_mask"],
residx_atom14_to_atom37,
dim=-1,
no_batch_dims=len(protein["all_atom_mask"].shape[:-1]),
)
# Gather the ground truth positions.
residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * (
batched_gather(
protein["all_atom_positions"],
residx_atom14_to_atom37,
dim=-2,
no_batch_dims=len(protein["all_atom_positions"].shape[:-2]),
)
)
protein["atom14_atom_exists"] = residx_atom14_mask
protein["atom14_gt_exists"] = residx_atom14_gt_mask
protein["atom14_gt_positions"] = residx_atom14_gt_positions
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative ground truth coordinates where the naming is swapped
restype_3 = [rc.restype_1to3[res] for res in rc.restypes]
restype_3 += ["UNK"]
# Matrices for renaming ambiguous atoms.
all_matrices = {
res: torch.eye(
14,
dtype=protein["all_atom_mask"].dtype,
device=protein["all_atom_mask"].device,
)
for res in restype_3
}
for resname, swap in rc.residue_atom_renaming_swaps.items():
correspondences = torch.arange(14, device=protein["all_atom_mask"].device)
for source_atom_swap, target_atom_swap in swap.items():
source_index = rc.restype_name_to_atom14_names[resname].index(
source_atom_swap
)
target_index = rc.restype_name_to_atom14_names[resname].index(
target_atom_swap
)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14))
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.0
all_matrices[resname] = renaming_matrix
renaming_matrices = torch.stack([all_matrices[restype] for restype in restype_3])
# Pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14).
renaming_transform = renaming_matrices[protein["aatype"]]
# Apply it to the ground truth positions. shape (num_res, 14, 3).
alternative_gt_positions = torch.einsum(
"...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform
)
protein["atom14_alt_gt_positions"] = alternative_gt_positions
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position).
alternative_gt_mask = torch.einsum(
"...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform
)
protein["atom14_alt_gt_exists"] = alternative_gt_mask
# Create an ambiguous atoms mask. shape: (21, 14).
restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14))
for resname, swap in rc.residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = rc.restype_order[rc.restype_3to1[resname]]
atom_idx1 = rc.restype_name_to_atom14_names[resname].index(atom_name1)
atom_idx2 = rc.restype_name_to_atom14_names[resname].index(atom_name2)
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
# From this create an ambiguous_mask for the given sequence.
protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[protein["aatype"]]
return protein