Last commit not found
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from protenix.metrics.rmsd import rmsd
from protenix.model.utils import expand_at_dim
from protenix.utils.distributed import traverse_and_aggregate
from protenix.utils.logger import get_logger
from protenix.utils.permutation.chain_permutation.utils import (
apply_transform,
get_optimal_transform,
num_unique_matches,
)
logger = get_logger(__name__)
def permute_pred_to_optimize_pocket_aligned_rmsd(
pred_coord: torch.Tensor, # [N_sample, N_atom, 3]
true_coord: torch.Tensor, # [N_atom, 3]
true_coord_mask: torch.Tensor,
true_pocket_mask: torch.Tensor,
true_ligand_mask: torch.Tensor,
atom_entity_id: torch.Tensor, # [N_atom]
atom_asym_id: torch.Tensor, # [N_atom]
mol_atom_index: torch.Tensor, # [N_atom]
use_center_rmsd: bool = False,
):
"""
Returns:
permute_pred_indices (list[torch.Tensor]): A list of LongTensor.
The list contains N_sample elements.
Each elements is a LongTensor of shape = [N_atom].
permuted_aligned_pred_coord (torch.Tensor): permuted and aligned coordinates of pred_coord.
[N_sample, N_atom, 3]
"""
log_dict = {}
atom_entity_id = atom_entity_id.long()
atom_asym_id = atom_asym_id.long()
mol_atom_index = mol_atom_index.long()
true_coord_mask = true_coord_mask.bool()
true_pocket_mask = true_pocket_mask.bool()
true_ligand_mask = true_ligand_mask.bool()
assert pred_coord.size(-2) == true_coord.size(-2), "Atom numbers are difference."
assert pred_coord.dim() == 3
# find entity_id/asym_id of pocket and ligand chains
def _get_entity_and_asym_id(atom_mask):
masked_asym_id = atom_asym_id[atom_mask]
masked_entity_id = atom_entity_id[atom_mask]
assert (masked_asym_id[0] == masked_asym_id).all()
assert (masked_entity_id[0] == masked_entity_id).all()
return masked_asym_id[0].item(), masked_entity_id[0].item()
pocket_asym_id, pocket_entity_id = _get_entity_and_asym_id(true_pocket_mask)
ligand_asym_id, ligand_entity_id = _get_entity_and_asym_id(true_ligand_mask)
candidate_pockets = {}
for i in torch.unique(atom_asym_id[atom_entity_id == pocket_entity_id]):
i = i.item()
pocket_mask = atom_asym_id == i
pocket_mask = pocket_mask * torch.isin(
mol_atom_index, mol_atom_index[true_pocket_mask]
)
assert pocket_mask.sum() == true_pocket_mask.sum()
candidate_pockets[i] = pocket_mask.clone()
candidate_ligands = {}
for j in torch.unique(atom_asym_id[atom_entity_id == ligand_entity_id]):
j = j.item()
lig_mask_j = atom_asym_id == j
if lig_mask_j.sum() != true_ligand_mask.sum():
logger.warning(
f"The ligand selected by 'mol_id' has {lig_mask_j.sum().item()} atoms."
+ f"The true ligand selected by 'asym_id' has {true_ligand_mask.sum().item()} atoms."
)
lig_mask_j = lig_mask_j * torch.isin(
mol_atom_index, mol_atom_index[true_ligand_mask]
)
assert lig_mask_j.sum() == true_ligand_mask.sum()
candidate_ligands[j] = lig_mask_j
log_dict["num_sym_pocket"] = len(candidate_pockets)
log_dict["num_sym_ligand"] = len(candidate_ligands)
log_dict["has_sym_chain"] = len(candidate_ligands) + len(candidate_pockets) > 2
# Enumerate over the batch dimension of pred_coord
# to find the optimal chain assignment for each sample.
def _find_protein_ligand_chains_for_one_sample(
coord: torch.Tensor,
):
best_results = {}
unpermuted_results = {}
for poc_asym_id, pocket_mask in candidate_pockets.items():
# Align pocket_i to true pocket
rot, trans = get_optimal_transform(
src_atoms=coord[pocket_mask].clone(),
tgt_atoms=true_coord[true_pocket_mask],
mask=true_coord_mask[true_pocket_mask],
)
# Transform predicted coordinates according to the aligment results
aligned_pred_coord = apply_transform(coord.clone(), rot=rot, trans=trans)
# Find the best ligand
ordered_lig_asym_ids = [i for i in candidate_ligands]
orderd_lig_masks = [candidate_ligands[i] for i in ordered_lig_asym_ids]
aligned_lig_coords = torch.stack(
[aligned_pred_coord[m] for m in orderd_lig_masks], dim=0
) # [N_lig, N_lig_atom, 3]
if use_center_rmsd:
mask = true_coord_mask[true_ligand_mask].bool() # [N_lig_atom]
aligned_lig_center = aligned_lig_coords[:, mask, :].mean(
dim=-2, keepdim=True
) # [N_lig, 1, 3]
true_coord_center = true_coord[true_ligand_mask][mask, :].mean(
dim=-2, keepdim=True
) # [1, 3]
per_lig_rmsd = rmsd(
aligned_lig_center, # [N_lig, 1, 3]
expand_at_dim(
true_coord_center,
dim=0,
n=aligned_lig_coords.size(0),
),
reduce=False,
) # [N_lig]
else:
per_lig_rmsd = rmsd(
aligned_lig_coords,
expand_at_dim(
true_coord[true_ligand_mask],
dim=0,
n=aligned_lig_coords.size(0),
),
mask=true_coord_mask[true_ligand_mask],
reduce=False,
) # [N_lig]
lig_rmsd, idx = per_lig_rmsd.min(dim=0)
lig_asym_id = ordered_lig_asym_ids[idx]
if lig_rmsd < best_results.get("rmsd", torch.inf):
best_results = {
"rmsd": lig_rmsd,
"pocket_asym_id": poc_asym_id,
"ligand_asym_id": lig_asym_id,
"aligned_pred_coord": aligned_pred_coord,
}
if poc_asym_id == pocket_asym_id:
# record the unpermuted result
i = ordered_lig_asym_ids.index(ligand_asym_id)
unpermuted_lig_rmsd = per_lig_rmsd[i].item()
unpermuted_results = {
"rmsd": unpermuted_lig_rmsd,
"aligned_pred_coord": aligned_pred_coord,
}
# record stats
per_sample_log_dict = {
"is_permuted": best_results["pocket_asym_id"] != pocket_asym_id
or best_results["ligand_asym_id"] != ligand_asym_id,
"is_permuted_pocket": best_results["pocket_asym_id"] != pocket_asym_id,
"is_permuted_ligand": best_results["ligand_asym_id"] != ligand_asym_id,
"algo:no_permute": best_results["pocket_asym_id"] == pocket_asym_id
and best_results["ligand_asym_id"] == ligand_asym_id,
}
improved_rmsd = (unpermuted_results["rmsd"] - best_results["rmsd"]).item()
if improved_rmsd >= 1e-12:
# better
per_sample_log_dict.update(
{
"algo:equivalent_permute": False,
"algo:worse_permute": False,
"algo:better_permute": True,
"algo:better_rmsd": improved_rmsd,
}
)
elif improved_rmsd < 0:
# worse
per_sample_log_dict.update(
{
"algo:equivalent_permute": False,
"algo:worse_permute": True,
"algo:better_permute": False,
"algo:worse_rmsd": -improved_rmsd,
}
)
elif per_sample_log_dict["is_permuted"]:
# equivalent
per_sample_log_dict.update(
{
"algo:equivalent_permute": True,
"algo:worse_permute": False,
"algo:better_permute": False,
}
)
# atom indices to permute coordinates
N_atom = aligned_pred_coord.size(-2)
device = aligned_pred_coord.device
atom_indices = torch.arange(N_atom, device=device)
permute_asym_pair = [
(best_results["pocket_asym_id"], pocket_asym_id),
(best_results["ligand_asym_id"], ligand_asym_id),
]
for asym_new, asym_old in permute_asym_pair:
if asym_new == asym_old:
continue
# switch two chains
ori_indices = atom_indices[atom_asym_id == asym_old]
new_indices = atom_indices[atom_asym_id == asym_new]
atom_indices[ori_indices.tolist()] = new_indices.clone()
atom_indices[new_indices.tolist()] = ori_indices.clone()
aligned_pred_coord = best_results.pop("aligned_pred_coord")[atom_indices, :]
per_sample_log_dict["rmsd"] = best_results["rmsd"].item()
return atom_indices, aligned_pred_coord, per_sample_log_dict
N_sample = pred_coord.size(0)
permute_pred_indices = []
permuted_aligned_pred_coord = []
sample_log_dicts = []
for i in range(N_sample):
atom_indices, aligned_pred_coord, per_sample_log_dict = (
_find_protein_ligand_chains_for_one_sample(pred_coord[i])
)
permute_pred_indices.append(atom_indices)
permuted_aligned_pred_coord.append(aligned_pred_coord)
sample_log_dicts.append(per_sample_log_dict)
permuted_aligned_pred_coord = torch.stack(permuted_aligned_pred_coord, dim=0)
log_dict.update(
traverse_and_aggregate(
sample_log_dicts, aggregation_func=lambda x_list: sum(x_list) / N_sample
)
)
# rmsd variance
all_sample_rmsd = torch.tensor([x["rmsd"] for x in sample_log_dicts]).float()
log_dict.update(
{
"rmsd_sample_std": all_sample_rmsd.std().item(),
"rmsd_sample_gap": (all_sample_rmsd.max() - all_sample_rmsd.min()).item(),
}
)
return permute_pred_indices, permuted_aligned_pred_coord, log_dict