import logging from pathlib import Path import numpy as np import pandas as pd import plotly.graph_objects as go from scipy.stats import pearsonr, spearmanr logger = logging.getLogger(__name__) SCORE_COLUMN_NAMES = { "confidence_score_boltz": "Boltz Confidence Score", "ptm_boltz": "Boltz pTM Score", "iptm_boltz": "Boltz ipTM Score", "complex_plddt_boltz": "Boltz Complex pLDDT", "complex_iplddt_boltz": "Boltz Complex ipLDDT", "complex_pde_boltz": "Boltz Complex pDE", "complex_ipde_boltz": "Boltz Complex ipDE", "interchain_pae_monomer": "AlphaFold2 GapTrick Interchain PAE", "interface_pae_monomer": "AlphaFold2 GapTrick Interface PAE", "overall_pae_monomer": "AlphaFold2 GapTrick Overall PAE", "interface_plddt_monomer": "AlphaFold2 GapTrick Interface pLDDT", "average_plddt_monomer": "AlphaFold2 GapTrick Average pLDDT", "ptm_monomer": "AlphaFold2 GapTrick pTM Score", "interface_ptm_monomer": "AlphaFold2 GapTrick Interface pTM", "interchain_pae_multimer": "AlphaFold2 Multimer Interchain PAE", "interface_pae_multimer": "AlphaFold2 Multimer Interface PAE", "overall_pae_multimer": "AlphaFold2 Multimer Overall PAE", "interface_plddt_multimer": "AlphaFold2 Multimer Interface pLDDT", "average_plddt_multimer": "AlphaFold2 Multimer Average pLDDT", "ptm_multimer": "AlphaFold2 Multimer pTM Score", "interface_ptm_multimer": "AlphaFold2 Multimer Interface pTM", } SCORE_COLUMNS = list(SCORE_COLUMN_NAMES.values()) def get_score_description(score: str) -> str: descriptions = { "Boltz Confidence Score": "The Boltz model confidence score provides an overall assessment of prediction quality (0-1, higher is better).", "Boltz pTM Score": "The Boltz model predicted TM-score (pTM) assesses the overall fold accuracy of the predicted structure (0-1, higher is better).", "Boltz ipTM Score": "The Boltz model interface pTM score (ipTM) specifically evaluates the accuracy of interface regions (0-1, higher is better).", "Boltz Complex pLDDT": "The Boltz model Complex pLDDT measures confidence in local structure predictions across the entire complex (0-100, higher is better).", "Boltz Complex ipLDDT": "The Boltz model Complex interface pLDDT (ipLDDT) focuses on confidence in interface region predictions (0-100, higher is better).", "Boltz Complex pDE": "The Boltz model Complex predicted distance error (pDE) estimates the confidence in predicted distances between residues (0-1, higher is better).", "Boltz Complex ipDE": "The Boltz model Complex interface pDE (ipDE) estimates confidence in predicted distances specifically at interfaces (0-1, higher is better).", "AlphaFold2 GapTrick Interchain PAE": "The AlphaFold2 GapTrick model interchain predicted aligned error (PAE) estimates position errors between chains in monomeric predictions (lower is better).", "AlphaFold2 GapTrick Interface PAE": "The AlphaFold2 GapTrick model interface PAE estimates position errors specifically at interfaces in monomeric predictions (lower is better).", "AlphaFold2 GapTrick Overall PAE": "The AlphaFold2 GapTrick model overall PAE estimates position errors across the entire structure in monomeric predictions (lower is better).", "AlphaFold2 GapTrick Interface pLDDT": "The AlphaFold2 GapTrick model interface pLDDT measures confidence in interface region predictions for monomeric models (0-100, higher is better).", "AlphaFold2 GapTrick Average pLDDT": "The AlphaFold2 GapTrick model average pLDDT provides the mean confidence across all residues in monomeric predictions (0-100, higher is better).", "AlphaFold2 GapTrick pTM Score": "The AlphaFold2 GapTrick model pTM score assesses overall fold accuracy in monomeric predictions (0-1, higher is better).", "AlphaFold2 GapTrick Interface pTM": "The AlphaFold2 GapTrick model interface pTM specifically evaluates accuracy of interface regions in monomeric predictions (0-1, higher is better).", "AlphaFold2 Multimer Interface PAE": "The AlphaFold2 Multimer model interface PAE estimates position errors specifically at interfaces in multimeric predictions (lower is better).", "AlphaFold2 Multimer Overall PAE": "The AlphaFold2 Multimer model overall PAE estimates position errors across the entire structure in multimeric predictions (lower is better).", "AlphaFold2 Multimer Interface pLDDT": "The AlphaFold2 Multimer model interface pLDDT measures confidence in interface region predictions for multimeric models (0-100, higher is better).", "AlphaFold2 Multimer Average pLDDT": "The AlphaFold2 Multimer model average pLDDT provides the mean confidence across all residues in multimeric predictions (0-100, higher is better).", "AlphaFold2 Multimer pTM Score": "The AlphaFold2 Multimer model pTM score assesses overall fold accuracy in multimeric predictions (0-1, higher is better).", "AlphaFold2 Multimer Interface pTM": "The AlphaFold2 Multimer model interface pTM specifically evaluates accuracy of interface regions in multimeric predictions (0-1, higher is better).", } return descriptions.get(score, "No description available for this score.") def compute_correlation_data( spr_data_with_scores: pd.DataFrame, score_cols: list[str] ) -> pd.DataFrame: corr_data_file = Path("corr_data.csv") if corr_data_file.exists(): logger.info(f"Loading correlation data from {corr_data_file}") return pd.read_csv(corr_data_file) corr_data = [] spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"]) kd_col = "KD (nM)" corr_funcs = {} corr_funcs["Spearman"] = spearmanr corr_funcs["Pearson"] = pearsonr for kd_col in ["KD (nM)", "log_kd"]: for correlation_type, corr_func in corr_funcs.items(): for score_col in score_cols: logger.info( f"Computing {correlation_type} correlation between {score_col} and {kd_col}" ) res = corr_func( spr_data_with_scores[kd_col], spr_data_with_scores[score_col] ) logger.info(f"Correlation function: {corr_func}") correlation_value = res.statistic corr_data.append( { "correlation_type": correlation_type, "kd_col": kd_col, "score": score_col, "correlation": correlation_value, "p-value": res.pvalue, } ) corr_data = pd.DataFrame(corr_data) # Find the lines in corr_data with NaN values and remove them corr_data = corr_data[corr_data["correlation"].notna()] # Sort correlation data by correlation value corr_data = corr_data.sort_values("correlation", ascending=True) corr_data.to_csv("corr_data.csv", index=False) return corr_data def plot_correlation_ranking( corr_data: pd.DataFrame, correlation_type: str, kd_col: str ) -> go.Figure: # Create bar plot of correlations data = corr_data[ (corr_data["correlation_type"] == correlation_type) & (corr_data["kd_col"] == kd_col) ] corr_ranking_plot = go.Figure( data=[ go.Bar( x=data["correlation"], y=data["score"], name=correlation_type, orientation="h", hovertemplate="Score: %{y}
Correlation: %{x:.3f}
", ) ] ) corr_ranking_plot.update_layout( title="Correlation with Binding Affinity", yaxis_title="Score", xaxis_title=correlation_type, template="simple_white", showlegend=False, ) return corr_ranking_plot def fake_predict_and_correlate( spr_data_with_scores: pd.DataFrame, score_cols: list[str], main_cols: list[str] ) -> tuple[pd.DataFrame, go.Figure]: """Fake predict structures of all complexes and correlate the results.""" corr_data = compute_correlation_data(spr_data_with_scores, score_cols) corr_ranking_plot = plot_correlation_ranking(corr_data, "Spearman", kd_col="KD (nM)") cols_to_show = main_cols[:] cols_to_show.extend(score_cols) corr_plot = make_regression_plot(spr_data_with_scores, score_cols[0], use_log=False) return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot def make_regression_plot( spr_data_with_scores: pd.DataFrame, score: str, use_log: bool ) -> go.Figure: """Select the regression plot to display.""" # corr_plot is a scatter plot of the regression between the binding affinity and each of the scores scatter = go.Scatter( x=spr_data_with_scores["KD (nM)"], y=spr_data_with_scores[score], name=f"Samples", mode="markers", # Only show markers/dots, no lines hovertemplate="Score: %{y}
KD: %{x:.2f}
", marker=dict(color="#1f77b4"), # Set color to match default first color ) corr_plot = go.Figure(data=scatter) corr_plot.update_layout( xaxis_title="KD (nM)", yaxis_title=score, template="simple_white", legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1, ), xaxis_type="log" if use_log else "linear", # Set x-axis to logarithmic scale ) # compute the regression line if use_log: # Take log of KD values for fitting x_vals = np.log10(spr_data_with_scores["KD (nM)"]) else: x_vals = spr_data_with_scores["KD (nM)"] # Fit line to data corr_line = np.polyfit(x_vals, spr_data_with_scores[score], 1) # Generate x points for line corr_line_x = np.linspace(min(x_vals), max(x_vals), 100) corr_line_y = corr_line[0] * corr_line_x + corr_line[1] # Convert back from log space if needed if use_log: corr_line_x = 10**corr_line_x # add the regression line to the plot corr_plot.add_trace( go.Scatter( x=corr_line_x, y=corr_line_y, mode="lines", name=f"Regression line", line=dict(color="#1f77b4"), # Set same color as scatter points ) ) return corr_plot