|
import logging |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import plotly.graph_objects as go |
|
from scipy.stats import pearsonr, spearmanr |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
SCORE_COLUMN_NAMES = { |
|
"confidence_score_boltz": "Boltz Confidence Score", |
|
"ptm_boltz": "Boltz pTM Score", |
|
"iptm_boltz": "Boltz ipTM Score", |
|
"complex_plddt_boltz": "Boltz Complex pLDDT", |
|
"complex_iplddt_boltz": "Boltz Complex ipLDDT", |
|
"complex_pde_boltz": "Boltz Complex pDE", |
|
"complex_ipde_boltz": "Boltz Complex ipDE", |
|
"interchain_pae_monomer": "AlphaFold2 GapTrick Interchain PAE", |
|
"interface_pae_monomer": "AlphaFold2 GapTrick Interface PAE", |
|
"overall_pae_monomer": "AlphaFold2 GapTrick Overall PAE", |
|
"interface_plddt_monomer": "AlphaFold2 GapTrick Interface pLDDT", |
|
"average_plddt_monomer": "AlphaFold2 GapTrick Average pLDDT", |
|
"ptm_monomer": "AlphaFold2 GapTrick pTM Score", |
|
"interface_ptm_monomer": "AlphaFold2 GapTrick Interface pTM", |
|
"interchain_pae_multimer": "AlphaFold2 Multimer Interchain PAE", |
|
"interface_pae_multimer": "AlphaFold2 Multimer Interface PAE", |
|
"overall_pae_multimer": "AlphaFold2 Multimer Overall PAE", |
|
"interface_plddt_multimer": "AlphaFold2 Multimer Interface pLDDT", |
|
"average_plddt_multimer": "AlphaFold2 Multimer Average pLDDT", |
|
"ptm_multimer": "AlphaFold2 Multimer pTM Score", |
|
"interface_ptm_multimer": "AlphaFold2 Multimer Interface pTM", |
|
} |
|
|
|
SCORE_COLUMNS = list(SCORE_COLUMN_NAMES.values()) |
|
|
|
|
|
def get_score_description(score: str) -> str: |
|
descriptions = { |
|
"Boltz Confidence Score": "The Boltz model confidence score provides an overall assessment of prediction quality (0-1, higher is better).", |
|
"Boltz pTM Score": "The Boltz model predicted TM-score (pTM) assesses the overall fold accuracy of the predicted structure (0-1, higher is better).", |
|
"Boltz ipTM Score": "The Boltz model interface pTM score (ipTM) specifically evaluates the accuracy of interface regions (0-1, higher is better).", |
|
"Boltz Complex pLDDT": "The Boltz model Complex pLDDT measures confidence in local structure predictions across the entire complex (0-100, higher is better).", |
|
"Boltz Complex ipLDDT": "The Boltz model Complex interface pLDDT (ipLDDT) focuses on confidence in interface region predictions (0-100, higher is better).", |
|
"Boltz Complex pDE": "The Boltz model Complex predicted distance error (pDE) estimates the confidence in predicted distances between residues (0-1, higher is better).", |
|
"Boltz Complex ipDE": "The Boltz model Complex interface pDE (ipDE) estimates confidence in predicted distances specifically at interfaces (0-1, higher is better).", |
|
"AlphaFold2 GapTrick Interchain PAE": "The AlphaFold2 GapTrick model interchain predicted aligned error (PAE) estimates position errors between chains in monomeric predictions (lower is better).", |
|
"AlphaFold2 GapTrick Interface PAE": "The AlphaFold2 GapTrick model interface PAE estimates position errors specifically at interfaces in monomeric predictions (lower is better).", |
|
"AlphaFold2 GapTrick Overall PAE": "The AlphaFold2 GapTrick model overall PAE estimates position errors across the entire structure in monomeric predictions (lower is better).", |
|
"AlphaFold2 GapTrick Interface pLDDT": "The AlphaFold2 GapTrick model interface pLDDT measures confidence in interface region predictions for monomeric models (0-100, higher is better).", |
|
"AlphaFold2 GapTrick Average pLDDT": "The AlphaFold2 GapTrick model average pLDDT provides the mean confidence across all residues in monomeric predictions (0-100, higher is better).", |
|
"AlphaFold2 GapTrick pTM Score": "The AlphaFold2 GapTrick model pTM score assesses overall fold accuracy in monomeric predictions (0-1, higher is better).", |
|
"AlphaFold2 GapTrick Interface pTM": "The AlphaFold2 GapTrick model interface pTM specifically evaluates accuracy of interface regions in monomeric predictions (0-1, higher is better).", |
|
"AlphaFold2 Multimer Interface PAE": "The AlphaFold2 Multimer model interface PAE estimates position errors specifically at interfaces in multimeric predictions (lower is better).", |
|
"AlphaFold2 Multimer Overall PAE": "The AlphaFold2 Multimer model overall PAE estimates position errors across the entire structure in multimeric predictions (lower is better).", |
|
"AlphaFold2 Multimer Interface pLDDT": "The AlphaFold2 Multimer model interface pLDDT measures confidence in interface region predictions for multimeric models (0-100, higher is better).", |
|
"AlphaFold2 Multimer Average pLDDT": "The AlphaFold2 Multimer model average pLDDT provides the mean confidence across all residues in multimeric predictions (0-100, higher is better).", |
|
"AlphaFold2 Multimer pTM Score": "The AlphaFold2 Multimer model pTM score assesses overall fold accuracy in multimeric predictions (0-1, higher is better).", |
|
"AlphaFold2 Multimer Interface pTM": "The AlphaFold2 Multimer model interface pTM specifically evaluates accuracy of interface regions in multimeric predictions (0-1, higher is better).", |
|
} |
|
return descriptions.get(score, "No description available for this score.") |
|
|
|
|
|
def compute_correlation_data( |
|
spr_data_with_scores: pd.DataFrame, score_cols: list[str] |
|
) -> pd.DataFrame: |
|
corr_data_file = Path("corr_data.csv") |
|
if corr_data_file.exists(): |
|
logger.info(f"Loading correlation data from {corr_data_file}") |
|
return pd.read_csv(corr_data_file) |
|
|
|
corr_data = [] |
|
spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"]) |
|
kd_col = "KD (nM)" |
|
corr_funcs = {} |
|
corr_funcs["Spearman"] = spearmanr |
|
corr_funcs["Pearson"] = pearsonr |
|
for kd_col in ["KD (nM)", "log_kd"]: |
|
for correlation_type, corr_func in corr_funcs.items(): |
|
for score_col in score_cols: |
|
logger.info( |
|
f"Computing {correlation_type} correlation between {score_col} and {kd_col}" |
|
) |
|
res = corr_func( |
|
spr_data_with_scores[kd_col], spr_data_with_scores[score_col] |
|
) |
|
logger.info(f"Correlation function: {corr_func}") |
|
correlation_value = res.statistic |
|
corr_data.append( |
|
{ |
|
"correlation_type": correlation_type, |
|
"kd_col": kd_col, |
|
"score": score_col, |
|
"correlation": correlation_value, |
|
"p-value": res.pvalue, |
|
} |
|
) |
|
|
|
corr_data = pd.DataFrame(corr_data) |
|
|
|
corr_data = corr_data[corr_data["correlation"].notna()] |
|
|
|
corr_data = corr_data.sort_values("correlation", ascending=True) |
|
|
|
corr_data.to_csv("corr_data.csv", index=False) |
|
|
|
return corr_data |
|
|
|
|
|
def plot_correlation_ranking( |
|
corr_data: pd.DataFrame, correlation_type: str, kd_col: str |
|
) -> go.Figure: |
|
|
|
data = corr_data[ |
|
(corr_data["correlation_type"] == correlation_type) |
|
& (corr_data["kd_col"] == kd_col) |
|
] |
|
corr_ranking_plot = go.Figure( |
|
data=[ |
|
go.Bar( |
|
x=data["correlation"], |
|
y=data["score"], |
|
name=correlation_type, |
|
orientation="h", |
|
hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>", |
|
) |
|
] |
|
) |
|
corr_ranking_plot.update_layout( |
|
title="Correlation with Binding Affinity", |
|
yaxis_title="Score", |
|
xaxis_title=correlation_type, |
|
template="simple_white", |
|
showlegend=False, |
|
) |
|
return corr_ranking_plot |
|
|
|
|
|
def fake_predict_and_correlate( |
|
spr_data_with_scores: pd.DataFrame, score_cols: list[str], main_cols: list[str] |
|
) -> tuple[pd.DataFrame, go.Figure]: |
|
"""Fake predict structures of all complexes and correlate the results.""" |
|
|
|
corr_data = compute_correlation_data(spr_data_with_scores, score_cols) |
|
corr_ranking_plot = plot_correlation_ranking(corr_data, "Spearman", kd_col="KD (nM)") |
|
|
|
cols_to_show = main_cols[:] |
|
cols_to_show.extend(score_cols) |
|
|
|
corr_plot = make_regression_plot(spr_data_with_scores, score_cols[0], use_log=False) |
|
|
|
return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot |
|
|
|
|
|
def make_regression_plot( |
|
spr_data_with_scores: pd.DataFrame, score: str, use_log: bool |
|
) -> go.Figure: |
|
"""Select the regression plot to display.""" |
|
|
|
scatter = go.Scatter( |
|
x=spr_data_with_scores["KD (nM)"], |
|
y=spr_data_with_scores[score], |
|
name=f"Samples", |
|
mode="markers", |
|
hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>", |
|
marker=dict(color="#1f77b4"), |
|
) |
|
corr_plot = go.Figure(data=scatter) |
|
corr_plot.update_layout( |
|
xaxis_title="KD (nM)", |
|
yaxis_title=score, |
|
template="simple_white", |
|
legend=dict( |
|
orientation="h", |
|
yanchor="bottom", |
|
y=1.02, |
|
xanchor="right", |
|
x=1, |
|
), |
|
xaxis_type="log" if use_log else "linear", |
|
) |
|
|
|
if use_log: |
|
|
|
x_vals = np.log10(spr_data_with_scores["KD (nM)"]) |
|
else: |
|
x_vals = spr_data_with_scores["KD (nM)"] |
|
|
|
|
|
corr_line = np.polyfit(x_vals, spr_data_with_scores[score], 1) |
|
|
|
|
|
corr_line_x = np.linspace(min(x_vals), max(x_vals), 100) |
|
corr_line_y = corr_line[0] * corr_line_x + corr_line[1] |
|
|
|
|
|
if use_log: |
|
corr_line_x = 10**corr_line_x |
|
|
|
corr_plot.add_trace( |
|
go.Scatter( |
|
x=corr_line_x, |
|
y=corr_line_y, |
|
mode="lines", |
|
name=f"Regression line", |
|
line=dict(color="#1f77b4"), |
|
) |
|
) |
|
return corr_plot |
|
|