|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Union |
|
|
|
import torch |
|
from ml_collections.config_dict import ConfigDict |
|
|
|
from protenix.metrics.clash import Clash |
|
from protenix.utils.distributed import traverse_and_aggregate |
|
|
|
|
|
def merge_per_sample_confidence_scores(summary_confidence_list: list[dict]) -> dict: |
|
""" |
|
Merge confidence scores from multiple samples into a single dictionary. |
|
|
|
Args: |
|
summary_confidence_list (list[dict]): List of dictionaries containing confidence scores for each sample. |
|
|
|
Returns: |
|
dict: Merged dictionary of confidence scores. |
|
""" |
|
|
|
def stack_score(tensor_list: list): |
|
if tensor_list[0].dim() == 0: |
|
tensor_list = [x.unsqueeze(0) for x in tensor_list] |
|
score = torch.stack(tensor_list, dim=0) |
|
return score |
|
|
|
return traverse_and_aggregate(summary_confidence_list, aggregation_func=stack_score) |
|
|
|
|
|
def _compute_full_data_and_summary( |
|
configs: ConfigDict, |
|
pae_logits: torch.Tensor, |
|
plddt_logits: torch.Tensor, |
|
pde_logits: torch.Tensor, |
|
contact_probs: torch.Tensor, |
|
token_asym_id: torch.Tensor, |
|
token_has_frame: torch.Tensor, |
|
atom_coordinate: torch.Tensor, |
|
atom_to_token_idx: torch.Tensor, |
|
atom_is_polymer: torch.Tensor, |
|
N_recycle: int, |
|
interested_atom_mask: Optional[torch.Tensor] = None, |
|
elements_one_hot: Optional[torch.Tensor] = None, |
|
mol_id: Optional[torch.Tensor] = None, |
|
return_full_data: bool = False, |
|
) -> tuple[list[dict], list[dict]]: |
|
""" |
|
Compute full data and summary confidence scores for the given inputs. |
|
|
|
Args: |
|
configs: Configuration object. |
|
pae_logits (torch.Tensor): Logits for PAE (Predicted Aligned Error). |
|
plddt_logits (torch.Tensor): Logits for pLDDT (Predicted Local Distance Difference Test). |
|
pde_logits (torch.Tensor): Logits for PDE (Predicted Distance Error). |
|
contact_probs (torch.Tensor): Contact probabilities. |
|
token_asym_id (torch.Tensor): Asymmetric ID for tokens. |
|
token_has_frame (torch.Tensor): Indicator for tokens having a frame. |
|
atom_coordinate (torch.Tensor): Atom coordinates. |
|
atom_to_token_idx (torch.Tensor): Mapping from atoms to tokens. |
|
atom_is_polymer (torch.Tensor): Indicator for atoms being part of a polymer. |
|
N_recycle (int): Number of recycles. |
|
interested_atom_mask (Optional[torch.Tensor]): Mask for interested atoms. Defaults to None. |
|
elements_one_hot (Optional[torch.Tensor]): One-hot encoding for elements. Defaults to None. |
|
mol_id (Optional[torch.Tensor]): Molecular ID. Defaults to None. |
|
return_full_data (bool): Whether to return full data. Defaults to False. |
|
|
|
Returns: |
|
tuple[list[dict], list[dict]]: |
|
- summary_confidence: List of dictionaries containing summary confidence scores. |
|
- full_data: List of dictionaries containing full data if `return_full_data` is True. |
|
""" |
|
atom_is_ligand = (1 - atom_is_polymer).long() |
|
token_is_ligand = torch.zeros_like(token_asym_id).scatter_add( |
|
0, atom_to_token_idx, atom_is_ligand |
|
) |
|
token_is_ligand = token_is_ligand > 0 |
|
|
|
full_data = {} |
|
full_data["atom_plddt"] = logits_to_score( |
|
plddt_logits, **get_bin_params(configs.loss.plddt) |
|
) |
|
|
|
pde_logits = pde_logits.to(plddt_logits.device) |
|
full_data["token_pair_pde"] = logits_to_score( |
|
pde_logits, **get_bin_params(configs.loss.pde) |
|
) |
|
del pde_logits |
|
full_data["contact_probs"] = contact_probs.clone() |
|
pae_logits = pae_logits.to(plddt_logits.device) |
|
full_data["token_pair_pae"], pae_prob = logits_to_score( |
|
pae_logits, **get_bin_params(configs.loss.pae), return_prob=True |
|
) |
|
del pae_logits |
|
|
|
summary_confidence = {} |
|
summary_confidence["plddt"] = full_data["atom_plddt"].mean(dim=-1) * 100 |
|
summary_confidence["gpde"] = ( |
|
full_data["token_pair_pde"] * full_data["contact_probs"] |
|
).sum(dim=[-1, -2]) / full_data["contact_probs"].sum(dim=[-1, -2]) |
|
|
|
summary_confidence["ptm"] = calculate_ptm( |
|
pae_prob, has_frame=token_has_frame, **get_bin_params(configs.loss.pae) |
|
) |
|
summary_confidence["iptm"] = calculate_iptm( |
|
pae_prob, |
|
has_frame=token_has_frame, |
|
asym_id=token_asym_id, |
|
**get_bin_params(configs.loss.pae) |
|
) |
|
|
|
|
|
summary_confidence.update( |
|
calculate_chain_based_ptm( |
|
pae_prob, |
|
has_frame=token_has_frame, |
|
asym_id=token_asym_id, |
|
token_is_ligand=token_is_ligand, |
|
**get_bin_params(configs.loss.pae) |
|
) |
|
) |
|
|
|
summary_confidence.update( |
|
calculate_chain_based_plddt( |
|
full_data["atom_plddt"], token_asym_id, atom_to_token_idx |
|
) |
|
) |
|
del pae_prob |
|
summary_confidence["has_clash"] = calculate_clash( |
|
atom_coordinate, |
|
token_asym_id, |
|
atom_to_token_idx, |
|
atom_is_polymer, |
|
configs.metrics.clash.af3_clash_threshold, |
|
) |
|
summary_confidence["num_recycles"] = torch.tensor( |
|
N_recycle, device=atom_coordinate.device |
|
) |
|
|
|
summary_confidence["disorder"] = torch.zeros_like(summary_confidence["ptm"]) |
|
summary_confidence["ranking_score"] = ( |
|
0.8 * summary_confidence["iptm"] |
|
+ 0.2 * summary_confidence["ptm"] |
|
+ 0.5 * summary_confidence["disorder"] |
|
- 100 * summary_confidence["has_clash"] |
|
) |
|
if interested_atom_mask is not None: |
|
token_idx = atom_to_token_idx[interested_atom_mask[0].bool()].long() |
|
asym_ids = token_asym_id[token_idx] |
|
assert len(torch.unique(asym_ids)) == 1 |
|
interested_asym_id = asym_ids[0].item() |
|
N_chains = token_asym_id.max().long().item() + 1 |
|
pb_ranking_score = summary_confidence["chain_pair_iptm_global"][ |
|
:, interested_asym_id, torch.arange(N_chains) != interested_asym_id |
|
] |
|
summary_confidence["pb_ranking_score"] = pb_ranking_score[:, 0] |
|
if elements_one_hot is not None and mol_id is not None: |
|
vdw_clash = calculate_vdw_clash( |
|
pred_coordinate=atom_coordinate, |
|
asym_id=token_asym_id, |
|
mol_id=mol_id, |
|
is_polymer=atom_is_polymer, |
|
atom_token_idx=atom_to_token_idx, |
|
elements_one_hot=elements_one_hot, |
|
threshold=configs.metrics.clash.vdw_clash_threshold, |
|
) |
|
N_sample = atom_coordinate.shape[0] |
|
vdw_clash_per_sample_flag = ( |
|
vdw_clash[:, interested_asym_id, :].reshape(N_sample, -1).max(dim=-1)[0] |
|
) |
|
summary_confidence["has_vdw_pl_clash"] = vdw_clash_per_sample_flag |
|
summary_confidence["pb_ranking_score_vdw_penalized"] = ( |
|
summary_confidence["pb_ranking_score"] - 100 * vdw_clash_per_sample_flag |
|
) |
|
|
|
summary_confidence = break_down_to_per_sample_dict( |
|
summary_confidence, shared_keys=["num_recycles"] |
|
) |
|
torch.cuda.empty_cache() |
|
if return_full_data: |
|
|
|
full_data["token_has_frame"] = token_has_frame.clone() |
|
full_data["token_asym_id"] = token_asym_id.clone() |
|
full_data["atom_to_token_idx"] = atom_to_token_idx.clone() |
|
full_data["atom_is_polymer"] = atom_is_polymer.clone() |
|
full_data["atom_coordinate"] = atom_coordinate.clone() |
|
|
|
full_data = break_down_to_per_sample_dict( |
|
full_data, |
|
shared_keys=[ |
|
"contact_probs", |
|
"token_has_frame", |
|
"token_asym_id", |
|
"atom_to_token_idx", |
|
"atom_is_polymer", |
|
], |
|
) |
|
return summary_confidence, full_data |
|
else: |
|
return summary_confidence, [{}] |
|
|
|
|
|
def get_bin_params(cfg: ConfigDict) -> dict: |
|
""" |
|
Extract bin parameters from the configuration object. |
|
""" |
|
return {"min_bin": cfg.min_bin, "max_bin": cfg.max_bin, "no_bins": cfg.no_bins} |
|
|
|
|
|
def compute_contact_prob( |
|
distogram_logits: torch.Tensor, |
|
min_bin: float, |
|
max_bin: float, |
|
no_bins: int, |
|
thres=8.0, |
|
) -> torch.Tensor: |
|
""" |
|
Compute the contact probability from distogram logits. |
|
|
|
Args: |
|
distogram_logits (torch.Tensor): Logits for the distogram. |
|
Shape: [N_token, N_token, N_bins] |
|
min_bin (float): Minimum bin value. |
|
max_bin (float): Maximum bin value. |
|
no_bins (int): Number of bins. |
|
thres (float): Threshold distance for contact probability. Defaults to 8.0. |
|
|
|
Returns: |
|
torch.Tensor: Contact probability. |
|
Shape: [N_token, N_token] |
|
""" |
|
distogram_prob = torch.nn.functional.softmax( |
|
distogram_logits, dim=-1 |
|
) |
|
distogram_bins = get_bin_centers(min_bin, max_bin, no_bins) |
|
thres_idx = (distogram_bins < thres).sum() |
|
contact_prob = distogram_prob[..., :thres_idx].sum(-1) |
|
return contact_prob |
|
|
|
|
|
def get_bin_centers(min_bin: float, max_bin: float, no_bins: int) -> torch.Tensor: |
|
""" |
|
Calculate the centers of the bins for a given range and number of bins. |
|
|
|
Args: |
|
min_bin (float): The minimum value of the bin range. |
|
max_bin (float): The maximum value of the bin range. |
|
no_bins (int): The number of bins. |
|
|
|
Returns: |
|
torch.Tensor: The centers of the bins. |
|
Shape: [no_bins] |
|
""" |
|
bin_width = (max_bin - min_bin) / no_bins |
|
boundaries = torch.linspace( |
|
start=min_bin, |
|
end=max_bin - bin_width, |
|
steps=no_bins, |
|
) |
|
bin_centers = boundaries + 0.5 * bin_width |
|
return bin_centers |
|
|
|
|
|
def logits_to_prob(logits: torch.Tensor, dim=-1) -> torch.Tensor: |
|
return torch.nn.functional.softmax(logits, dim=dim) |
|
|
|
|
|
def logits_to_score( |
|
logits: torch.Tensor, |
|
min_bin: float, |
|
max_bin: float, |
|
no_bins: int, |
|
return_prob=False, |
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: |
|
""" |
|
Convert logits to a score using bin centers. |
|
|
|
Args: |
|
logits (torch.Tensor): Logits tensor. |
|
Shape: [..., no_bins] |
|
min_bin (float): Minimum bin value. |
|
max_bin (float): Maximum bin value. |
|
no_bins (int): Number of bins. |
|
return_prob (bool): Whether to return the probability distribution. Defaults to False. |
|
|
|
Returns: |
|
score (torch.Tensor): Converted score. |
|
Shape: [...] |
|
prob (torch.Tensor, optional): Probability distribution if `return_prob` is True. |
|
Shape: [..., no_bins] |
|
""" |
|
prob = logits_to_prob(logits, dim=-1) |
|
bin_centers = get_bin_centers(min_bin, max_bin, no_bins).to(logits.device) |
|
score = prob @ bin_centers |
|
if return_prob: |
|
return score, prob |
|
else: |
|
return score |
|
|
|
|
|
def calculate_normalization(N): |
|
|
|
return 1.24 * (max(N, 19) - 15) ** (1 / 3) - 1.8 |
|
|
|
|
|
def calculate_vdw_clash( |
|
pred_coordinate: torch.Tensor, |
|
asym_id: torch.LongTensor, |
|
mol_id: torch.LongTensor, |
|
atom_token_idx: torch.LongTensor, |
|
is_polymer: torch.BoolTensor, |
|
elements_one_hot: torch.Tensor, |
|
threshold: float, |
|
) -> torch.Tensor: |
|
""" |
|
Calculate Van der Waals (VDW) clash for predicted coordinates. |
|
|
|
Args: |
|
pred_coordinate (torch.Tensor): Predicted coordinates of atoms. |
|
Shape: [N_sample, N_atom, 3] |
|
asym_id (torch.LongTensor): Asymmetric ID for tokens. |
|
Shape: [N_token] |
|
mol_id (torch.LongTensor): Molecular ID. |
|
Shape: [N_atom] |
|
atom_token_idx (torch.LongTensor): Mapping from atoms to tokens. |
|
Shape: [N_atom] |
|
is_polymer (torch.BoolTensor): Indicator for atoms being part of a polymer. |
|
Shape: [N_atom] |
|
elements_one_hot (torch.Tensor): One-hot encoding for elements. |
|
Shape: [N_atom, N_elements] |
|
threshold (float): Threshold for VDW clash detection. |
|
|
|
Returns: |
|
torch.Tensor: VDW clash summary. |
|
Shape: [N_sample] |
|
""" |
|
clash_calculator = Clash(vdw_clash_threshold=threshold, compute_af3_clash=False) |
|
|
|
N_sample = pred_coordinate.shape[0] |
|
dummy_is_dna = torch.zeros_like(is_polymer) |
|
dummy_is_rna = torch.zeros_like(is_polymer) |
|
clash_dict = clash_calculator( |
|
pred_coordinate=pred_coordinate, |
|
asym_id=asym_id, |
|
atom_to_token_idx=atom_token_idx, |
|
mol_id=mol_id, |
|
is_ligand=1 - is_polymer, |
|
is_protein=is_polymer, |
|
is_dna=dummy_is_dna, |
|
is_rna=dummy_is_rna, |
|
elements_one_hot=elements_one_hot, |
|
) |
|
return clash_dict["summary"]["vdw_clash"] |
|
|
|
|
|
def calculate_clash( |
|
pred_coordinate: torch.Tensor, |
|
asym_id: torch.LongTensor, |
|
atom_to_token_idx: torch.LongTensor, |
|
is_polymer: torch.BoolTensor, |
|
threshold: float, |
|
) -> torch.Tensor: |
|
"""Check complex clash |
|
|
|
Args: |
|
pred_coordinate (torch.Tensor): [N_sample, N_atom, 3] |
|
asym_id (torch.LongTensor): [N_token, ] |
|
atom_to_token_idx (torch.LongTensor): [N_atom, ] |
|
is_polymer (torch.BoolTensor): [N_atom, ] |
|
threshold: (float) |
|
|
|
Returns: |
|
torch.Tensor: [N_sample] whether there is a clash in the complex |
|
""" |
|
N_sample = pred_coordinate.shape[0] |
|
dummy_is_dna = torch.zeros_like(is_polymer) |
|
dummy_is_rna = torch.zeros_like(is_polymer) |
|
clash_calculator = Clash(vdw_clash_threshold=threshold, compute_vdw_clash=False) |
|
clash_dict = clash_calculator( |
|
pred_coordinate, |
|
asym_id, |
|
atom_to_token_idx, |
|
1 - is_polymer, |
|
is_polymer, |
|
dummy_is_dna, |
|
dummy_is_rna, |
|
) |
|
return clash_dict["summary"]["af3_clash"].reshape(N_sample, -1).max(dim=-1)[0] |
|
|
|
|
|
def calculate_ptm( |
|
pae_prob: torch.Tensor, |
|
has_frame: torch.BoolTensor, |
|
min_bin: float, |
|
max_bin: float, |
|
no_bins: int, |
|
token_mask: Optional[torch.BoolTensor] = None, |
|
) -> torch.Tensor: |
|
"""Compute pTM score |
|
|
|
Args: |
|
pae_prob (torch.Tensor): Predicted probability from PAE loss head. |
|
Shape: [..., N_token, N_token, N_bins] |
|
has_frame (torch.BoolTensor): Indicator for tokens having a frame. |
|
Shape: [N_token, ] |
|
min_bin (float): Minimum bin value. |
|
max_bin (float): Maximum bin value. |
|
no_bins (int): Number of bins. |
|
token_mask (Optional[torch.BoolTensor]): Mask for tokens. |
|
Shape: [N_token, ] or None |
|
|
|
Returns: |
|
torch.Tensor: pTM score. Higher values indicate better ranking. |
|
Shape: [...] |
|
""" |
|
has_frame = has_frame.bool() |
|
|
|
if token_mask is not None: |
|
token_mask = token_mask.bool() |
|
pae_prob = pae_prob[..., token_mask, :, :][ |
|
..., :, token_mask, : |
|
] |
|
has_frame = has_frame[token_mask] |
|
|
|
if has_frame.sum() == 0: |
|
return torch.zeros(size=pae_prob.shape[:-3], device=pae_prob.device) |
|
|
|
N_d = has_frame.shape[-1] |
|
ptm_norm = calculate_normalization(N_d) |
|
|
|
bin_center = get_bin_centers(min_bin, max_bin, no_bins) |
|
per_bin_weight = (1 / (1 + (bin_center / ptm_norm) ** 2)).to( |
|
pae_prob.device |
|
) |
|
|
|
token_token_ptm = (pae_prob * per_bin_weight).sum(dim=-1) |
|
|
|
ptm = token_token_ptm.mean(dim=-1)[..., has_frame].max(dim=-1).values |
|
return ptm |
|
|
|
|
|
def calculate_chain_based_ptm( |
|
pae_prob: torch.Tensor, |
|
has_frame: torch.BoolTensor, |
|
asym_id: torch.LongTensor, |
|
token_is_ligand: torch.BoolTensor, |
|
min_bin: float, |
|
max_bin: float, |
|
no_bins: int, |
|
) -> dict[str, torch.Tensor]: |
|
""" |
|
Compute chain-based pTM scores. |
|
|
|
Args: |
|
pae_prob (torch.Tensor): Predicted probability from PAE loss head. |
|
Shape: [..., N_token, N_token, N_bins] |
|
has_frame (torch.BoolTensor): Indicator for tokens having a frame. |
|
Shape: [N_token, ] |
|
asym_id (torch.LongTensor): Asymmetric ID for tokens. |
|
Shape: [N_token, ] |
|
token_is_ligand (torch.BoolTensor): Indicator for tokens being ligands. |
|
Shape: [N_token, ] |
|
min_bin (float): Minimum bin value. |
|
max_bin (float): Maximum bin value. |
|
no_bins (int): Number of bins. |
|
|
|
Returns: |
|
dict: Dictionary containing chain-based pTM scores. |
|
- chain_ptm (torch.Tensor): pTM scores for each chain. |
|
- chain_iptm (torch.Tensor): ipTM scores for chain interface. |
|
- chain_pair_iptm (torch.Tensor): Pairwise ipTM scores between chains. |
|
- chain_pair_iptm_global (torch.Tensor): Global pairwise ipTM scores between chains. |
|
""" |
|
|
|
has_frame = has_frame.bool() |
|
asym_id = asym_id.long() |
|
asym_id_to_asym_mask = {aid.item(): asym_id == aid for aid in torch.unique(asym_id)} |
|
chain_is_ligand = { |
|
aid.item(): token_is_ligand[asym_id == aid].sum() >= (asym_id == aid).sum() // 2 |
|
for aid in torch.unique(asym_id) |
|
} |
|
|
|
batch_shape = pae_prob.shape[:-3] |
|
|
|
|
|
|
|
N_chain = len(asym_id_to_asym_mask) |
|
chain_pair_iptm = torch.zeros(size=batch_shape + (N_chain, N_chain)).to( |
|
pae_prob.device |
|
) |
|
for aid_1 in range(N_chain): |
|
for aid_2 in range(N_chain): |
|
if aid_1 == aid_2: |
|
continue |
|
if aid_1 > aid_2: |
|
chain_pair_iptm[:, aid_1, aid_2] = chain_pair_iptm[:, aid_2, aid_1] |
|
continue |
|
pair_mask = asym_id_to_asym_mask[aid_1] + asym_id_to_asym_mask[aid_2] |
|
chain_pair_iptm[:, aid_1, aid_2] = calculate_iptm( |
|
pae_prob, |
|
has_frame, |
|
asym_id, |
|
min_bin, |
|
max_bin, |
|
no_bins, |
|
token_mask=pair_mask, |
|
) |
|
|
|
|
|
chain_ptm = torch.zeros(size=batch_shape + (N_chain,)).to(pae_prob.device) |
|
for aid, asym_mask in asym_id_to_asym_mask.items(): |
|
chain_ptm[:, aid] = calculate_ptm( |
|
pae_prob, |
|
has_frame, |
|
min_bin, |
|
max_bin, |
|
no_bins, |
|
token_mask=asym_mask, |
|
) |
|
|
|
|
|
chain_has_frame = [ |
|
(asym_id_to_asym_mask[i] * has_frame).any() for i in range(N_chain) |
|
] |
|
|
|
chain_iptm = torch.zeros(size=batch_shape + (N_chain,)).to(pae_prob.device) |
|
for aid, asym_mask in asym_id_to_asym_mask.items(): |
|
pairs = [ |
|
(i, j) |
|
for i in range(N_chain) |
|
for j in range(N_chain) |
|
if (i == aid or j == aid) and (i != j) and chain_has_frame[i] |
|
] |
|
vals = [chain_pair_iptm[:, i, j] for (i, j) in pairs] |
|
if len(vals) > 0: |
|
chain_iptm[:, aid] = torch.stack(vals, dim=-1).mean(dim=-1) |
|
|
|
|
|
chain_pair_iptm_global = torch.zeros(size=batch_shape + (N_chain, N_chain)).to( |
|
pae_prob.device |
|
) |
|
for aid_1 in range(N_chain): |
|
for aid_2 in range(N_chain): |
|
if aid_1 == aid_2: |
|
continue |
|
if chain_is_ligand[aid_1]: |
|
chain_pair_iptm_global[:, aid_1, aid_2] = chain_iptm[:, aid_1] |
|
elif chain_is_ligand[aid_2]: |
|
chain_pair_iptm_global[:, aid_1, aid_2] = chain_iptm[:, aid_2] |
|
else: |
|
chain_pair_iptm_global[:, aid_1, aid_2] = ( |
|
chain_iptm[:, aid_1] + chain_iptm[:, aid_2] |
|
) * 0.5 |
|
|
|
return { |
|
"chain_ptm": chain_ptm, |
|
"chain_iptm": chain_iptm, |
|
"chain_pair_iptm": chain_pair_iptm, |
|
"chain_pair_iptm_global": chain_pair_iptm_global, |
|
} |
|
|
|
|
|
def calculate_chain_based_plddt( |
|
atom_plddt: torch.Tensor, |
|
asym_id: torch.LongTensor, |
|
atom_to_token_idx: torch.LongTensor, |
|
) -> dict[str, torch.Tensor]: |
|
""" |
|
Calculate chain-based pLDDT scores. |
|
|
|
Args: |
|
atom_plddt (torch.Tensor): Predicted pLDDT scores for atoms. |
|
Shape: [N_sample, N_atom] |
|
asym_id (torch.LongTensor): Asymmetric ID for tokens. |
|
Shape: [N_token] |
|
atom_to_token_idx (torch.LongTensor): Mapping from atoms to tokens. |
|
Shape: [N_atom] |
|
|
|
Returns: |
|
dict: Dictionary containing chain-based pLDDT scores. |
|
- chain_plddt (torch.Tensor): pLDDT scores for each chain. |
|
- chain_pair_plddt (torch.Tensor): Pairwise pLDDT scores between chains. |
|
""" |
|
|
|
asym_id = asym_id.long() |
|
asym_id_to_asym_mask = {aid.item(): asym_id == aid for aid in torch.unique(asym_id)} |
|
N_chain = len(asym_id_to_asym_mask) |
|
assert N_chain == asym_id.max() + 1 |
|
|
|
def _calculate_lddt_with_token_mask(token_mask): |
|
atom_mask = token_mask[atom_to_token_idx] |
|
sub_plddt = atom_plddt[:, atom_mask].mean(-1) |
|
return sub_plddt |
|
|
|
batch_shape = atom_plddt.shape[:-1] |
|
|
|
chain_plddt = torch.zeros(size=batch_shape + (N_chain,)).to(atom_plddt.device) |
|
for aid, asym_mask in asym_id_to_asym_mask.items(): |
|
chain_plddt[:, aid] = _calculate_lddt_with_token_mask(token_mask=asym_mask) |
|
|
|
|
|
chain_pair_plddt = torch.zeros(size=batch_shape + (N_chain, N_chain)).to( |
|
atom_plddt.device |
|
) |
|
for aid_1 in asym_id_to_asym_mask: |
|
for aid_2 in asym_id_to_asym_mask: |
|
if aid_1 == aid_2: |
|
continue |
|
pair_mask = asym_id_to_asym_mask[aid_1] + asym_id_to_asym_mask[aid_2] |
|
chain_pair_plddt[:, aid_1, aid_2] = _calculate_lddt_with_token_mask( |
|
token_mask=pair_mask |
|
) |
|
|
|
return {"chain_plddt": chain_plddt, "chain_pair_plddt": chain_pair_plddt} |
|
|
|
|
|
def calculate_iptm( |
|
pae_prob: torch.Tensor, |
|
has_frame: torch.BoolTensor, |
|
asym_id: torch.LongTensor, |
|
min_bin: float, |
|
max_bin: float, |
|
no_bins: int, |
|
token_mask: Optional[torch.BoolTensor] = None, |
|
eps: float = 1e-8, |
|
): |
|
""" |
|
Compute ipTM score. |
|
|
|
Args: |
|
pae_prob (torch.Tensor): Predicted probability from PAE loss head. |
|
Shape: [..., N_token, N_token, N_bins] |
|
has_frame (torch.BoolTensor): Indicator for tokens having a frame. |
|
Shape: [N_token, ] |
|
asym_id (torch.LongTensor): Asymmetric ID for tokens. |
|
Shape: [N_token, ] |
|
min_bin (float): Minimum bin value. |
|
max_bin (float): Maximum bin value. |
|
no_bins (int): Number of bins. |
|
token_mask (Optional[torch.BoolTensor]): Mask for tokens. |
|
Shape: [N_token, ] or None |
|
eps (float): Small value to avoid division by zero. Defaults to 1e-8. |
|
|
|
Returns: |
|
torch.Tensor: ipTM score. Higher values indicate better ranking. |
|
Shape: [...] |
|
""" |
|
has_frame = has_frame.bool() |
|
if token_mask is not None: |
|
token_mask = token_mask.bool() |
|
pae_prob = pae_prob[..., token_mask, :, :][ |
|
..., :, token_mask, : |
|
] |
|
has_frame = has_frame[token_mask] |
|
asym_id = asym_id[token_mask] |
|
|
|
if has_frame.sum() == 0: |
|
return torch.zeros(size=pae_prob.shape[:-3], device=pae_prob.device) |
|
|
|
N_d = has_frame.shape[-1] |
|
ptm_norm = calculate_normalization(N_d) |
|
|
|
bin_center = get_bin_centers(min_bin, max_bin, no_bins) |
|
per_bin_weight = (1 / (1 + (bin_center / ptm_norm) ** 2)).to( |
|
pae_prob.device |
|
) |
|
|
|
token_token_ptm = (pae_prob * per_bin_weight).sum(dim=-1) |
|
|
|
is_diff_chain = asym_id[None, :] != asym_id[:, None] |
|
|
|
iptm = (token_token_ptm * is_diff_chain).sum(dim=-1) / ( |
|
eps + is_diff_chain.sum(dim=-1) |
|
) |
|
iptm = iptm[..., has_frame].max(dim=-1).values |
|
|
|
return iptm |
|
|
|
|
|
def break_down_to_per_sample_dict(input_dict: dict, shared_keys=[]) -> list[dict]: |
|
""" |
|
Break down a dictionary containing tensors into a list of dictionaries, each corresponding to a sample. |
|
|
|
Args: |
|
input_dict (dict): Dictionary containing tensors. |
|
shared_keys (list): List of keys that are shared across all samples. Defaults to an empty list. |
|
|
|
Returns: |
|
list[dict]: List of dictionaries, each containing data for a single sample. |
|
""" |
|
per_sample_keys = [key for key in input_dict if key not in shared_keys] |
|
assert len(per_sample_keys) > 0 |
|
N_sample = input_dict[per_sample_keys[0]].size(0) |
|
for key in per_sample_keys: |
|
assert input_dict[key].size(0) == N_sample |
|
|
|
per_sample_dict_list = [] |
|
for i in range(N_sample): |
|
sample_dict = {key: input_dict[key][i] for key in per_sample_keys} |
|
sample_dict.update({key: input_dict[key] for key in shared_keys}) |
|
per_sample_dict_list.append(sample_dict) |
|
|
|
return per_sample_dict_list |
|
|
|
|
|
@torch.no_grad() |
|
def compute_full_data_and_summary( |
|
configs, |
|
pae_logits, |
|
plddt_logits, |
|
pde_logits, |
|
contact_probs, |
|
token_asym_id, |
|
token_has_frame, |
|
atom_coordinate, |
|
atom_to_token_idx, |
|
atom_is_polymer, |
|
N_recycle, |
|
return_full_data: bool = False, |
|
interested_atom_mask=None, |
|
mol_id=None, |
|
elements_one_hot=None, |
|
): |
|
"""Wrapper of `_compute_full_data_and_summary` by enumerating over N samples""" |
|
|
|
N_sample = pae_logits.size(0) |
|
if contact_probs.dim() == 2: |
|
|
|
contact_probs = contact_probs.unsqueeze(dim=0).expand(N_sample, -1, -1) |
|
else: |
|
assert contact_probs.dim() == 3 |
|
assert ( |
|
contact_probs.size(0) == plddt_logits.size(0) == pde_logits.size(0) == N_sample |
|
) |
|
|
|
summary_confidence = [] |
|
full_data = [] |
|
for i in range(N_sample): |
|
summary_confidence_i, full_data_i = _compute_full_data_and_summary( |
|
configs=configs, |
|
pae_logits=pae_logits[i : i + 1], |
|
plddt_logits=plddt_logits[i : i + 1], |
|
pde_logits=pde_logits[i : i + 1], |
|
contact_probs=contact_probs[i], |
|
token_asym_id=token_asym_id, |
|
token_has_frame=token_has_frame, |
|
atom_coordinate=atom_coordinate[i : i + 1], |
|
atom_to_token_idx=atom_to_token_idx, |
|
atom_is_polymer=atom_is_polymer, |
|
N_recycle=N_recycle, |
|
interested_atom_mask=interested_atom_mask, |
|
return_full_data=return_full_data, |
|
mol_id=mol_id, |
|
elements_one_hot=elements_one_hot, |
|
) |
|
summary_confidence.extend(summary_confidence_i) |
|
full_data.extend(full_data_i) |
|
return summary_confidence, full_data |
|
|