|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from protenix.metrics.rmsd import rmsd |
|
from protenix.model.utils import expand_at_dim |
|
from protenix.utils.distributed import traverse_and_aggregate |
|
from protenix.utils.logger import get_logger |
|
from protenix.utils.permutation.chain_permutation.utils import ( |
|
apply_transform, |
|
get_optimal_transform, |
|
num_unique_matches, |
|
) |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
def permute_pred_to_optimize_pocket_aligned_rmsd( |
|
pred_coord: torch.Tensor, |
|
true_coord: torch.Tensor, |
|
true_coord_mask: torch.Tensor, |
|
true_pocket_mask: torch.Tensor, |
|
true_ligand_mask: torch.Tensor, |
|
atom_entity_id: torch.Tensor, |
|
atom_asym_id: torch.Tensor, |
|
mol_atom_index: torch.Tensor, |
|
use_center_rmsd: bool = False, |
|
): |
|
""" |
|
|
|
Returns: |
|
permute_pred_indices (list[torch.Tensor]): A list of LongTensor. |
|
The list contains N_sample elements. |
|
Each elements is a LongTensor of shape = [N_atom]. |
|
permuted_aligned_pred_coord (torch.Tensor): permuted and aligned coordinates of pred_coord. |
|
[N_sample, N_atom, 3] |
|
""" |
|
|
|
log_dict = {} |
|
atom_entity_id = atom_entity_id.long() |
|
atom_asym_id = atom_asym_id.long() |
|
mol_atom_index = mol_atom_index.long() |
|
true_coord_mask = true_coord_mask.bool() |
|
true_pocket_mask = true_pocket_mask.bool() |
|
true_ligand_mask = true_ligand_mask.bool() |
|
assert pred_coord.size(-2) == true_coord.size(-2), "Atom numbers are difference." |
|
assert pred_coord.dim() == 3 |
|
|
|
|
|
def _get_entity_and_asym_id(atom_mask): |
|
|
|
masked_asym_id = atom_asym_id[atom_mask] |
|
masked_entity_id = atom_entity_id[atom_mask] |
|
assert (masked_asym_id[0] == masked_asym_id).all() |
|
assert (masked_entity_id[0] == masked_entity_id).all() |
|
return masked_asym_id[0].item(), masked_entity_id[0].item() |
|
|
|
pocket_asym_id, pocket_entity_id = _get_entity_and_asym_id(true_pocket_mask) |
|
ligand_asym_id, ligand_entity_id = _get_entity_and_asym_id(true_ligand_mask) |
|
|
|
candidate_pockets = {} |
|
for i in torch.unique(atom_asym_id[atom_entity_id == pocket_entity_id]): |
|
i = i.item() |
|
pocket_mask = atom_asym_id == i |
|
pocket_mask = pocket_mask * torch.isin( |
|
mol_atom_index, mol_atom_index[true_pocket_mask] |
|
) |
|
assert pocket_mask.sum() == true_pocket_mask.sum() |
|
candidate_pockets[i] = pocket_mask.clone() |
|
candidate_ligands = {} |
|
for j in torch.unique(atom_asym_id[atom_entity_id == ligand_entity_id]): |
|
j = j.item() |
|
lig_mask_j = atom_asym_id == j |
|
if lig_mask_j.sum() != true_ligand_mask.sum(): |
|
logger.warning( |
|
f"The ligand selected by 'mol_id' has {lig_mask_j.sum().item()} atoms." |
|
+ f"The true ligand selected by 'asym_id' has {true_ligand_mask.sum().item()} atoms." |
|
) |
|
lig_mask_j = lig_mask_j * torch.isin( |
|
mol_atom_index, mol_atom_index[true_ligand_mask] |
|
) |
|
assert lig_mask_j.sum() == true_ligand_mask.sum() |
|
candidate_ligands[j] = lig_mask_j |
|
|
|
log_dict["num_sym_pocket"] = len(candidate_pockets) |
|
log_dict["num_sym_ligand"] = len(candidate_ligands) |
|
log_dict["has_sym_chain"] = len(candidate_ligands) + len(candidate_pockets) > 2 |
|
|
|
|
|
|
|
|
|
def _find_protein_ligand_chains_for_one_sample( |
|
coord: torch.Tensor, |
|
): |
|
best_results = {} |
|
unpermuted_results = {} |
|
for poc_asym_id, pocket_mask in candidate_pockets.items(): |
|
|
|
rot, trans = get_optimal_transform( |
|
src_atoms=coord[pocket_mask].clone(), |
|
tgt_atoms=true_coord[true_pocket_mask], |
|
mask=true_coord_mask[true_pocket_mask], |
|
) |
|
|
|
aligned_pred_coord = apply_transform(coord.clone(), rot=rot, trans=trans) |
|
|
|
|
|
ordered_lig_asym_ids = [i for i in candidate_ligands] |
|
orderd_lig_masks = [candidate_ligands[i] for i in ordered_lig_asym_ids] |
|
aligned_lig_coords = torch.stack( |
|
[aligned_pred_coord[m] for m in orderd_lig_masks], dim=0 |
|
) |
|
|
|
if use_center_rmsd: |
|
mask = true_coord_mask[true_ligand_mask].bool() |
|
aligned_lig_center = aligned_lig_coords[:, mask, :].mean( |
|
dim=-2, keepdim=True |
|
) |
|
true_coord_center = true_coord[true_ligand_mask][mask, :].mean( |
|
dim=-2, keepdim=True |
|
) |
|
per_lig_rmsd = rmsd( |
|
aligned_lig_center, |
|
expand_at_dim( |
|
true_coord_center, |
|
dim=0, |
|
n=aligned_lig_coords.size(0), |
|
), |
|
reduce=False, |
|
) |
|
else: |
|
per_lig_rmsd = rmsd( |
|
aligned_lig_coords, |
|
expand_at_dim( |
|
true_coord[true_ligand_mask], |
|
dim=0, |
|
n=aligned_lig_coords.size(0), |
|
), |
|
mask=true_coord_mask[true_ligand_mask], |
|
reduce=False, |
|
) |
|
lig_rmsd, idx = per_lig_rmsd.min(dim=0) |
|
lig_asym_id = ordered_lig_asym_ids[idx] |
|
|
|
if lig_rmsd < best_results.get("rmsd", torch.inf): |
|
best_results = { |
|
"rmsd": lig_rmsd, |
|
"pocket_asym_id": poc_asym_id, |
|
"ligand_asym_id": lig_asym_id, |
|
"aligned_pred_coord": aligned_pred_coord, |
|
} |
|
if poc_asym_id == pocket_asym_id: |
|
|
|
i = ordered_lig_asym_ids.index(ligand_asym_id) |
|
unpermuted_lig_rmsd = per_lig_rmsd[i].item() |
|
unpermuted_results = { |
|
"rmsd": unpermuted_lig_rmsd, |
|
"aligned_pred_coord": aligned_pred_coord, |
|
} |
|
|
|
|
|
per_sample_log_dict = { |
|
"is_permuted": best_results["pocket_asym_id"] != pocket_asym_id |
|
or best_results["ligand_asym_id"] != ligand_asym_id, |
|
"is_permuted_pocket": best_results["pocket_asym_id"] != pocket_asym_id, |
|
"is_permuted_ligand": best_results["ligand_asym_id"] != ligand_asym_id, |
|
"algo:no_permute": best_results["pocket_asym_id"] == pocket_asym_id |
|
and best_results["ligand_asym_id"] == ligand_asym_id, |
|
} |
|
improved_rmsd = (unpermuted_results["rmsd"] - best_results["rmsd"]).item() |
|
if improved_rmsd >= 1e-12: |
|
|
|
per_sample_log_dict.update( |
|
{ |
|
"algo:equivalent_permute": False, |
|
"algo:worse_permute": False, |
|
"algo:better_permute": True, |
|
"algo:better_rmsd": improved_rmsd, |
|
} |
|
) |
|
elif improved_rmsd < 0: |
|
|
|
per_sample_log_dict.update( |
|
{ |
|
"algo:equivalent_permute": False, |
|
"algo:worse_permute": True, |
|
"algo:better_permute": False, |
|
"algo:worse_rmsd": -improved_rmsd, |
|
} |
|
) |
|
elif per_sample_log_dict["is_permuted"]: |
|
|
|
per_sample_log_dict.update( |
|
{ |
|
"algo:equivalent_permute": True, |
|
"algo:worse_permute": False, |
|
"algo:better_permute": False, |
|
} |
|
) |
|
|
|
N_atom = aligned_pred_coord.size(-2) |
|
device = aligned_pred_coord.device |
|
atom_indices = torch.arange(N_atom, device=device) |
|
|
|
permute_asym_pair = [ |
|
(best_results["pocket_asym_id"], pocket_asym_id), |
|
(best_results["ligand_asym_id"], ligand_asym_id), |
|
] |
|
for asym_new, asym_old in permute_asym_pair: |
|
if asym_new == asym_old: |
|
continue |
|
|
|
ori_indices = atom_indices[atom_asym_id == asym_old] |
|
new_indices = atom_indices[atom_asym_id == asym_new] |
|
atom_indices[ori_indices.tolist()] = new_indices.clone() |
|
atom_indices[new_indices.tolist()] = ori_indices.clone() |
|
|
|
aligned_pred_coord = best_results.pop("aligned_pred_coord")[atom_indices, :] |
|
per_sample_log_dict["rmsd"] = best_results["rmsd"].item() |
|
|
|
return atom_indices, aligned_pred_coord, per_sample_log_dict |
|
|
|
N_sample = pred_coord.size(0) |
|
permute_pred_indices = [] |
|
permuted_aligned_pred_coord = [] |
|
sample_log_dicts = [] |
|
for i in range(N_sample): |
|
atom_indices, aligned_pred_coord, per_sample_log_dict = ( |
|
_find_protein_ligand_chains_for_one_sample(pred_coord[i]) |
|
) |
|
permute_pred_indices.append(atom_indices) |
|
permuted_aligned_pred_coord.append(aligned_pred_coord) |
|
sample_log_dicts.append(per_sample_log_dict) |
|
|
|
permuted_aligned_pred_coord = torch.stack(permuted_aligned_pred_coord, dim=0) |
|
|
|
log_dict.update( |
|
traverse_and_aggregate( |
|
sample_log_dicts, aggregation_func=lambda x_list: sum(x_list) / N_sample |
|
) |
|
) |
|
|
|
|
|
all_sample_rmsd = torch.tensor([x["rmsd"] for x in sample_log_dicts]).float() |
|
log_dict.update( |
|
{ |
|
"rmsd_sample_std": all_sample_rmsd.std().item(), |
|
"rmsd_sample_gap": (all_sample_rmsd.max() - all_sample_rmsd.min()).item(), |
|
} |
|
) |
|
|
|
return permute_pred_indices, permuted_aligned_pred_coord, log_dict |
|
|