# Copyright 2024 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from protenix.metrics.rmsd import rmsd from protenix.model.utils import expand_at_dim from protenix.utils.distributed import traverse_and_aggregate from protenix.utils.logger import get_logger from protenix.utils.permutation.chain_permutation.utils import ( apply_transform, get_optimal_transform, num_unique_matches, ) logger = get_logger(__name__) def permute_pred_to_optimize_pocket_aligned_rmsd( pred_coord: torch.Tensor, # [N_sample, N_atom, 3] true_coord: torch.Tensor, # [N_atom, 3] true_coord_mask: torch.Tensor, true_pocket_mask: torch.Tensor, true_ligand_mask: torch.Tensor, atom_entity_id: torch.Tensor, # [N_atom] atom_asym_id: torch.Tensor, # [N_atom] mol_atom_index: torch.Tensor, # [N_atom] use_center_rmsd: bool = False, ): """ Returns: permute_pred_indices (list[torch.Tensor]): A list of LongTensor. The list contains N_sample elements. Each elements is a LongTensor of shape = [N_atom]. permuted_aligned_pred_coord (torch.Tensor): permuted and aligned coordinates of pred_coord. [N_sample, N_atom, 3] """ log_dict = {} atom_entity_id = atom_entity_id.long() atom_asym_id = atom_asym_id.long() mol_atom_index = mol_atom_index.long() true_coord_mask = true_coord_mask.bool() true_pocket_mask = true_pocket_mask.bool() true_ligand_mask = true_ligand_mask.bool() assert pred_coord.size(-2) == true_coord.size(-2), "Atom numbers are difference." assert pred_coord.dim() == 3 # find entity_id/asym_id of pocket and ligand chains def _get_entity_and_asym_id(atom_mask): masked_asym_id = atom_asym_id[atom_mask] masked_entity_id = atom_entity_id[atom_mask] assert (masked_asym_id[0] == masked_asym_id).all() assert (masked_entity_id[0] == masked_entity_id).all() return masked_asym_id[0].item(), masked_entity_id[0].item() pocket_asym_id, pocket_entity_id = _get_entity_and_asym_id(true_pocket_mask) ligand_asym_id, ligand_entity_id = _get_entity_and_asym_id(true_ligand_mask) candidate_pockets = {} for i in torch.unique(atom_asym_id[atom_entity_id == pocket_entity_id]): i = i.item() pocket_mask = atom_asym_id == i pocket_mask = pocket_mask * torch.isin( mol_atom_index, mol_atom_index[true_pocket_mask] ) assert pocket_mask.sum() == true_pocket_mask.sum() candidate_pockets[i] = pocket_mask.clone() candidate_ligands = {} for j in torch.unique(atom_asym_id[atom_entity_id == ligand_entity_id]): j = j.item() lig_mask_j = atom_asym_id == j if lig_mask_j.sum() != true_ligand_mask.sum(): logger.warning( f"The ligand selected by 'mol_id' has {lig_mask_j.sum().item()} atoms." + f"The true ligand selected by 'asym_id' has {true_ligand_mask.sum().item()} atoms." ) lig_mask_j = lig_mask_j * torch.isin( mol_atom_index, mol_atom_index[true_ligand_mask] ) assert lig_mask_j.sum() == true_ligand_mask.sum() candidate_ligands[j] = lig_mask_j log_dict["num_sym_pocket"] = len(candidate_pockets) log_dict["num_sym_ligand"] = len(candidate_ligands) log_dict["has_sym_chain"] = len(candidate_ligands) + len(candidate_pockets) > 2 # Enumerate over the batch dimension of pred_coord # to find the optimal chain assignment for each sample. def _find_protein_ligand_chains_for_one_sample( coord: torch.Tensor, ): best_results = {} unpermuted_results = {} for poc_asym_id, pocket_mask in candidate_pockets.items(): # Align pocket_i to true pocket rot, trans = get_optimal_transform( src_atoms=coord[pocket_mask].clone(), tgt_atoms=true_coord[true_pocket_mask], mask=true_coord_mask[true_pocket_mask], ) # Transform predicted coordinates according to the aligment results aligned_pred_coord = apply_transform(coord.clone(), rot=rot, trans=trans) # Find the best ligand ordered_lig_asym_ids = [i for i in candidate_ligands] orderd_lig_masks = [candidate_ligands[i] for i in ordered_lig_asym_ids] aligned_lig_coords = torch.stack( [aligned_pred_coord[m] for m in orderd_lig_masks], dim=0 ) # [N_lig, N_lig_atom, 3] if use_center_rmsd: mask = true_coord_mask[true_ligand_mask].bool() # [N_lig_atom] aligned_lig_center = aligned_lig_coords[:, mask, :].mean( dim=-2, keepdim=True ) # [N_lig, 1, 3] true_coord_center = true_coord[true_ligand_mask][mask, :].mean( dim=-2, keepdim=True ) # [1, 3] per_lig_rmsd = rmsd( aligned_lig_center, # [N_lig, 1, 3] expand_at_dim( true_coord_center, dim=0, n=aligned_lig_coords.size(0), ), reduce=False, ) # [N_lig] else: per_lig_rmsd = rmsd( aligned_lig_coords, expand_at_dim( true_coord[true_ligand_mask], dim=0, n=aligned_lig_coords.size(0), ), mask=true_coord_mask[true_ligand_mask], reduce=False, ) # [N_lig] lig_rmsd, idx = per_lig_rmsd.min(dim=0) lig_asym_id = ordered_lig_asym_ids[idx] if lig_rmsd < best_results.get("rmsd", torch.inf): best_results = { "rmsd": lig_rmsd, "pocket_asym_id": poc_asym_id, "ligand_asym_id": lig_asym_id, "aligned_pred_coord": aligned_pred_coord, } if poc_asym_id == pocket_asym_id: # record the unpermuted result i = ordered_lig_asym_ids.index(ligand_asym_id) unpermuted_lig_rmsd = per_lig_rmsd[i].item() unpermuted_results = { "rmsd": unpermuted_lig_rmsd, "aligned_pred_coord": aligned_pred_coord, } # record stats per_sample_log_dict = { "is_permuted": best_results["pocket_asym_id"] != pocket_asym_id or best_results["ligand_asym_id"] != ligand_asym_id, "is_permuted_pocket": best_results["pocket_asym_id"] != pocket_asym_id, "is_permuted_ligand": best_results["ligand_asym_id"] != ligand_asym_id, "algo:no_permute": best_results["pocket_asym_id"] == pocket_asym_id and best_results["ligand_asym_id"] == ligand_asym_id, } improved_rmsd = (unpermuted_results["rmsd"] - best_results["rmsd"]).item() if improved_rmsd >= 1e-12: # better per_sample_log_dict.update( { "algo:equivalent_permute": False, "algo:worse_permute": False, "algo:better_permute": True, "algo:better_rmsd": improved_rmsd, } ) elif improved_rmsd < 0: # worse per_sample_log_dict.update( { "algo:equivalent_permute": False, "algo:worse_permute": True, "algo:better_permute": False, "algo:worse_rmsd": -improved_rmsd, } ) elif per_sample_log_dict["is_permuted"]: # equivalent per_sample_log_dict.update( { "algo:equivalent_permute": True, "algo:worse_permute": False, "algo:better_permute": False, } ) # atom indices to permute coordinates N_atom = aligned_pred_coord.size(-2) device = aligned_pred_coord.device atom_indices = torch.arange(N_atom, device=device) permute_asym_pair = [ (best_results["pocket_asym_id"], pocket_asym_id), (best_results["ligand_asym_id"], ligand_asym_id), ] for asym_new, asym_old in permute_asym_pair: if asym_new == asym_old: continue # switch two chains ori_indices = atom_indices[atom_asym_id == asym_old] new_indices = atom_indices[atom_asym_id == asym_new] atom_indices[ori_indices.tolist()] = new_indices.clone() atom_indices[new_indices.tolist()] = ori_indices.clone() aligned_pred_coord = best_results.pop("aligned_pred_coord")[atom_indices, :] per_sample_log_dict["rmsd"] = best_results["rmsd"].item() return atom_indices, aligned_pred_coord, per_sample_log_dict N_sample = pred_coord.size(0) permute_pred_indices = [] permuted_aligned_pred_coord = [] sample_log_dicts = [] for i in range(N_sample): atom_indices, aligned_pred_coord, per_sample_log_dict = ( _find_protein_ligand_chains_for_one_sample(pred_coord[i]) ) permute_pred_indices.append(atom_indices) permuted_aligned_pred_coord.append(aligned_pred_coord) sample_log_dicts.append(per_sample_log_dict) permuted_aligned_pred_coord = torch.stack(permuted_aligned_pred_coord, dim=0) log_dict.update( traverse_and_aggregate( sample_log_dicts, aggregation_func=lambda x_list: sum(x_list) / N_sample ) ) # rmsd variance all_sample_rmsd = torch.tensor([x["rmsd"] for x in sample_log_dicts]).float() log_dict.update( { "rmsd_sample_std": all_sample_rmsd.std().item(), "rmsd_sample_gap": (all_sample_rmsd.max() - all_sample_rmsd.min()).item(), } ) return permute_pred_indices, permuted_aligned_pred_coord, log_dict