import logging import pandas as pd from pathlib import Path import numpy as np import plotly.graph_objects as go from scipy.stats import spearmanr, pearsonr, linregress logger = logging.getLogger(__name__) SCORE_COLUMN_NAMES = { "confidence_score_boltz": "Boltz Confidence Score", "ptm_boltz": "Boltz pTM Score", "iptm_boltz": "Boltz ipTM Score", "complex_plddt_boltz": "Boltz Complex pLDDT", "complex_iplddt_boltz": "Boltz Complex ipLDDT", "complex_pde_boltz": "Boltz Complex pDE", "complex_ipde_boltz": "Boltz Complex ipDE", "interchain_pae_monomer": "AlphaFold2 GapTrick Interchain PAE", "interface_pae_monomer": "AlphaFold2 GapTrick Interface PAE", "overall_pae_monomer": "AlphaFold2 GapTrick Overall PAE", "interface_plddt_monomer": "AlphaFold2 GapTrick Interface pLDDT", "average_plddt_monomer": "AlphaFold2 GapTrick Average pLDDT", "ptm_monomer": "AlphaFold2 GapTrick pTM Score", "interface_ptm_monomer": "AlphaFold2 GapTrick Interface pTM", "interchain_pae_multimer": "AlphaFold2 Multimer Interchain PAE", "interface_pae_multimer": "AlphaFold2 Multimer Interface PAE", "overall_pae_multimer": "AlphaFold2 Multimer Overall PAE", "interface_plddt_multimer": "AlphaFold2 Multimer Interface pLDDT", "average_plddt_multimer": "AlphaFold2 Multimer Average pLDDT", "ptm_multimer": "AlphaFold2 Multimer pTM Score", "interface_ptm_multimer": "AlphaFold2 Multimer Interface pTM" } SCORE_COLUMNS = list(SCORE_COLUMN_NAMES.values()) def get_score_description(score: str) -> str: descriptions = { "Boltz Confidence Score": "The Boltz model confidence score provides an overall assessment of prediction quality (0-1, higher is better).", "Boltz pTM Score": "The Boltz model predicted TM-score (pTM) assesses the overall fold accuracy of the predicted structure (0-1, higher is better).", "Boltz ipTM Score": "The Boltz model interface pTM score (ipTM) specifically evaluates the accuracy of interface regions (0-1, higher is better).", "Boltz Complex pLDDT": "The Boltz model Complex pLDDT measures confidence in local structure predictions across the entire complex (0-100, higher is better).", "Boltz Complex ipLDDT": "The Boltz model Complex interface pLDDT (ipLDDT) focuses on confidence in interface region predictions (0-100, higher is better).", "Boltz Complex pDE": "The Boltz model Complex predicted distance error (pDE) estimates the confidence in predicted distances between residues (0-1, higher is better).", "Boltz Complex ipDE": "The Boltz model Complex interface pDE (ipDE) estimates confidence in predicted distances specifically at interfaces (0-1, higher is better).", "AlphaFold2 GapTrick Interchain PAE": "The AlphaFold2 GapTrick model interchain predicted aligned error (PAE) estimates position errors between chains in monomeric predictions (lower is better).", "AlphaFold2 GapTrick Interface PAE": "The AlphaFold2 GapTrick model interface PAE estimates position errors specifically at interfaces in monomeric predictions (lower is better).", "AlphaFold2 GapTrick Overall PAE": "The AlphaFold2 GapTrick model overall PAE estimates position errors across the entire structure in monomeric predictions (lower is better).", "AlphaFold2 GapTrick Interface pLDDT": "The AlphaFold2 GapTrick model interface pLDDT measures confidence in interface region predictions for monomeric models (0-100, higher is better).", "AlphaFold2 GapTrick Average pLDDT": "The AlphaFold2 GapTrick model average pLDDT provides the mean confidence across all residues in monomeric predictions (0-100, higher is better).", "AlphaFold2 GapTrick pTM Score": "The AlphaFold2 GapTrick model pTM score assesses overall fold accuracy in monomeric predictions (0-1, higher is better).", "AlphaFold2 GapTrick Interface pTM": "The AlphaFold2 GapTrick model interface pTM specifically evaluates accuracy of interface regions in monomeric predictions (0-1, higher is better).", "AlphaFold2 GapTrick Interchain PAE": "The AlphaFold2 GapTrick model interchain PAE estimates position errors between chains in multimeric predictions (lower is better).", "AlphaFold2 Multimer Interface PAE": "The AlphaFold2 Multimer model interface PAE estimates position errors specifically at interfaces in multimeric predictions (lower is better).", "AlphaFold2 Multimer Overall PAE": "The AlphaFold2 Multimer model overall PAE estimates position errors across the entire structure in multimeric predictions (lower is better).", "AlphaFold2 Multimer Interface pLDDT": "The AlphaFold2 Multimer model interface pLDDT measures confidence in interface region predictions for multimeric models (0-100, higher is better).", "AlphaFold2 Multimer Average pLDDT": "The AlphaFold2 Multimer model average pLDDT provides the mean confidence across all residues in multimeric predictions (0-100, higher is better).", "AlphaFold2 Multimer pTM Score": "The AlphaFold2 Multimer model pTM score assesses overall fold accuracy in multimeric predictions (0-1, higher is better).", "AlphaFold2 Multimer Interface pTM": "The AlphaFold2 Multimer model interface pTM specifically evaluates accuracy of interface regions in multimeric predictions (0-1, higher is better)." } return descriptions.get(score, "No description available for this score.") def compute_correlation_data(spr_data_with_scores: pd.DataFrame, score_cols: list[str]) -> pd.DataFrame: corr_data_file = Path("corr_data.csv") if corr_data_file.exists(): logger.info(f"Loading correlation data from {corr_data_file}") return pd.read_csv(corr_data_file) corr_data = [] spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"]) kd_col = "KD (nM)" corr_funcs = {} corr_funcs["Spearman"] = spearmanr corr_funcs["Pearson"] = pearsonr corr_funcs["R²"] = linregress for correlation_type, corr_func in corr_funcs.items(): for score_col in score_cols: logger.info(f"Computing {correlation_type} correlation between {score_col} and KD (nM)") res = corr_func(spr_data_with_scores[kd_col], spr_data_with_scores[score_col]) logger.info(f"Correlation function: {corr_func}") correlation_value = res.rvalue**2 if correlation_type == "R²" else res.statistic corr_data.append({ "correlation_type": correlation_type, "score": score_col, "correlation": correlation_value, "p-value": res.pvalue }) logger.info(f"Correlation {correlation_type} between {score_col} and KD (nM): {correlation_value}") corr_data = pd.DataFrame(corr_data) # Find the lines in corr_data with NaN values and remove them corr_data = corr_data[corr_data["correlation"].notna()] # Sort correlation data by correlation value corr_data = corr_data.sort_values('correlation', ascending=True) corr_data.to_csv("corr_data.csv", index=False) return corr_data def plot_correlation_ranking(corr_data: pd.DataFrame, correlation_type: str) -> go.Figure: # Create bar plot of correlations data = corr_data[corr_data["correlation_type"] == correlation_type] corr_ranking_plot = go.Figure(data=[ go.Bar( x=data["correlation"], y=data["score"], name=correlation_type, text=data["correlation"], orientation='h', hovertemplate="Score: %{y}
Correlation: %{x:.3f}
" ) ]) corr_ranking_plot.update_layout( title="Correlation with Binding Affinity", yaxis_title="Score", xaxis_title=correlation_type, template="simple_white", showlegend=False ) return corr_ranking_plot def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: list[str], main_cols: list[str]) -> tuple[pd.DataFrame, go.Figure]: """Fake predict structures of all complexes and correlate the results.""" corr_data = compute_correlation_data(spr_data_with_scores, score_cols) corr_ranking_plot = plot_correlation_ranking(corr_data, "Spearman") cols_to_show = main_cols[:] cols_to_show.extend(score_cols) corr_plot = make_regression_plot(spr_data_with_scores, score_cols[0], use_log=False) return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot def make_regression_plot(spr_data_with_scores: pd.DataFrame, score: str, use_log: bool) -> go.Figure: """Select the regression plot to display.""" # corr_plot is a scatter plot of the regression between the binding affinity and each of the scores scatter = go.Scatter( x=spr_data_with_scores["KD (nM)"], y=spr_data_with_scores[score], name=f"Samples", mode='markers', # Only show markers/dots, no lines hovertemplate="Score: %{y}
KD: %{x:.2f}
", marker=dict(color='#1f77b4') # Set color to match default first color ) corr_plot = go.Figure(data=scatter) corr_plot.update_layout( xaxis_title="KD (nM)", yaxis_title=score, template="simple_white", legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1, ), xaxis_type="log" if use_log else "linear" # Set x-axis to logarithmic scale ) # compute the regression line corr_line = np.polyfit(spr_data_with_scores["KD (nM)"], spr_data_with_scores[score], 1) corr_line_x = np.linspace(min(spr_data_with_scores["KD (nM)"]), max(spr_data_with_scores["KD (nM)"]), 100) corr_line_y = corr_line[0] * corr_line_x + corr_line[1] # add the regression line to the plot corr_plot.add_trace(go.Scatter( x=corr_line_x, y=corr_line_y, mode='lines', name=f"Regression line", line=dict(color='#1f77b4') # Set same color as scatter points )) return corr_plot