AchilleSoulieID's picture
Add more story to correlation tab (#17)
250a4a2 verified
import logging
from pathlib import Path
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from scipy.stats import pearsonr, spearmanr
logger = logging.getLogger(__name__)
SCORE_COLUMN_NAMES = {
"confidence_score_boltz": "Boltz Confidence Score",
"ptm_boltz": "Boltz pTM Score",
"iptm_boltz": "Boltz ipTM Score",
"complex_plddt_boltz": "Boltz Complex pLDDT",
"complex_iplddt_boltz": "Boltz Complex ipLDDT",
"complex_pde_boltz": "Boltz Complex pDE",
"complex_ipde_boltz": "Boltz Complex ipDE",
"interchain_pae_monomer": "AlphaFold2 GapTrick Interchain PAE",
"interface_pae_monomer": "AlphaFold2 GapTrick Interface PAE",
"overall_pae_monomer": "AlphaFold2 GapTrick Overall PAE",
"interface_plddt_monomer": "AlphaFold2 GapTrick Interface pLDDT",
"average_plddt_monomer": "AlphaFold2 GapTrick Average pLDDT",
"ptm_monomer": "AlphaFold2 GapTrick pTM Score",
"interface_ptm_monomer": "AlphaFold2 GapTrick Interface pTM",
"interchain_pae_multimer": "AlphaFold2 Multimer Interchain PAE",
"interface_pae_multimer": "AlphaFold2 Multimer Interface PAE",
"overall_pae_multimer": "AlphaFold2 Multimer Overall PAE",
"interface_plddt_multimer": "AlphaFold2 Multimer Interface pLDDT",
"average_plddt_multimer": "AlphaFold2 Multimer Average pLDDT",
"ptm_multimer": "AlphaFold2 Multimer pTM Score",
"interface_ptm_multimer": "AlphaFold2 Multimer Interface pTM",
}
SCORE_COLUMNS = list(SCORE_COLUMN_NAMES.values())
def get_score_description(score: str) -> str:
descriptions = {
"Boltz Confidence Score": "The Boltz model confidence score provides an overall assessment of prediction quality (0-1, higher is better).",
"Boltz pTM Score": "The Boltz model predicted TM-score (pTM) assesses the overall fold accuracy of the predicted structure (0-1, higher is better).",
"Boltz ipTM Score": "The Boltz model interface pTM score (ipTM) specifically evaluates the accuracy of interface regions (0-1, higher is better).",
"Boltz Complex pLDDT": "The Boltz model Complex pLDDT measures confidence in local structure predictions across the entire complex (0-100, higher is better).",
"Boltz Complex ipLDDT": "The Boltz model Complex interface pLDDT (ipLDDT) focuses on confidence in interface region predictions (0-100, higher is better).",
"Boltz Complex pDE": "The Boltz model Complex predicted distance error (pDE) estimates the confidence in predicted distances between residues (0-1, higher is better).",
"Boltz Complex ipDE": "The Boltz model Complex interface pDE (ipDE) estimates confidence in predicted distances specifically at interfaces (0-1, higher is better).",
"AlphaFold2 GapTrick Interchain PAE": "The AlphaFold2 GapTrick model interchain predicted aligned error (PAE) estimates position errors between chains in monomeric predictions (lower is better).",
"AlphaFold2 GapTrick Interface PAE": "The AlphaFold2 GapTrick model interface PAE estimates position errors specifically at interfaces in monomeric predictions (lower is better).",
"AlphaFold2 GapTrick Overall PAE": "The AlphaFold2 GapTrick model overall PAE estimates position errors across the entire structure in monomeric predictions (lower is better).",
"AlphaFold2 GapTrick Interface pLDDT": "The AlphaFold2 GapTrick model interface pLDDT measures confidence in interface region predictions for monomeric models (0-100, higher is better).",
"AlphaFold2 GapTrick Average pLDDT": "The AlphaFold2 GapTrick model average pLDDT provides the mean confidence across all residues in monomeric predictions (0-100, higher is better).",
"AlphaFold2 GapTrick pTM Score": "The AlphaFold2 GapTrick model pTM score assesses overall fold accuracy in monomeric predictions (0-1, higher is better).",
"AlphaFold2 GapTrick Interface pTM": "The AlphaFold2 GapTrick model interface pTM specifically evaluates accuracy of interface regions in monomeric predictions (0-1, higher is better).",
"AlphaFold2 Multimer Interface PAE": "The AlphaFold2 Multimer model interface PAE estimates position errors specifically at interfaces in multimeric predictions (lower is better).",
"AlphaFold2 Multimer Overall PAE": "The AlphaFold2 Multimer model overall PAE estimates position errors across the entire structure in multimeric predictions (lower is better).",
"AlphaFold2 Multimer Interface pLDDT": "The AlphaFold2 Multimer model interface pLDDT measures confidence in interface region predictions for multimeric models (0-100, higher is better).",
"AlphaFold2 Multimer Average pLDDT": "The AlphaFold2 Multimer model average pLDDT provides the mean confidence across all residues in multimeric predictions (0-100, higher is better).",
"AlphaFold2 Multimer pTM Score": "The AlphaFold2 Multimer model pTM score assesses overall fold accuracy in multimeric predictions (0-1, higher is better).",
"AlphaFold2 Multimer Interface pTM": "The AlphaFold2 Multimer model interface pTM specifically evaluates accuracy of interface regions in multimeric predictions (0-1, higher is better).",
}
return descriptions.get(score, "No description available for this score.")
def compute_correlation_data(
spr_data_with_scores: pd.DataFrame, score_cols: list[str]
) -> pd.DataFrame:
corr_data_file = Path("corr_data.csv")
if corr_data_file.exists():
logger.info(f"Loading correlation data from {corr_data_file}")
return pd.read_csv(corr_data_file)
corr_data = []
spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
kd_col = "KD (nM)"
corr_funcs = {}
corr_funcs["Spearman"] = spearmanr
corr_funcs["Pearson"] = pearsonr
for kd_col in ["KD (nM)", "log_kd"]:
for correlation_type, corr_func in corr_funcs.items():
for score_col in score_cols:
logger.info(
f"Computing {correlation_type} correlation between {score_col} and {kd_col}"
)
res = corr_func(
spr_data_with_scores[kd_col], spr_data_with_scores[score_col]
)
logger.info(f"Correlation function: {corr_func}")
correlation_value = res.statistic
corr_data.append(
{
"correlation_type": correlation_type,
"kd_col": kd_col,
"score": score_col,
"correlation": correlation_value,
"p-value": res.pvalue,
}
)
corr_data = pd.DataFrame(corr_data)
# Find the lines in corr_data with NaN values and remove them
corr_data = corr_data[corr_data["correlation"].notna()]
# Sort correlation data by correlation value
corr_data = corr_data.sort_values("correlation", ascending=True)
corr_data.to_csv("corr_data.csv", index=False)
return corr_data
def plot_correlation_ranking(
corr_data: pd.DataFrame, correlation_type: str, kd_col: str
) -> go.Figure:
# Create bar plot of correlations
data = corr_data[
(corr_data["correlation_type"] == correlation_type)
& (corr_data["kd_col"] == kd_col)
]
corr_ranking_plot = go.Figure(
data=[
go.Bar(
x=data["correlation"],
y=data["score"],
name=correlation_type,
orientation="h",
hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>",
)
]
)
corr_ranking_plot.update_layout(
title="Correlation with Binding Affinity",
yaxis_title="Score",
xaxis_title=correlation_type,
template="simple_white",
showlegend=False,
)
return corr_ranking_plot
def fake_predict_and_correlate(
spr_data_with_scores: pd.DataFrame, score_cols: list[str], main_cols: list[str]
) -> tuple[pd.DataFrame, go.Figure]:
"""Fake predict structures of all complexes and correlate the results."""
corr_data = compute_correlation_data(spr_data_with_scores, score_cols)
corr_ranking_plot = plot_correlation_ranking(corr_data, "Spearman", kd_col="KD (nM)")
cols_to_show = main_cols[:]
cols_to_show.extend(score_cols)
corr_plot = make_regression_plot(spr_data_with_scores, score_cols[0], use_log=False)
return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot
def make_regression_plot(
spr_data_with_scores: pd.DataFrame, score: str, use_log: bool
) -> go.Figure:
"""Select the regression plot to display."""
# corr_plot is a scatter plot of the regression between the binding affinity and each of the scores
scatter = go.Scatter(
x=spr_data_with_scores["KD (nM)"],
y=spr_data_with_scores[score],
name=f"Samples",
mode="markers", # Only show markers/dots, no lines
hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>",
marker=dict(color="#1f77b4"), # Set color to match default first color
)
corr_plot = go.Figure(data=scatter)
corr_plot.update_layout(
xaxis_title="KD (nM)",
yaxis_title=score,
template="simple_white",
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1,
),
xaxis_type="log" if use_log else "linear", # Set x-axis to logarithmic scale
)
# compute the regression line
if use_log:
# Take log of KD values for fitting
x_vals = np.log10(spr_data_with_scores["KD (nM)"])
else:
x_vals = spr_data_with_scores["KD (nM)"]
# Fit line to data
corr_line = np.polyfit(x_vals, spr_data_with_scores[score], 1)
# Generate x points for line
corr_line_x = np.linspace(min(x_vals), max(x_vals), 100)
corr_line_y = corr_line[0] * corr_line_x + corr_line[1]
# Convert back from log space if needed
if use_log:
corr_line_x = 10**corr_line_x
# add the regression line to the plot
corr_plot.add_trace(
go.Scatter(
x=corr_line_x,
y=corr_line_y,
mode="lines",
name=f"Regression line",
line=dict(color="#1f77b4"), # Set same color as scatter points
)
)
return corr_plot