|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from protenix.model import sample_confidence |
|
|
|
|
|
def get_complex_level_rankers(scores, keys): |
|
assert all([k in ["plddt", "gpde", "ranking_score"] for k in keys]) |
|
rankers = {} |
|
for key in keys: |
|
if key == "gpde": |
|
descending = False |
|
else: |
|
descending = True |
|
ranking = scores[key].argsort(dim=0, descending=descending) |
|
rankers[f"{key}.rank1"] = lambda x, rank1_idx=ranking[0].item(): x[ |
|
..., rank1_idx |
|
] |
|
return rankers |
|
|
|
|
|
def add_diff_metrics(scores, ranker_keys): |
|
diff_metrics = { |
|
"diff/best_worst": scores["best"] - scores["worst"], |
|
"diff/best_random": scores["best"] - scores["random"], |
|
"diff/best_median": scores["best"] - scores["median"], |
|
} |
|
|
|
for key in ranker_keys: |
|
diff_metrics.update( |
|
{ |
|
f"diff/best_{key}": scores["best"] - scores[f"{key}.rank1"], |
|
f"diff/{key}_median": scores[f"{key}.rank1"] - scores["median"], |
|
} |
|
) |
|
scores.update(diff_metrics) |
|
return scores |
|
|
|
|
|
class LDDTMetrics(nn.Module): |
|
"""LDDT: evaluated on chains and interfaces""" |
|
|
|
def __init__(self, configs): |
|
super(LDDTMetrics, self).__init__() |
|
self.eps = configs.metrics.lddt.eps |
|
self.configs = configs |
|
self.chunk_size = self.configs.infer_setting.lddt_metrics_chunk_size |
|
self.lddt_base = LDDT(eps=self.eps) |
|
|
|
self.complex_ranker_keys = configs.metrics.get( |
|
"complex_ranker_keys", ["plddt", "gpde", "ranking_score"] |
|
) |
|
|
|
def compute_lddt(self, pred_dict: dict, label_dict: dict): |
|
"""compute complex-level and chain/interface-level lddt |
|
|
|
Args: |
|
pred_dict (Dict): a dictionary containing |
|
coordinate: [N_sample, N_atom, 3] |
|
label_dict (Dict): a dictionary containing |
|
coordinate: [N_sample, N_atom, 3] |
|
lddt_mask: [N_atom, N_atom] |
|
""" |
|
|
|
out = {} |
|
|
|
|
|
lddt = self.lddt_base.forward( |
|
pred_coordinate=pred_dict["coordinate"], |
|
true_coordinate=label_dict["coordinate"], |
|
lddt_mask=label_dict["lddt_mask"], |
|
chunk_size=self.chunk_size, |
|
) |
|
out["complex"] = lddt |
|
|
|
return out |
|
|
|
def aggregate( |
|
self, |
|
vals, |
|
dim: int = -1, |
|
aggregators: dict = {}, |
|
): |
|
N_sample = vals.size(dim) |
|
median_index = N_sample // 2 |
|
basic_sample_aggregators = { |
|
"best": lambda x: x.max(dim=dim)[0], |
|
"worst": lambda x: x.min(dim=dim)[0], |
|
"random": lambda x: x.select(dim=dim, index=0), |
|
"mean": lambda x: x.mean(dim=dim), |
|
"median": lambda x: x.sort(dim=dim, descending=True)[0].select( |
|
dim=dim, index=median_index |
|
), |
|
} |
|
sample_aggregators = {**basic_sample_aggregators, **aggregators} |
|
|
|
return { |
|
agg_name: agg_func(vals) |
|
for agg_name, agg_func in sample_aggregators.items() |
|
} |
|
|
|
def aggregate_lddt(self, lddt_dict, per_sample_summary_confidence): |
|
|
|
|
|
confidence_scores = sample_confidence.merge_per_sample_confidence_scores( |
|
per_sample_summary_confidence |
|
) |
|
|
|
|
|
complex_level_ranker = get_complex_level_rankers( |
|
confidence_scores, self.complex_ranker_keys |
|
) |
|
|
|
complex_lddt = self.aggregate( |
|
lddt_dict["complex"], aggregators=complex_level_ranker |
|
) |
|
complex_lddt = add_diff_metrics(complex_lddt, self.complex_ranker_keys) |
|
|
|
complex_lddt = { |
|
f"lddt/complex/{name}": value for name, value in complex_lddt.items() |
|
} |
|
return complex_lddt, {} |
|
|
|
|
|
class LDDT(nn.Module): |
|
"""LDDT base metrics""" |
|
|
|
def __init__(self, eps: float = 1e-10): |
|
super(LDDT, self).__init__() |
|
self.eps = eps |
|
|
|
def _chunk_base_forward(self, pred_distance, true_distance) -> torch.Tensor: |
|
distance_error_l1 = torch.abs( |
|
pred_distance - true_distance |
|
) |
|
thresholds = [0.5, 1, 2, 4] |
|
sparse_pair_lddt = ( |
|
torch.stack([distance_error_l1 < t for t in thresholds], dim=-1) |
|
.to(dtype=distance_error_l1.dtype) |
|
.mean(dim=-1) |
|
) |
|
del distance_error_l1 |
|
|
|
if sparse_pair_lddt.numel() == 0: |
|
sparse_pair_lddt = torch.zeros_like(sparse_pair_lddt) |
|
lddt = torch.mean(sparse_pair_lddt, dim=-1) |
|
return lddt |
|
|
|
def _chunk_forward( |
|
self, pred_distance, true_distance, chunk_size: Optional[int] = None |
|
) -> torch.Tensor: |
|
if chunk_size is None: |
|
return self._chunk_base_forward(pred_distance, true_distance) |
|
else: |
|
lddt = [] |
|
N_sample = pred_distance.shape[-2] |
|
no_chunks = N_sample // chunk_size + (N_sample % chunk_size != 0) |
|
for i in range(no_chunks): |
|
lddt_i = self._chunk_base_forward( |
|
pred_distance[ |
|
..., |
|
i * chunk_size : (i + 1) * chunk_size, |
|
:, |
|
], |
|
true_distance, |
|
) |
|
lddt.append(lddt_i) |
|
lddt = torch.cat(lddt, dim=-1) |
|
return lddt |
|
|
|
def _calc_sparse_dist(self, pred_coordinate, true_coordinate, l_index, m_index): |
|
pred_coords_l = pred_coordinate.index_select( |
|
-2, l_index |
|
) |
|
pred_coords_m = pred_coordinate.index_select( |
|
-2, m_index |
|
) |
|
true_coords_l = true_coordinate.index_select( |
|
-2, l_index |
|
) |
|
true_coords_m = true_coordinate.index_select( |
|
-2, m_index |
|
) |
|
|
|
pred_distance_sparse_lm = torch.norm( |
|
pred_coords_l - pred_coords_m, p=2, dim=-1 |
|
) |
|
true_distance_sparse_lm = torch.norm( |
|
true_coords_l - true_coords_m, p=2, dim=-1 |
|
) |
|
return pred_distance_sparse_lm, true_distance_sparse_lm |
|
|
|
def forward( |
|
self, |
|
pred_coordinate: torch.Tensor, |
|
true_coordinate: torch.Tensor, |
|
lddt_mask: torch.Tensor, |
|
chunk_size: Optional[int] = None, |
|
) -> dict[str, torch.Tensor]: |
|
"""LDDT: evaluated on complex, chains and interfaces |
|
sparse implementation, which largely reduce cuda memory when atom num reaches 10^4 + |
|
|
|
Args: |
|
pred_coordinate (torch.Tensor): the pred coordinates |
|
[N_sample, N_atom, 3] |
|
true_coordinate (torch.Tensor): the ground truth atom coordinates |
|
[N_atom, 3] |
|
lddt_mask (torch.Tensor): |
|
sparse version of [N_atom, N_atom] atompair mask based on bespoke radius of true distance |
|
[N_nonzero_mask, 2] |
|
|
|
Returns: |
|
Dict[str, torch.Tensor]: |
|
"best": [N_eval] |
|
"worst": [N_eval] |
|
""" |
|
lddt_indices = torch.nonzero(lddt_mask, as_tuple=True) |
|
l_index = lddt_indices[0] |
|
m_index = lddt_indices[1] |
|
pred_distance_sparse_lm, true_distance_sparse_lm = self._calc_sparse_dist( |
|
pred_coordinate, true_coordinate, l_index, m_index |
|
) |
|
group_lddt = self._chunk_forward( |
|
pred_distance_sparse_lm, true_distance_sparse_lm, chunk_size=chunk_size |
|
) |
|
return group_lddt |
|
|
|
@staticmethod |
|
def compute_lddt_mask( |
|
true_coordinate: torch.Tensor, |
|
true_coordinate_mask: torch.Tensor, |
|
is_nucleotide: torch.Tensor = None, |
|
is_nucleotide_threshold: float = 30.0, |
|
threshold: float = 15.0, |
|
): |
|
|
|
distance_mask = ( |
|
true_coordinate_mask[..., None] * true_coordinate_mask[..., None, :] |
|
) |
|
|
|
|
|
distance = (torch.cdist(true_coordinate, true_coordinate) * distance_mask).to( |
|
true_coordinate.dtype |
|
) |
|
|
|
|
|
c_lm = distance < threshold |
|
if is_nucleotide is not None: |
|
|
|
is_nucleotide_mask = is_nucleotide.bool()[..., None] |
|
c_lm = (distance < is_nucleotide_threshold) * is_nucleotide_mask + c_lm * ( |
|
~is_nucleotide_mask |
|
) |
|
|
|
|
|
c_lm = c_lm * ( |
|
1 - torch.eye(n=c_lm.size(-1), device=c_lm.device, dtype=distance.dtype) |
|
) |
|
|
|
c_lm = c_lm * distance_mask |
|
return c_lm |
|
|