jfaustin's picture
Improve text in experiment tab and frame as ab discovery
240124f
raw
history blame
10.4 kB
import logging
from pathlib import Path
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from scipy.stats import pearsonr, spearmanr
logger = logging.getLogger(__name__)
SCORE_COLUMN_NAMES = {
"confidence_score_boltz": "Boltz Confidence Score",
"ptm_boltz": "Boltz pTM Score",
"iptm_boltz": "Boltz ipTM Score",
"complex_plddt_boltz": "Boltz Complex pLDDT",
"complex_iplddt_boltz": "Boltz Complex ipLDDT",
"complex_pde_boltz": "Boltz Complex pDE",
"complex_ipde_boltz": "Boltz Complex ipDE",
"interchain_pae_monomer": "AlphaFold2 GapTrick Interchain PAE",
"interface_pae_monomer": "AlphaFold2 GapTrick Interface PAE",
"overall_pae_monomer": "AlphaFold2 GapTrick Overall PAE",
"interface_plddt_monomer": "AlphaFold2 GapTrick Interface pLDDT",
"average_plddt_monomer": "AlphaFold2 GapTrick Average pLDDT",
"ptm_monomer": "AlphaFold2 GapTrick pTM Score",
"interface_ptm_monomer": "AlphaFold2 GapTrick Interface pTM",
"interchain_pae_multimer": "AlphaFold2 Multimer Interchain PAE",
"interface_pae_multimer": "AlphaFold2 Multimer Interface PAE",
"overall_pae_multimer": "AlphaFold2 Multimer Overall PAE",
"interface_plddt_multimer": "AlphaFold2 Multimer Interface pLDDT",
"average_plddt_multimer": "AlphaFold2 Multimer Average pLDDT",
"ptm_multimer": "AlphaFold2 Multimer pTM Score",
"interface_ptm_multimer": "AlphaFold2 Multimer Interface pTM",
}
SCORE_COLUMNS = list(SCORE_COLUMN_NAMES.values())
def get_score_description(score: str) -> str:
descriptions = {
"Boltz Confidence Score": "The Boltz model confidence score provides an overall assessment of prediction quality (0-1, higher is better).",
"Boltz pTM Score": "The Boltz model predicted TM-score (pTM) assesses the overall fold accuracy of the predicted structure (0-1, higher is better).",
"Boltz ipTM Score": "The Boltz model interface pTM score (ipTM) specifically evaluates the accuracy of interface regions (0-1, higher is better).",
"Boltz Complex pLDDT": "The Boltz model Complex pLDDT measures confidence in local structure predictions across the entire complex (0-100, higher is better).",
"Boltz Complex ipLDDT": "The Boltz model Complex interface pLDDT (ipLDDT) focuses on confidence in interface region predictions (0-100, higher is better).",
"Boltz Complex pDE": "The Boltz model Complex predicted distance error (pDE) estimates the confidence in predicted distances between residues (0-1, higher is better).",
"Boltz Complex ipDE": "The Boltz model Complex interface pDE (ipDE) estimates confidence in predicted distances specifically at interfaces (0-1, higher is better).",
"AlphaFold2 GapTrick Interchain PAE": "The AlphaFold2 GapTrick model interchain predicted aligned error (PAE) estimates position errors between chains in monomeric predictions (lower is better).",
"AlphaFold2 GapTrick Interface PAE": "The AlphaFold2 GapTrick model interface PAE estimates position errors specifically at interfaces in monomeric predictions (lower is better).",
"AlphaFold2 GapTrick Overall PAE": "The AlphaFold2 GapTrick model overall PAE estimates position errors across the entire structure in monomeric predictions (lower is better).",
"AlphaFold2 GapTrick Interface pLDDT": "The AlphaFold2 GapTrick model interface pLDDT measures confidence in interface region predictions for monomeric models (0-100, higher is better).",
"AlphaFold2 GapTrick Average pLDDT": "The AlphaFold2 GapTrick model average pLDDT provides the mean confidence across all residues in monomeric predictions (0-100, higher is better).",
"AlphaFold2 GapTrick pTM Score": "The AlphaFold2 GapTrick model pTM score assesses overall fold accuracy in monomeric predictions (0-1, higher is better).",
"AlphaFold2 GapTrick Interface pTM": "The AlphaFold2 GapTrick model interface pTM specifically evaluates accuracy of interface regions in monomeric predictions (0-1, higher is better).",
"AlphaFold2 Multimer Interface PAE": "The AlphaFold2 Multimer model interface PAE estimates position errors specifically at interfaces in multimeric predictions (lower is better).",
"AlphaFold2 Multimer Overall PAE": "The AlphaFold2 Multimer model overall PAE estimates position errors across the entire structure in multimeric predictions (lower is better).",
"AlphaFold2 Multimer Interface pLDDT": "The AlphaFold2 Multimer model interface pLDDT measures confidence in interface region predictions for multimeric models (0-100, higher is better).",
"AlphaFold2 Multimer Average pLDDT": "The AlphaFold2 Multimer model average pLDDT provides the mean confidence across all residues in multimeric predictions (0-100, higher is better).",
"AlphaFold2 Multimer pTM Score": "The AlphaFold2 Multimer model pTM score assesses overall fold accuracy in multimeric predictions (0-1, higher is better).",
"AlphaFold2 Multimer Interface pTM": "The AlphaFold2 Multimer model interface pTM specifically evaluates accuracy of interface regions in multimeric predictions (0-1, higher is better).",
}
return descriptions.get(score, "No description available for this score.")
def compute_correlation_data(
spr_data_with_scores: pd.DataFrame, score_cols: list[str]
) -> pd.DataFrame:
corr_data_file = Path("corr_data.csv")
if corr_data_file.exists():
logger.info(f"Loading correlation data from {corr_data_file}")
return pd.read_csv(corr_data_file)
corr_data = []
spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
kd_col = "KD (nM)"
corr_funcs = {}
corr_funcs["Spearman"] = spearmanr
corr_funcs["Pearson"] = pearsonr
for kd_col in ["KD (nM)", "log_kd"]:
for correlation_type, corr_func in corr_funcs.items():
for score_col in score_cols:
logger.info(
f"Computing {correlation_type} correlation between {score_col} and {kd_col}"
)
res = corr_func(
spr_data_with_scores[kd_col], spr_data_with_scores[score_col]
)
logger.info(f"Correlation function: {corr_func}")
correlation_value = res.statistic
corr_data.append(
{
"correlation_type": correlation_type,
"kd_col": kd_col,
"score": score_col,
"correlation": correlation_value,
"p-value": res.pvalue,
}
)
corr_data = pd.DataFrame(corr_data)
# Find the lines in corr_data with NaN values and remove them
corr_data = corr_data[corr_data["correlation"].notna()]
# Sort correlation data by correlation value
corr_data = corr_data.sort_values("correlation", ascending=True)
corr_data.to_csv("corr_data.csv", index=False)
return corr_data
def plot_correlation_ranking(
corr_data: pd.DataFrame, correlation_type: str, kd_col: str
) -> go.Figure:
# Create bar plot of correlations
data = corr_data[
(corr_data["correlation_type"] == correlation_type)
& (corr_data["kd_col"] == kd_col)
]
corr_ranking_plot = go.Figure(
data=[
go.Bar(
x=data["correlation"],
y=data["score"],
name=correlation_type,
orientation="h",
hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>",
)
]
)
corr_ranking_plot.update_layout(
title="Correlation with Binding Affinity",
yaxis_title="Score",
xaxis_title=correlation_type,
template="simple_white",
showlegend=False,
)
return corr_ranking_plot
def fake_predict_and_correlate(
spr_data_with_scores: pd.DataFrame, score_cols: list[str], main_cols: list[str]
) -> tuple[pd.DataFrame, go.Figure]:
"""Fake predict structures of all complexes and correlate the results."""
corr_data = compute_correlation_data(spr_data_with_scores, score_cols)
corr_ranking_plot = plot_correlation_ranking(corr_data, "Spearman", kd_col="KD (nM)")
cols_to_show = main_cols[:]
cols_to_show.extend(score_cols)
corr_plot = make_regression_plot(spr_data_with_scores, score_cols[0], use_log=False)
return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot
def make_regression_plot(
spr_data_with_scores: pd.DataFrame, score: str, use_log: bool
) -> go.Figure:
"""Select the regression plot to display."""
# corr_plot is a scatter plot of the regression between the binding affinity and each of the scores
scatter = go.Scatter(
x=spr_data_with_scores["KD (nM)"],
y=spr_data_with_scores[score],
name=f"Samples",
mode="markers", # Only show markers/dots, no lines
hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>",
marker=dict(color="#1f77b4"), # Set color to match default first color
)
corr_plot = go.Figure(data=scatter)
corr_plot.update_layout(
xaxis_title="KD (nM)",
yaxis_title=score,
template="simple_white",
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1,
),
xaxis_type="log" if use_log else "linear", # Set x-axis to logarithmic scale
)
# compute the regression line
if use_log:
# Take log of KD values for fitting
x_vals = np.log10(spr_data_with_scores["KD (nM)"])
else:
x_vals = spr_data_with_scores["KD (nM)"]
# Fit line to data
corr_line = np.polyfit(x_vals, spr_data_with_scores[score], 1)
# Generate x points for line
corr_line_x = np.linspace(min(x_vals), max(x_vals), 100)
corr_line_y = corr_line[0] * corr_line_x + corr_line[1]
# Convert back from log space if needed
if use_log:
corr_line_x = 10**corr_line_x
# add the regression line to the plot
corr_plot.add_trace(
go.Scatter(
x=corr_line_x,
y=corr_line_y,
mode="lines",
name=f"Regression line",
line=dict(color="#1f77b4"), # Set same color as scatter points
)
)
return corr_plot