File size: 4,154 Bytes
90a13ac
 
 
 
 
 
 
 
216492b
90a13ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0af20ef
90a13ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0af20ef
3a103f2
 
 
 
 
 
90a13ac
3a103f2
 
 
 
 
0af20ef
 
3a103f2
 
90a13ac
 
3a103f2
90a13ac
0af20ef
 
 
 
 
 
 
3a103f2
90a13ac
3a103f2
 
 
 
 
 
 
 
 
0af20ef
 
3a103f2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import logging
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from scipy.stats import spearmanr

logger = logging.getLogger(__name__)

SCORE_COLUMNS = [
        "confidence_score_boltz",
        "ptm_boltz",
        "iptm_boltz", 
        "complex_plddt_boltz",
        "complex_iplddt_boltz",
        "complex_pde_boltz",
        "complex_ipde_boltz",
        "interchain_pae_monomer",
        "interface_pae_monomer",
        "overall_pae_monomer",
        "interface_plddt_monomer",
        "average_plddt_monomer",
        "ptm_monomer",
        "interface_ptm_monomer", 
        "interchain_pae_multimer",
        "interface_pae_multimer",
        "overall_pae_multimer",
        "interface_plddt_multimer",
        "average_plddt_multimer",
        "ptm_multimer",
        "interface_ptm_multimer"
    ]

def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: list[str], main_cols: list[str]) -> tuple[pd.DataFrame, go.Figure]:
    """Fake predict structures of all complexes and correlate the results."""
    corr_data = []
    spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
    kd_col = "KD (nM)"
    for score_col in score_cols:
        logger.info(f"Computing correlation between {score_col} and KD (nM)")
        res = spearmanr(spr_data_with_scores[kd_col], spr_data_with_scores[score_col])
        corr_data.append({"score": score_col, "correlation": res.statistic, "p-value": res.pvalue})
        logger.info(f"Correlation between {score_col} and KD (nM): {res.statistic}")

    corr_data = pd.DataFrame(corr_data)
    # Find the lines in corr_data with NaN values and remove them
    corr_data = corr_data[corr_data["correlation"].notna()]
    # Sort correlation data by correlation value
    corr_data = corr_data.sort_values('correlation', ascending=True)
    
    # Create bar plot of correlations
    corr_ranking_plot = go.Figure(data=[
        go.Bar(
            x=corr_data["correlation"],
            y=corr_data["score"], 
            name="correlation",
            orientation='h',
            hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>"
        )
    ])
    corr_ranking_plot.update_layout(
        title="Correlation with Binding Affinity",
        yaxis_title="Score Type",
        xaxis_title="Spearman Correlation",
        template="simple_white",
        showlegend=False
    )

    cols_to_show = main_cols[:]
    cols_to_show.extend(score_cols)

    return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot

def select_correlation_plot(spr_data_with_scores: pd.DataFrame, score: str) -> go.Figure:
    """Select the correlation plot to display."""
    # corr_plot is a scatter plot of the correlation between the binding affinity and each of the scores
    scatter =  go.Scatter(
            x=spr_data_with_scores["KD (nM)"],
            y=spr_data_with_scores[score],
            name=f"KD (nM) vs {score}",
            mode='markers',  # Only show markers/dots, no lines
            hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>",
            marker=dict(color='#1f77b4')  # Set color to match default first color
        )   
    corr_plot = go.Figure(data=scatter)
    corr_plot.update_layout(
        xaxis_title="KD (nM)",
        yaxis_title=score,
        template="simple_white",
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
        # xaxis_type="log"  # Set x-axis to logarithmic scale
    )
    # compute the correlation line
    corr_line = np.polyfit(spr_data_with_scores["KD (nM)"], spr_data_with_scores[score], 1)
    corr_line_x = np.linspace(min(spr_data_with_scores["KD (nM)"]), max(spr_data_with_scores["KD (nM)"]), 100)
    corr_line_y = corr_line[0] * corr_line_x + corr_line[1]
    # add the correlation line to the plot
    corr_plot.add_trace(go.Scatter(
        x=corr_line_x,
        y=corr_line_y,
        mode='lines',
        name=f"Correlation",
        line=dict(color='#1f77b4')  # Set same color as scatter points
    ))
    return corr_plot