jfaustin's picture
improve corr plot
0af20ef
raw
history blame
4.15 kB
import logging
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from scipy.stats import spearmanr
logger = logging.getLogger(__name__)
SCORE_COLUMNS = [
"confidence_score_boltz",
"ptm_boltz",
"iptm_boltz",
"complex_plddt_boltz",
"complex_iplddt_boltz",
"complex_pde_boltz",
"complex_ipde_boltz",
"interchain_pae_monomer",
"interface_pae_monomer",
"overall_pae_monomer",
"interface_plddt_monomer",
"average_plddt_monomer",
"ptm_monomer",
"interface_ptm_monomer",
"interchain_pae_multimer",
"interface_pae_multimer",
"overall_pae_multimer",
"interface_plddt_multimer",
"average_plddt_multimer",
"ptm_multimer",
"interface_ptm_multimer"
]
def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: list[str], main_cols: list[str]) -> tuple[pd.DataFrame, go.Figure]:
"""Fake predict structures of all complexes and correlate the results."""
corr_data = []
spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
kd_col = "KD (nM)"
for score_col in score_cols:
logger.info(f"Computing correlation between {score_col} and KD (nM)")
res = spearmanr(spr_data_with_scores[kd_col], spr_data_with_scores[score_col])
corr_data.append({"score": score_col, "correlation": res.statistic, "p-value": res.pvalue})
logger.info(f"Correlation between {score_col} and KD (nM): {res.statistic}")
corr_data = pd.DataFrame(corr_data)
# Find the lines in corr_data with NaN values and remove them
corr_data = corr_data[corr_data["correlation"].notna()]
# Sort correlation data by correlation value
corr_data = corr_data.sort_values('correlation', ascending=True)
# Create bar plot of correlations
corr_ranking_plot = go.Figure(data=[
go.Bar(
x=corr_data["correlation"],
y=corr_data["score"],
name="correlation",
orientation='h',
hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>"
)
])
corr_ranking_plot.update_layout(
title="Correlation with Binding Affinity",
yaxis_title="Score Type",
xaxis_title="Spearman Correlation",
template="simple_white",
showlegend=False
)
cols_to_show = main_cols[:]
cols_to_show.extend(score_cols)
return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot
def select_correlation_plot(spr_data_with_scores: pd.DataFrame, score: str) -> go.Figure:
"""Select the correlation plot to display."""
# corr_plot is a scatter plot of the correlation between the binding affinity and each of the scores
scatter = go.Scatter(
x=spr_data_with_scores["KD (nM)"],
y=spr_data_with_scores[score],
name=f"KD (nM) vs {score}",
mode='markers', # Only show markers/dots, no lines
hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>",
marker=dict(color='#1f77b4') # Set color to match default first color
)
corr_plot = go.Figure(data=scatter)
corr_plot.update_layout(
xaxis_title="KD (nM)",
yaxis_title=score,
template="simple_white",
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
)
# xaxis_type="log" # Set x-axis to logarithmic scale
)
# compute the correlation line
corr_line = np.polyfit(spr_data_with_scores["KD (nM)"], spr_data_with_scores[score], 1)
corr_line_x = np.linspace(min(spr_data_with_scores["KD (nM)"]), max(spr_data_with_scores["KD (nM)"]), 100)
corr_line_y = corr_line[0] * corr_line_x + corr_line[1]
# add the correlation line to the plot
corr_plot.add_trace(go.Scatter(
x=corr_line_x,
y=corr_line_y,
mode='lines',
name=f"Correlation",
line=dict(color='#1f77b4') # Set same color as scatter points
))
return corr_plot