FoldMark / protenix /model /sample_confidence.py
Zaixi's picture
Add large file
89c0b51
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import torch
from ml_collections.config_dict import ConfigDict
from protenix.metrics.clash import Clash
from protenix.utils.distributed import traverse_and_aggregate
def merge_per_sample_confidence_scores(summary_confidence_list: list[dict]) -> dict:
"""
Merge confidence scores from multiple samples into a single dictionary.
Args:
summary_confidence_list (list[dict]): List of dictionaries containing confidence scores for each sample.
Returns:
dict: Merged dictionary of confidence scores.
"""
def stack_score(tensor_list: list):
if tensor_list[0].dim() == 0:
tensor_list = [x.unsqueeze(0) for x in tensor_list]
score = torch.stack(tensor_list, dim=0)
return score
return traverse_and_aggregate(summary_confidence_list, aggregation_func=stack_score)
def _compute_full_data_and_summary(
configs: ConfigDict,
pae_logits: torch.Tensor,
plddt_logits: torch.Tensor,
pde_logits: torch.Tensor,
contact_probs: torch.Tensor,
token_asym_id: torch.Tensor,
token_has_frame: torch.Tensor,
atom_coordinate: torch.Tensor,
atom_to_token_idx: torch.Tensor,
atom_is_polymer: torch.Tensor,
N_recycle: int,
interested_atom_mask: Optional[torch.Tensor] = None,
elements_one_hot: Optional[torch.Tensor] = None,
mol_id: Optional[torch.Tensor] = None,
return_full_data: bool = False,
) -> tuple[list[dict], list[dict]]:
"""
Compute full data and summary confidence scores for the given inputs.
Args:
configs: Configuration object.
pae_logits (torch.Tensor): Logits for PAE (Predicted Aligned Error).
plddt_logits (torch.Tensor): Logits for pLDDT (Predicted Local Distance Difference Test).
pde_logits (torch.Tensor): Logits for PDE (Predicted Distance Error).
contact_probs (torch.Tensor): Contact probabilities.
token_asym_id (torch.Tensor): Asymmetric ID for tokens.
token_has_frame (torch.Tensor): Indicator for tokens having a frame.
atom_coordinate (torch.Tensor): Atom coordinates.
atom_to_token_idx (torch.Tensor): Mapping from atoms to tokens.
atom_is_polymer (torch.Tensor): Indicator for atoms being part of a polymer.
N_recycle (int): Number of recycles.
interested_atom_mask (Optional[torch.Tensor]): Mask for interested atoms. Defaults to None.
elements_one_hot (Optional[torch.Tensor]): One-hot encoding for elements. Defaults to None.
mol_id (Optional[torch.Tensor]): Molecular ID. Defaults to None.
return_full_data (bool): Whether to return full data. Defaults to False.
Returns:
tuple[list[dict], list[dict]]:
- summary_confidence: List of dictionaries containing summary confidence scores.
- full_data: List of dictionaries containing full data if `return_full_data` is True.
"""
atom_is_ligand = (1 - atom_is_polymer).long()
token_is_ligand = torch.zeros_like(token_asym_id).scatter_add(
0, atom_to_token_idx, atom_is_ligand
)
token_is_ligand = token_is_ligand > 0
full_data = {}
full_data["atom_plddt"] = logits_to_score(
plddt_logits, **get_bin_params(configs.loss.plddt)
) # [N_s, N_atom]
# Cpu offload for saving cuda memory
pde_logits = pde_logits.to(plddt_logits.device)
full_data["token_pair_pde"] = logits_to_score(
pde_logits, **get_bin_params(configs.loss.pde)
) # [N_s, N_token, N_token]
del pde_logits
full_data["contact_probs"] = contact_probs.clone() # [N_token, N_token]
pae_logits = pae_logits.to(plddt_logits.device)
full_data["token_pair_pae"], pae_prob = logits_to_score(
pae_logits, **get_bin_params(configs.loss.pae), return_prob=True
) # [N_s, N_token, N_token]
del pae_logits
summary_confidence = {}
summary_confidence["plddt"] = full_data["atom_plddt"].mean(dim=-1) * 100 # [N_s, ]
summary_confidence["gpde"] = (
full_data["token_pair_pde"] * full_data["contact_probs"]
).sum(dim=[-1, -2]) / full_data["contact_probs"].sum(dim=[-1, -2])
summary_confidence["ptm"] = calculate_ptm(
pae_prob, has_frame=token_has_frame, **get_bin_params(configs.loss.pae)
) # [N_s, ]
summary_confidence["iptm"] = calculate_iptm(
pae_prob,
has_frame=token_has_frame,
asym_id=token_asym_id,
**get_bin_params(configs.loss.pae)
) # [N_s, ]
# Add: 'chain_pair_iptm', 'chain_pair_iptm_global' 'chain_iptm', 'chain_ptm'
summary_confidence.update(
calculate_chain_based_ptm(
pae_prob,
has_frame=token_has_frame,
asym_id=token_asym_id,
token_is_ligand=token_is_ligand,
**get_bin_params(configs.loss.pae)
)
)
# Add: 'chain_plddt', 'chain_pair_plddt'
summary_confidence.update(
calculate_chain_based_plddt(
full_data["atom_plddt"], token_asym_id, atom_to_token_idx
)
)
del pae_prob
summary_confidence["has_clash"] = calculate_clash(
atom_coordinate,
token_asym_id,
atom_to_token_idx,
atom_is_polymer,
configs.metrics.clash.af3_clash_threshold,
)
summary_confidence["num_recycles"] = torch.tensor(
N_recycle, device=atom_coordinate.device
)
# TODO: disorder
summary_confidence["disorder"] = torch.zeros_like(summary_confidence["ptm"])
summary_confidence["ranking_score"] = (
0.8 * summary_confidence["iptm"]
+ 0.2 * summary_confidence["ptm"]
+ 0.5 * summary_confidence["disorder"]
- 100 * summary_confidence["has_clash"]
)
if interested_atom_mask is not None:
token_idx = atom_to_token_idx[interested_atom_mask[0].bool()].long()
asym_ids = token_asym_id[token_idx]
assert len(torch.unique(asym_ids)) == 1
interested_asym_id = asym_ids[0].item()
N_chains = token_asym_id.max().long().item() + 1
pb_ranking_score = summary_confidence["chain_pair_iptm_global"][
:, interested_asym_id, torch.arange(N_chains) != interested_asym_id
] # [N_s, N_chain - 1]
summary_confidence["pb_ranking_score"] = pb_ranking_score[:, 0]
if elements_one_hot is not None and mol_id is not None:
vdw_clash = calculate_vdw_clash(
pred_coordinate=atom_coordinate,
asym_id=token_asym_id,
mol_id=mol_id,
is_polymer=atom_is_polymer,
atom_token_idx=atom_to_token_idx,
elements_one_hot=elements_one_hot,
threshold=configs.metrics.clash.vdw_clash_threshold,
)
N_sample = atom_coordinate.shape[0]
vdw_clash_per_sample_flag = (
vdw_clash[:, interested_asym_id, :].reshape(N_sample, -1).max(dim=-1)[0]
)
summary_confidence["has_vdw_pl_clash"] = vdw_clash_per_sample_flag
summary_confidence["pb_ranking_score_vdw_penalized"] = (
summary_confidence["pb_ranking_score"] - 100 * vdw_clash_per_sample_flag
)
summary_confidence = break_down_to_per_sample_dict(
summary_confidence, shared_keys=["num_recycles"]
)
torch.cuda.empty_cache()
if return_full_data:
# save extra inputs that are used for computing summary_confidence
full_data["token_has_frame"] = token_has_frame.clone()
full_data["token_asym_id"] = token_asym_id.clone()
full_data["atom_to_token_idx"] = atom_to_token_idx.clone()
full_data["atom_is_polymer"] = atom_is_polymer.clone()
full_data["atom_coordinate"] = atom_coordinate.clone()
full_data = break_down_to_per_sample_dict(
full_data,
shared_keys=[
"contact_probs",
"token_has_frame",
"token_asym_id",
"atom_to_token_idx",
"atom_is_polymer",
],
)
return summary_confidence, full_data
else:
return summary_confidence, [{}]
def get_bin_params(cfg: ConfigDict) -> dict:
"""
Extract bin parameters from the configuration object.
"""
return {"min_bin": cfg.min_bin, "max_bin": cfg.max_bin, "no_bins": cfg.no_bins}
def compute_contact_prob(
distogram_logits: torch.Tensor,
min_bin: float,
max_bin: float,
no_bins: int,
thres=8.0,
) -> torch.Tensor:
"""
Compute the contact probability from distogram logits.
Args:
distogram_logits (torch.Tensor): Logits for the distogram.
Shape: [N_token, N_token, N_bins]
min_bin (float): Minimum bin value.
max_bin (float): Maximum bin value.
no_bins (int): Number of bins.
thres (float): Threshold distance for contact probability. Defaults to 8.0.
Returns:
torch.Tensor: Contact probability.
Shape: [N_token, N_token]
"""
distogram_prob = torch.nn.functional.softmax(
distogram_logits, dim=-1
) # [N_token, N_token, N_bins]
distogram_bins = get_bin_centers(min_bin, max_bin, no_bins)
thres_idx = (distogram_bins < thres).sum()
contact_prob = distogram_prob[..., :thres_idx].sum(-1)
return contact_prob
def get_bin_centers(min_bin: float, max_bin: float, no_bins: int) -> torch.Tensor:
"""
Calculate the centers of the bins for a given range and number of bins.
Args:
min_bin (float): The minimum value of the bin range.
max_bin (float): The maximum value of the bin range.
no_bins (int): The number of bins.
Returns:
torch.Tensor: The centers of the bins.
Shape: [no_bins]
"""
bin_width = (max_bin - min_bin) / no_bins
boundaries = torch.linspace(
start=min_bin,
end=max_bin - bin_width,
steps=no_bins,
)
bin_centers = boundaries + 0.5 * bin_width
return bin_centers
def logits_to_prob(logits: torch.Tensor, dim=-1) -> torch.Tensor:
return torch.nn.functional.softmax(logits, dim=dim)
def logits_to_score(
logits: torch.Tensor,
min_bin: float,
max_bin: float,
no_bins: int,
return_prob=False,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""
Convert logits to a score using bin centers.
Args:
logits (torch.Tensor): Logits tensor.
Shape: [..., no_bins]
min_bin (float): Minimum bin value.
max_bin (float): Maximum bin value.
no_bins (int): Number of bins.
return_prob (bool): Whether to return the probability distribution. Defaults to False.
Returns:
score (torch.Tensor): Converted score.
Shape: [...]
prob (torch.Tensor, optional): Probability distribution if `return_prob` is True.
Shape: [..., no_bins]
"""
prob = logits_to_prob(logits, dim=-1)
bin_centers = get_bin_centers(min_bin, max_bin, no_bins).to(logits.device)
score = prob @ bin_centers
if return_prob:
return score, prob
else:
return score
def calculate_normalization(N):
# TM-score normalization constant
return 1.24 * (max(N, 19) - 15) ** (1 / 3) - 1.8
def calculate_vdw_clash(
pred_coordinate: torch.Tensor,
asym_id: torch.LongTensor,
mol_id: torch.LongTensor,
atom_token_idx: torch.LongTensor,
is_polymer: torch.BoolTensor,
elements_one_hot: torch.Tensor,
threshold: float,
) -> torch.Tensor:
"""
Calculate Van der Waals (VDW) clash for predicted coordinates.
Args:
pred_coordinate (torch.Tensor): Predicted coordinates of atoms.
Shape: [N_sample, N_atom, 3]
asym_id (torch.LongTensor): Asymmetric ID for tokens.
Shape: [N_token]
mol_id (torch.LongTensor): Molecular ID.
Shape: [N_atom]
atom_token_idx (torch.LongTensor): Mapping from atoms to tokens.
Shape: [N_atom]
is_polymer (torch.BoolTensor): Indicator for atoms being part of a polymer.
Shape: [N_atom]
elements_one_hot (torch.Tensor): One-hot encoding for elements.
Shape: [N_atom, N_elements]
threshold (float): Threshold for VDW clash detection.
Returns:
torch.Tensor: VDW clash summary.
Shape: [N_sample]
"""
clash_calculator = Clash(vdw_clash_threshold=threshold, compute_af3_clash=False)
# Check ligand-polymer VDW clash
N_sample = pred_coordinate.shape[0]
dummy_is_dna = torch.zeros_like(is_polymer)
dummy_is_rna = torch.zeros_like(is_polymer)
clash_dict = clash_calculator(
pred_coordinate=pred_coordinate,
asym_id=asym_id,
atom_to_token_idx=atom_token_idx,
mol_id=mol_id,
is_ligand=1 - is_polymer,
is_protein=is_polymer,
is_dna=dummy_is_dna,
is_rna=dummy_is_rna,
elements_one_hot=elements_one_hot,
)
return clash_dict["summary"]["vdw_clash"]
def calculate_clash(
pred_coordinate: torch.Tensor,
asym_id: torch.LongTensor,
atom_to_token_idx: torch.LongTensor,
is_polymer: torch.BoolTensor,
threshold: float,
) -> torch.Tensor:
"""Check complex clash
Args:
pred_coordinate (torch.Tensor): [N_sample, N_atom, 3]
asym_id (torch.LongTensor): [N_token, ]
atom_to_token_idx (torch.LongTensor): [N_atom, ]
is_polymer (torch.BoolTensor): [N_atom, ]
threshold: (float)
Returns:
torch.Tensor: [N_sample] whether there is a clash in the complex
"""
N_sample = pred_coordinate.shape[0]
dummy_is_dna = torch.zeros_like(is_polymer)
dummy_is_rna = torch.zeros_like(is_polymer)
clash_calculator = Clash(vdw_clash_threshold=threshold, compute_vdw_clash=False)
clash_dict = clash_calculator(
pred_coordinate,
asym_id,
atom_to_token_idx,
1 - is_polymer,
is_polymer,
dummy_is_dna,
dummy_is_rna,
)
return clash_dict["summary"]["af3_clash"].reshape(N_sample, -1).max(dim=-1)[0]
def calculate_ptm(
pae_prob: torch.Tensor,
has_frame: torch.BoolTensor,
min_bin: float,
max_bin: float,
no_bins: int,
token_mask: Optional[torch.BoolTensor] = None,
) -> torch.Tensor:
"""Compute pTM score
Args:
pae_prob (torch.Tensor): Predicted probability from PAE loss head.
Shape: [..., N_token, N_token, N_bins]
has_frame (torch.BoolTensor): Indicator for tokens having a frame.
Shape: [N_token, ]
min_bin (float): Minimum bin value.
max_bin (float): Maximum bin value.
no_bins (int): Number of bins.
token_mask (Optional[torch.BoolTensor]): Mask for tokens.
Shape: [N_token, ] or None
Returns:
torch.Tensor: pTM score. Higher values indicate better ranking.
Shape: [...]
"""
has_frame = has_frame.bool()
if token_mask is not None:
token_mask = token_mask.bool()
pae_prob = pae_prob[..., token_mask, :, :][
..., :, token_mask, :
] # [..., N_d, N_d, N_bins]
has_frame = has_frame[token_mask] # [N_d, ]
if has_frame.sum() == 0:
return torch.zeros(size=pae_prob.shape[:-3], device=pae_prob.device)
N_d = has_frame.shape[-1]
ptm_norm = calculate_normalization(N_d)
bin_center = get_bin_centers(min_bin, max_bin, no_bins)
per_bin_weight = (1 / (1 + (bin_center / ptm_norm) ** 2)).to(
pae_prob.device
) # [N_bins]
token_token_ptm = (pae_prob * per_bin_weight).sum(dim=-1) # [..., N_d, N_d]
ptm = token_token_ptm.mean(dim=-1)[..., has_frame].max(dim=-1).values
return ptm
def calculate_chain_based_ptm(
pae_prob: torch.Tensor,
has_frame: torch.BoolTensor,
asym_id: torch.LongTensor,
token_is_ligand: torch.BoolTensor,
min_bin: float,
max_bin: float,
no_bins: int,
) -> dict[str, torch.Tensor]:
"""
Compute chain-based pTM scores.
Args:
pae_prob (torch.Tensor): Predicted probability from PAE loss head.
Shape: [..., N_token, N_token, N_bins]
has_frame (torch.BoolTensor): Indicator for tokens having a frame.
Shape: [N_token, ]
asym_id (torch.LongTensor): Asymmetric ID for tokens.
Shape: [N_token, ]
token_is_ligand (torch.BoolTensor): Indicator for tokens being ligands.
Shape: [N_token, ]
min_bin (float): Minimum bin value.
max_bin (float): Maximum bin value.
no_bins (int): Number of bins.
Returns:
dict: Dictionary containing chain-based pTM scores.
- chain_ptm (torch.Tensor): pTM scores for each chain.
- chain_iptm (torch.Tensor): ipTM scores for chain interface.
- chain_pair_iptm (torch.Tensor): Pairwise ipTM scores between chains.
- chain_pair_iptm_global (torch.Tensor): Global pairwise ipTM scores between chains.
"""
has_frame = has_frame.bool()
asym_id = asym_id.long()
asym_id_to_asym_mask = {aid.item(): asym_id == aid for aid in torch.unique(asym_id)}
chain_is_ligand = {
aid.item(): token_is_ligand[asym_id == aid].sum() >= (asym_id == aid).sum() // 2
for aid in torch.unique(asym_id)
}
batch_shape = pae_prob.shape[:-3]
# Chain_pair_iptm
# Change to dense tensor, otherwise it's troublesome in break_down_to_per_sample_dict and traverse_and_aggregate across different devices
N_chain = len(asym_id_to_asym_mask)
chain_pair_iptm = torch.zeros(size=batch_shape + (N_chain, N_chain)).to(
pae_prob.device
)
for aid_1 in range(N_chain):
for aid_2 in range(N_chain):
if aid_1 == aid_2:
continue
if aid_1 > aid_2:
chain_pair_iptm[:, aid_1, aid_2] = chain_pair_iptm[:, aid_2, aid_1]
continue
pair_mask = asym_id_to_asym_mask[aid_1] + asym_id_to_asym_mask[aid_2]
chain_pair_iptm[:, aid_1, aid_2] = calculate_iptm(
pae_prob,
has_frame,
asym_id,
min_bin,
max_bin,
no_bins,
token_mask=pair_mask,
)
# chain_ptm
chain_ptm = torch.zeros(size=batch_shape + (N_chain,)).to(pae_prob.device)
for aid, asym_mask in asym_id_to_asym_mask.items():
chain_ptm[:, aid] = calculate_ptm(
pae_prob,
has_frame,
min_bin,
max_bin,
no_bins,
token_mask=asym_mask,
)
# Chain iptm
chain_has_frame = [
(asym_id_to_asym_mask[i] * has_frame).any() for i in range(N_chain)
]
chain_iptm = torch.zeros(size=batch_shape + (N_chain,)).to(pae_prob.device)
for aid, asym_mask in asym_id_to_asym_mask.items():
pairs = [
(i, j)
for i in range(N_chain)
for j in range(N_chain)
if (i == aid or j == aid) and (i != j) and chain_has_frame[i]
]
vals = [chain_pair_iptm[:, i, j] for (i, j) in pairs]
if len(vals) > 0:
chain_iptm[:, aid] = torch.stack(vals, dim=-1).mean(dim=-1)
# Chain_pair_iptm_global
chain_pair_iptm_global = torch.zeros(size=batch_shape + (N_chain, N_chain)).to(
pae_prob.device
)
for aid_1 in range(N_chain):
for aid_2 in range(N_chain):
if aid_1 == aid_2:
continue
if chain_is_ligand[aid_1]:
chain_pair_iptm_global[:, aid_1, aid_2] = chain_iptm[:, aid_1]
elif chain_is_ligand[aid_2]:
chain_pair_iptm_global[:, aid_1, aid_2] = chain_iptm[:, aid_2]
else:
chain_pair_iptm_global[:, aid_1, aid_2] = (
chain_iptm[:, aid_1] + chain_iptm[:, aid_2]
) * 0.5
return {
"chain_ptm": chain_ptm,
"chain_iptm": chain_iptm,
"chain_pair_iptm": chain_pair_iptm,
"chain_pair_iptm_global": chain_pair_iptm_global,
}
def calculate_chain_based_plddt(
atom_plddt: torch.Tensor,
asym_id: torch.LongTensor,
atom_to_token_idx: torch.LongTensor,
) -> dict[str, torch.Tensor]:
"""
Calculate chain-based pLDDT scores.
Args:
atom_plddt (torch.Tensor): Predicted pLDDT scores for atoms.
Shape: [N_sample, N_atom]
asym_id (torch.LongTensor): Asymmetric ID for tokens.
Shape: [N_token]
atom_to_token_idx (torch.LongTensor): Mapping from atoms to tokens.
Shape: [N_atom]
Returns:
dict: Dictionary containing chain-based pLDDT scores.
- chain_plddt (torch.Tensor): pLDDT scores for each chain.
- chain_pair_plddt (torch.Tensor): Pairwise pLDDT scores between chains.
"""
asym_id = asym_id.long()
asym_id_to_asym_mask = {aid.item(): asym_id == aid for aid in torch.unique(asym_id)}
N_chain = len(asym_id_to_asym_mask)
assert N_chain == asym_id.max() + 1 # make sure it is from 0 to N_chain-1
def _calculate_lddt_with_token_mask(token_mask):
atom_mask = token_mask[atom_to_token_idx]
sub_plddt = atom_plddt[:, atom_mask].mean(-1)
return sub_plddt
batch_shape = atom_plddt.shape[:-1]
# Chain_plddt
chain_plddt = torch.zeros(size=batch_shape + (N_chain,)).to(atom_plddt.device)
for aid, asym_mask in asym_id_to_asym_mask.items():
chain_plddt[:, aid] = _calculate_lddt_with_token_mask(token_mask=asym_mask)
# Chain_pair_plddt
chain_pair_plddt = torch.zeros(size=batch_shape + (N_chain, N_chain)).to(
atom_plddt.device
)
for aid_1 in asym_id_to_asym_mask:
for aid_2 in asym_id_to_asym_mask:
if aid_1 == aid_2:
continue
pair_mask = asym_id_to_asym_mask[aid_1] + asym_id_to_asym_mask[aid_2]
chain_pair_plddt[:, aid_1, aid_2] = _calculate_lddt_with_token_mask(
token_mask=pair_mask
)
return {"chain_plddt": chain_plddt, "chain_pair_plddt": chain_pair_plddt}
def calculate_iptm(
pae_prob: torch.Tensor,
has_frame: torch.BoolTensor,
asym_id: torch.LongTensor,
min_bin: float,
max_bin: float,
no_bins: int,
token_mask: Optional[torch.BoolTensor] = None,
eps: float = 1e-8,
):
"""
Compute ipTM score.
Args:
pae_prob (torch.Tensor): Predicted probability from PAE loss head.
Shape: [..., N_token, N_token, N_bins]
has_frame (torch.BoolTensor): Indicator for tokens having a frame.
Shape: [N_token, ]
asym_id (torch.LongTensor): Asymmetric ID for tokens.
Shape: [N_token, ]
min_bin (float): Minimum bin value.
max_bin (float): Maximum bin value.
no_bins (int): Number of bins.
token_mask (Optional[torch.BoolTensor]): Mask for tokens.
Shape: [N_token, ] or None
eps (float): Small value to avoid division by zero. Defaults to 1e-8.
Returns:
torch.Tensor: ipTM score. Higher values indicate better ranking.
Shape: [...]
"""
has_frame = has_frame.bool()
if token_mask is not None:
token_mask = token_mask.bool()
pae_prob = pae_prob[..., token_mask, :, :][
..., :, token_mask, :
] # [..., N_d, N_d, N_bins]
has_frame = has_frame[token_mask] # [N_d, ]
asym_id = asym_id[token_mask] # [N_d, ]
if has_frame.sum() == 0:
return torch.zeros(size=pae_prob.shape[:-3], device=pae_prob.device)
N_d = has_frame.shape[-1]
ptm_norm = calculate_normalization(N_d)
bin_center = get_bin_centers(min_bin, max_bin, no_bins)
per_bin_weight = (1 / (1 + (bin_center / ptm_norm) ** 2)).to(
pae_prob.device
) # [N_bins]
token_token_ptm = (pae_prob * per_bin_weight).sum(dim=-1) # [..., N_d, N_d]
is_diff_chain = asym_id[None, :] != asym_id[:, None] # [N_d, N_d]
iptm = (token_token_ptm * is_diff_chain).sum(dim=-1) / (
eps + is_diff_chain.sum(dim=-1)
) # [..., N_d]
iptm = iptm[..., has_frame].max(dim=-1).values
return iptm
def break_down_to_per_sample_dict(input_dict: dict, shared_keys=[]) -> list[dict]:
"""
Break down a dictionary containing tensors into a list of dictionaries, each corresponding to a sample.
Args:
input_dict (dict): Dictionary containing tensors.
shared_keys (list): List of keys that are shared across all samples. Defaults to an empty list.
Returns:
list[dict]: List of dictionaries, each containing data for a single sample.
"""
per_sample_keys = [key for key in input_dict if key not in shared_keys]
assert len(per_sample_keys) > 0
N_sample = input_dict[per_sample_keys[0]].size(0)
for key in per_sample_keys:
assert input_dict[key].size(0) == N_sample
per_sample_dict_list = []
for i in range(N_sample):
sample_dict = {key: input_dict[key][i] for key in per_sample_keys}
sample_dict.update({key: input_dict[key] for key in shared_keys})
per_sample_dict_list.append(sample_dict)
return per_sample_dict_list
@torch.no_grad()
def compute_full_data_and_summary(
configs,
pae_logits,
plddt_logits,
pde_logits,
contact_probs,
token_asym_id,
token_has_frame,
atom_coordinate,
atom_to_token_idx,
atom_is_polymer,
N_recycle,
return_full_data: bool = False,
interested_atom_mask=None,
mol_id=None,
elements_one_hot=None,
):
"""Wrapper of `_compute_full_data_and_summary` by enumerating over N samples"""
N_sample = pae_logits.size(0)
if contact_probs.dim() == 2:
# Convert to [N_sample, N_token, N_token]
contact_probs = contact_probs.unsqueeze(dim=0).expand(N_sample, -1, -1)
else:
assert contact_probs.dim() == 3
assert (
contact_probs.size(0) == plddt_logits.size(0) == pde_logits.size(0) == N_sample
)
summary_confidence = []
full_data = []
for i in range(N_sample):
summary_confidence_i, full_data_i = _compute_full_data_and_summary(
configs=configs,
pae_logits=pae_logits[i : i + 1],
plddt_logits=plddt_logits[i : i + 1],
pde_logits=pde_logits[i : i + 1],
contact_probs=contact_probs[i],
token_asym_id=token_asym_id,
token_has_frame=token_has_frame,
atom_coordinate=atom_coordinate[i : i + 1],
atom_to_token_idx=atom_to_token_idx,
atom_is_polymer=atom_is_polymer,
N_recycle=N_recycle,
interested_atom_mask=interested_atom_mask,
return_full_data=return_full_data,
mol_id=mol_id,
elements_one_hot=elements_one_hot,
)
summary_confidence.extend(summary_confidence_i)
full_data.extend(full_data_i)
return summary_confidence, full_data