|
import logging |
|
import pandas as pd |
|
import numpy as np |
|
import plotly.graph_objects as go |
|
from scipy.stats import spearmanr |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
SCORE_COLUMNS = [ |
|
"confidence_score_boltz", |
|
"ptm_boltz", |
|
"iptm_boltz", |
|
"complex_plddt_boltz", |
|
"complex_iplddt_boltz", |
|
"complex_pde_boltz", |
|
"complex_ipde_boltz", |
|
"interchain_pae_monomer", |
|
"interface_pae_monomer", |
|
"overall_pae_monomer", |
|
"interface_plddt_monomer", |
|
"average_plddt_monomer", |
|
"ptm_monomer", |
|
"interface_ptm_monomer", |
|
"interchain_pae_multimer", |
|
"interface_pae_multimer", |
|
"overall_pae_multimer", |
|
"interface_plddt_multimer", |
|
"average_plddt_multimer", |
|
"ptm_multimer", |
|
"interface_ptm_multimer" |
|
] |
|
|
|
def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: list[str], main_cols: list[str]) -> tuple[pd.DataFrame, go.Figure]: |
|
"""Fake predict structures of all complexes and correlate the results.""" |
|
corr_data = [] |
|
spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"]) |
|
kd_col = "KD (nM)" |
|
for score_col in score_cols: |
|
logger.info(f"Computing correlation between {score_col} and KD (nM)") |
|
res = spearmanr(spr_data_with_scores[kd_col], spr_data_with_scores[score_col]) |
|
corr_data.append({"score": score_col, "correlation": res.statistic, "p-value": res.pvalue}) |
|
logger.info(f"Correlation between {score_col} and KD (nM): {res.statistic}") |
|
|
|
corr_data = pd.DataFrame(corr_data) |
|
|
|
corr_data = corr_data[corr_data["correlation"].notna()] |
|
|
|
corr_data = corr_data.sort_values('correlation', ascending=True) |
|
|
|
|
|
corr_ranking_plot = go.Figure(data=[ |
|
go.Bar( |
|
x=corr_data["correlation"], |
|
y=corr_data["score"], |
|
name="correlation", |
|
orientation='h', |
|
hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>" |
|
) |
|
]) |
|
corr_ranking_plot.update_layout( |
|
title="Correlation with Binding Affinity", |
|
yaxis_title="Score Type", |
|
xaxis_title="Spearman Correlation", |
|
template="simple_white", |
|
showlegend=False |
|
) |
|
|
|
cols_to_show = main_cols[:] |
|
cols_to_show.extend(score_cols) |
|
|
|
return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot |
|
|
|
def select_correlation_plot(spr_data_with_scores: pd.DataFrame, score: str) -> go.Figure: |
|
"""Select the correlation plot to display.""" |
|
|
|
scatter = go.Scatter( |
|
x=spr_data_with_scores["KD (nM)"], |
|
y=spr_data_with_scores[score], |
|
name=f"KD (nM) vs {score}", |
|
mode='markers', |
|
hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>", |
|
marker=dict(color='#1f77b4') |
|
) |
|
corr_plot = go.Figure(data=scatter) |
|
corr_plot.update_layout( |
|
xaxis_title="KD (nM)", |
|
yaxis_title=score, |
|
template="simple_white", |
|
legend=dict( |
|
orientation="h", |
|
yanchor="bottom", |
|
y=1.02, |
|
xanchor="right", |
|
x=1 |
|
) |
|
|
|
) |
|
|
|
corr_line = np.polyfit(spr_data_with_scores["KD (nM)"], spr_data_with_scores[score], 1) |
|
corr_line_x = np.linspace(min(spr_data_with_scores["KD (nM)"]), max(spr_data_with_scores["KD (nM)"]), 100) |
|
corr_line_y = corr_line[0] * corr_line_x + corr_line[1] |
|
|
|
corr_plot.add_trace(go.Scatter( |
|
x=corr_line_x, |
|
y=corr_line_y, |
|
mode='lines', |
|
name=f"Correlation", |
|
line=dict(color='#1f77b4') |
|
)) |
|
return corr_plot |