|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
from protenix.openfold_local.np import residue_constants as rc |
|
from protenix.openfold_local.utils.tensor_utils import batched_gather, tensor_tree_map, tree_map |
|
|
|
MSA_FEATURE_NAMES = [ |
|
"msa", |
|
"deletion_matrix", |
|
"msa_mask", |
|
"msa_row_mask", |
|
"bert_mask", |
|
"true_msa", |
|
] |
|
|
|
|
|
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask): |
|
"""Create pseudo beta features.""" |
|
is_gly = torch.eq(aatype, rc.restype_order["G"]) |
|
ca_idx = rc.atom_order["CA"] |
|
cb_idx = rc.atom_order["CB"] |
|
pseudo_beta = torch.where( |
|
torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]), |
|
all_atom_positions[..., ca_idx, :], |
|
all_atom_positions[..., cb_idx, :], |
|
) |
|
|
|
if all_atom_mask is not None: |
|
pseudo_beta_mask = torch.where( |
|
is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx] |
|
) |
|
return pseudo_beta, pseudo_beta_mask |
|
else: |
|
return pseudo_beta |
|
|
|
|
|
def make_atom14_masks(protein): |
|
"""Construct denser atom positions (14 dimensions instead of 37).""" |
|
restype_atom14_to_atom37 = [] |
|
restype_atom37_to_atom14 = [] |
|
restype_atom14_mask = [] |
|
|
|
for rt in rc.restypes: |
|
atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]] |
|
restype_atom14_to_atom37.append( |
|
[(rc.atom_order[name] if name else 0) for name in atom_names] |
|
) |
|
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} |
|
restype_atom37_to_atom14.append( |
|
[ |
|
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) |
|
for name in rc.atom_types |
|
] |
|
) |
|
|
|
restype_atom14_mask.append([(1.0 if name else 0.0) for name in atom_names]) |
|
|
|
|
|
restype_atom14_to_atom37.append([0] * 14) |
|
restype_atom37_to_atom14.append([0] * 37) |
|
restype_atom14_mask.append([0.0] * 14) |
|
|
|
restype_atom14_to_atom37 = torch.tensor( |
|
restype_atom14_to_atom37, |
|
dtype=torch.int32, |
|
device=protein["aatype"].device, |
|
) |
|
restype_atom37_to_atom14 = torch.tensor( |
|
restype_atom37_to_atom14, |
|
dtype=torch.int32, |
|
device=protein["aatype"].device, |
|
) |
|
restype_atom14_mask = torch.tensor( |
|
restype_atom14_mask, |
|
dtype=torch.float32, |
|
device=protein["aatype"].device, |
|
) |
|
protein_aatype = protein["aatype"].to(torch.long) |
|
|
|
|
|
|
|
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype] |
|
residx_atom14_mask = restype_atom14_mask[protein_aatype] |
|
|
|
protein["atom14_atom_exists"] = residx_atom14_mask |
|
protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long() |
|
|
|
|
|
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype] |
|
protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long() |
|
|
|
|
|
restype_atom37_mask = torch.zeros( |
|
[21, 37], dtype=torch.float32, device=protein["aatype"].device |
|
) |
|
for restype, restype_letter in enumerate(rc.restypes): |
|
restype_name = rc.restype_1to3[restype_letter] |
|
atom_names = rc.residue_atoms[restype_name] |
|
for atom_name in atom_names: |
|
atom_type = rc.atom_order[atom_name] |
|
restype_atom37_mask[restype, atom_type] = 1 |
|
|
|
residx_atom37_mask = restype_atom37_mask[protein_aatype] |
|
protein["atom37_atom_exists"] = residx_atom37_mask |
|
|
|
return protein |
|
|
|
|
|
def make_atom14_masks_np(batch): |
|
batch = tree_map(lambda n: torch.tensor(n, device="cpu"), batch, np.ndarray) |
|
out = make_atom14_masks(batch) |
|
out = tensor_tree_map(lambda t: np.array(t), out) |
|
return out |
|
|
|
|
|
def make_atom14_positions(protein): |
|
"""Constructs denser atom positions (14 dimensions instead of 37).""" |
|
residx_atom14_mask = protein["atom14_atom_exists"] |
|
residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"] |
|
|
|
|
|
residx_atom14_gt_mask = residx_atom14_mask * batched_gather( |
|
protein["all_atom_mask"], |
|
residx_atom14_to_atom37, |
|
dim=-1, |
|
no_batch_dims=len(protein["all_atom_mask"].shape[:-1]), |
|
) |
|
|
|
|
|
residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * ( |
|
batched_gather( |
|
protein["all_atom_positions"], |
|
residx_atom14_to_atom37, |
|
dim=-2, |
|
no_batch_dims=len(protein["all_atom_positions"].shape[:-2]), |
|
) |
|
) |
|
|
|
protein["atom14_atom_exists"] = residx_atom14_mask |
|
protein["atom14_gt_exists"] = residx_atom14_gt_mask |
|
protein["atom14_gt_positions"] = residx_atom14_gt_positions |
|
|
|
|
|
|
|
restype_3 = [rc.restype_1to3[res] for res in rc.restypes] |
|
restype_3 += ["UNK"] |
|
|
|
|
|
all_matrices = { |
|
res: torch.eye( |
|
14, |
|
dtype=protein["all_atom_mask"].dtype, |
|
device=protein["all_atom_mask"].device, |
|
) |
|
for res in restype_3 |
|
} |
|
for resname, swap in rc.residue_atom_renaming_swaps.items(): |
|
correspondences = torch.arange(14, device=protein["all_atom_mask"].device) |
|
for source_atom_swap, target_atom_swap in swap.items(): |
|
source_index = rc.restype_name_to_atom14_names[resname].index( |
|
source_atom_swap |
|
) |
|
target_index = rc.restype_name_to_atom14_names[resname].index( |
|
target_atom_swap |
|
) |
|
correspondences[source_index] = target_index |
|
correspondences[target_index] = source_index |
|
renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14)) |
|
for index, correspondence in enumerate(correspondences): |
|
renaming_matrix[index, correspondence] = 1.0 |
|
all_matrices[resname] = renaming_matrix |
|
|
|
renaming_matrices = torch.stack([all_matrices[restype] for restype in restype_3]) |
|
|
|
|
|
|
|
renaming_transform = renaming_matrices[protein["aatype"]] |
|
|
|
|
|
alternative_gt_positions = torch.einsum( |
|
"...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform |
|
) |
|
protein["atom14_alt_gt_positions"] = alternative_gt_positions |
|
|
|
|
|
|
|
|
|
alternative_gt_mask = torch.einsum( |
|
"...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform |
|
) |
|
protein["atom14_alt_gt_exists"] = alternative_gt_mask |
|
|
|
|
|
restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14)) |
|
for resname, swap in rc.residue_atom_renaming_swaps.items(): |
|
for atom_name1, atom_name2 in swap.items(): |
|
restype = rc.restype_order[rc.restype_3to1[resname]] |
|
atom_idx1 = rc.restype_name_to_atom14_names[resname].index(atom_name1) |
|
atom_idx2 = rc.restype_name_to_atom14_names[resname].index(atom_name2) |
|
restype_atom14_is_ambiguous[restype, atom_idx1] = 1 |
|
restype_atom14_is_ambiguous[restype, atom_idx2] = 1 |
|
|
|
|
|
protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[protein["aatype"]] |
|
|
|
return protein |
|
|