Add more correlation metrics to correlation tab

#8
by jfaustin - opened
folding_studio_demo/app.py CHANGED
@@ -9,8 +9,12 @@ from gradio_molecule3d import Molecule3D
9
 
10
  from folding_studio_demo.correlate import (
11
  SCORE_COLUMNS,
 
12
  fake_predict_and_correlate,
13
- make_correlation_plot,
 
 
 
14
  )
15
  from folding_studio_demo.predict import predict, predict_comparison
16
 
@@ -107,7 +111,6 @@ def simple_prediction(api_key: str) -> None:
107
  elem_classes="gradient-button",
108
  elem_id="predict-btn",
109
  variant="primary",
110
- # css=f".gradio-container #predict-btn {{background: linear-gradient(90deg, {BLUE}, {PURPLE});}}",
111
  )
112
 
113
  with gr.Row():
@@ -145,7 +148,6 @@ def model_comparison(api_key: str) -> None:
145
  elem_classes=["gradient-button"],
146
  elem_id="compare-models-btn",
147
  variant="primary",
148
- # css=f".gradio-container #compare-models-btn {{background: linear-gradient(90deg, {BLUE}, {PURPLE});}}"
149
  )
150
 
151
  with gr.Row():
@@ -181,6 +183,7 @@ def create_correlation_tab():
181
  of binding strength.
182
  """)
183
  spr_data_with_scores = pd.read_csv("spr_af_scores_mapped.csv")
 
184
  prettified_columns = {
185
  "antibody_name": "Antibody Name",
186
  "KD (nM)": "KD (nM)",
@@ -209,12 +212,19 @@ def create_correlation_tab():
209
  "Predict structures of all complexes",
210
  elem_classes="gradient-button",
211
  variant="primary",
212
- # css=f".gradio-container #fake-predict-btn {{background: linear-gradient(90deg, {BLUE}, {PURPLE});}}",
213
  )
214
  with gr.Row():
215
  prediction_dataframe = gr.Dataframe(label="Predicted Structures Data")
216
  with gr.Row():
217
- correlation_ranking_plot = gr.Plot(label="Correlation ranking")
 
 
 
 
 
 
 
 
218
  with gr.Row():
219
  with gr.Column():
220
  with gr.Row():
@@ -225,6 +235,13 @@ def create_correlation_tab():
225
  # Add checkbox for log scale and update plot when either input changes
226
  with gr.Row():
227
  log_scale = gr.Checkbox(label="Display x-axis on logarithmic scale", value=False)
 
 
 
 
 
 
 
228
  with gr.Column():
229
  correlation_plot = gr.Plot(label="Correlation with binding affinity")
230
 
@@ -232,21 +249,33 @@ def create_correlation_tab():
232
  fn=lambda x: fake_predict_and_correlate(
233
  spr_data_with_scores, SCORE_COLUMNS, ["Antibody Name", "KD (nM)"]
234
  ),
235
- inputs=None,
236
  outputs=[prediction_dataframe, correlation_ranking_plot, correlation_plot],
237
  )
238
 
239
- def update_plot(score, use_log):
240
- return make_correlation_plot(spr_data_with_scores, score, use_log)
241
 
 
 
 
 
 
 
242
  correlation_column.change(
243
- fn=update_plot,
244
  inputs=[correlation_column, log_scale],
245
  outputs=correlation_plot,
246
  )
 
 
 
 
 
 
247
 
248
  log_scale.change(
249
- fn=update_plot,
250
  inputs=[correlation_column, log_scale],
251
  outputs=correlation_plot,
252
  )
 
9
 
10
  from folding_studio_demo.correlate import (
11
  SCORE_COLUMNS,
12
+ SCORE_COLUMN_NAMES,
13
  fake_predict_and_correlate,
14
+ make_regression_plot,
15
+ compute_correlation_data,
16
+ plot_correlation_ranking,
17
+ get_score_description
18
  )
19
  from folding_studio_demo.predict import predict, predict_comparison
20
 
 
111
  elem_classes="gradient-button",
112
  elem_id="predict-btn",
113
  variant="primary",
 
114
  )
115
 
116
  with gr.Row():
 
148
  elem_classes=["gradient-button"],
149
  elem_id="compare-models-btn",
150
  variant="primary",
 
151
  )
152
 
153
  with gr.Row():
 
183
  of binding strength.
184
  """)
185
  spr_data_with_scores = pd.read_csv("spr_af_scores_mapped.csv")
186
+ spr_data_with_scores = spr_data_with_scores.rename(columns=SCORE_COLUMN_NAMES)
187
  prettified_columns = {
188
  "antibody_name": "Antibody Name",
189
  "KD (nM)": "KD (nM)",
 
212
  "Predict structures of all complexes",
213
  elem_classes="gradient-button",
214
  variant="primary",
 
215
  )
216
  with gr.Row():
217
  prediction_dataframe = gr.Dataframe(label="Predicted Structures Data")
218
  with gr.Row():
219
+ with gr.Row():
220
+ correlation_type = gr.Radio(
221
+ choices=["Spearman", "Pearson", "R²"],
222
+ value="Spearman",
223
+ label="Correlation Type",
224
+ interactive=True
225
+ )
226
+ with gr.Row():
227
+ correlation_ranking_plot = gr.Plot(label="Correlation ranking")
228
  with gr.Row():
229
  with gr.Column():
230
  with gr.Row():
 
235
  # Add checkbox for log scale and update plot when either input changes
236
  with gr.Row():
237
  log_scale = gr.Checkbox(label="Display x-axis on logarithmic scale", value=False)
238
+ with gr.Row():
239
+ score_description = gr.Markdown(get_score_description(correlation_column.value))
240
+ correlation_column.change(
241
+ fn=lambda x: get_score_description(x),
242
+ inputs=correlation_column,
243
+ outputs=score_description
244
+ )
245
  with gr.Column():
246
  correlation_plot = gr.Plot(label="Correlation with binding affinity")
247
 
 
249
  fn=lambda x: fake_predict_and_correlate(
250
  spr_data_with_scores, SCORE_COLUMNS, ["Antibody Name", "KD (nM)"]
251
  ),
252
+ inputs=[correlation_type],
253
  outputs=[prediction_dataframe, correlation_ranking_plot, correlation_plot],
254
  )
255
 
256
+ def update_regression_plot(score, use_log):
257
+ return make_regression_plot(spr_data_with_scores, score, use_log)
258
 
259
+ def update_correlation_plot(correlation_type):
260
+ logger.info(f"Updating correlation plot for {correlation_type}")
261
+ corr_data = compute_correlation_data(spr_data_with_scores, SCORE_COLUMNS)
262
+ logger.info(f"Correlation data: {corr_data}")
263
+ return plot_correlation_ranking(corr_data, correlation_type)
264
+
265
  correlation_column.change(
266
+ fn=update_regression_plot,
267
  inputs=[correlation_column, log_scale],
268
  outputs=correlation_plot,
269
  )
270
+
271
+ correlation_type.change(
272
+ fn=update_correlation_plot,
273
+ inputs=[correlation_type],
274
+ outputs=correlation_ranking_plot,
275
+ )
276
 
277
  log_scale.change(
278
+ fn=update_regression_plot,
279
  inputs=[correlation_column, log_scale],
280
  outputs=correlation_plot,
281
  )
folding_studio_demo/correlate.py CHANGED
@@ -1,45 +1,90 @@
1
  import logging
2
  import pandas as pd
 
3
  import numpy as np
4
  import plotly.graph_objects as go
5
- from scipy.stats import spearmanr
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
- SCORE_COLUMNS = [
10
- "confidence_score_boltz",
11
- "ptm_boltz",
12
- "iptm_boltz",
13
- "complex_plddt_boltz",
14
- "complex_iplddt_boltz",
15
- "complex_pde_boltz",
16
- "complex_ipde_boltz",
17
- "interchain_pae_monomer",
18
- "interface_pae_monomer",
19
- "overall_pae_monomer",
20
- "interface_plddt_monomer",
21
- "average_plddt_monomer",
22
- "ptm_monomer",
23
- "interface_ptm_monomer",
24
- "interchain_pae_multimer",
25
- "interface_pae_multimer",
26
- "overall_pae_multimer",
27
- "interface_plddt_multimer",
28
- "average_plddt_multimer",
29
- "ptm_multimer",
30
- "interface_ptm_multimer"
31
- ]
32
 
33
- def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: list[str], main_cols: list[str]) -> tuple[pd.DataFrame, go.Figure]:
34
- """Fake predict structures of all complexes and correlate the results."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  corr_data = []
36
  spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
37
  kd_col = "KD (nM)"
38
- for score_col in score_cols:
39
- logger.info(f"Computing correlation between {score_col} and KD (nM)")
40
- res = spearmanr(spr_data_with_scores[kd_col], spr_data_with_scores[score_col])
41
- corr_data.append({"score": score_col, "correlation": res.statistic, "p-value": res.pvalue})
42
- logger.info(f"Correlation between {score_col} and KD (nM): {res.statistic}")
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  corr_data = pd.DataFrame(corr_data)
45
  # Find the lines in corr_data with NaN values and remove them
@@ -47,34 +92,48 @@ def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: l
47
  # Sort correlation data by correlation value
48
  corr_data = corr_data.sort_values('correlation', ascending=True)
49
 
 
 
 
 
 
50
  # Create bar plot of correlations
 
51
  corr_ranking_plot = go.Figure(data=[
52
  go.Bar(
53
- x=corr_data["correlation"],
54
- y=corr_data["score"],
55
- name="correlation",
 
56
  orientation='h',
57
  hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>"
58
  )
59
  ])
60
  corr_ranking_plot.update_layout(
61
  title="Correlation with Binding Affinity",
62
- yaxis_title="Score Type",
63
- xaxis_title="Spearman Correlation",
64
  template="simple_white",
65
  showlegend=False
66
  )
 
 
 
 
 
 
 
67
 
68
  cols_to_show = main_cols[:]
69
  cols_to_show.extend(score_cols)
70
 
71
- corr_plot = make_correlation_plot(spr_data_with_scores, score_cols[0], use_log=False)
72
 
73
  return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot
74
 
75
- def make_correlation_plot(spr_data_with_scores: pd.DataFrame, score: str, use_log: bool) -> go.Figure:
76
- """Select the correlation plot to display."""
77
- # corr_plot is a scatter plot of the correlation between the binding affinity and each of the scores
78
  scatter = go.Scatter(
79
  x=spr_data_with_scores["KD (nM)"],
80
  y=spr_data_with_scores[score],
@@ -97,11 +156,11 @@ def make_correlation_plot(spr_data_with_scores: pd.DataFrame, score: str, use_lo
97
  ),
98
  xaxis_type="log" if use_log else "linear" # Set x-axis to logarithmic scale
99
  )
100
- # compute the correlation line
101
  corr_line = np.polyfit(spr_data_with_scores["KD (nM)"], spr_data_with_scores[score], 1)
102
  corr_line_x = np.linspace(min(spr_data_with_scores["KD (nM)"]), max(spr_data_with_scores["KD (nM)"]), 100)
103
  corr_line_y = corr_line[0] * corr_line_x + corr_line[1]
104
- # add the correlation line to the plot
105
  corr_plot.add_trace(go.Scatter(
106
  x=corr_line_x,
107
  y=corr_line_y,
 
1
  import logging
2
  import pandas as pd
3
+ from pathlib import Path
4
  import numpy as np
5
  import plotly.graph_objects as go
6
+ from scipy.stats import spearmanr, pearsonr, linregress
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
+ SCORE_COLUMN_NAMES = {
11
+ "confidence_score_boltz": "Boltz Confidence Score",
12
+ "ptm_boltz": "Boltz pTM Score",
13
+ "iptm_boltz": "Boltz ipTM Score",
14
+ "complex_plddt_boltz": "Boltz Complex pLDDT",
15
+ "complex_iplddt_boltz": "Boltz Complex ipLDDT",
16
+ "complex_pde_boltz": "Boltz Complex pDE",
17
+ "complex_ipde_boltz": "Boltz Complex ipDE",
18
+ "interchain_pae_monomer": "AlphaFold2 GapTrick Interchain PAE",
19
+ "interface_pae_monomer": "AlphaFold2 GapTrick Interface PAE",
20
+ "overall_pae_monomer": "AlphaFold2 GapTrick Overall PAE",
21
+ "interface_plddt_monomer": "AlphaFold2 GapTrick Interface pLDDT",
22
+ "average_plddt_monomer": "AlphaFold2 GapTrick Average pLDDT",
23
+ "ptm_monomer": "AlphaFold2 GapTrick pTM Score",
24
+ "interface_ptm_monomer": "AlphaFold2 GapTrick Interface pTM",
25
+ "interchain_pae_multimer": "AlphaFold2 Multimer Interchain PAE",
26
+ "interface_pae_multimer": "AlphaFold2 Multimer Interface PAE",
27
+ "overall_pae_multimer": "AlphaFold2 Multimer Overall PAE",
28
+ "interface_plddt_multimer": "AlphaFold2 Multimer Interface pLDDT",
29
+ "average_plddt_multimer": "AlphaFold2 Multimer Average pLDDT",
30
+ "ptm_multimer": "AlphaFold2 Multimer pTM Score",
31
+ "interface_ptm_multimer": "AlphaFold2 Multimer Interface pTM"
32
+ }
33
 
34
+ SCORE_COLUMNS = list(SCORE_COLUMN_NAMES.values())
35
+
36
+ def get_score_description(score: str) -> str:
37
+ descriptions = {
38
+ "Boltz Confidence Score": "The Boltz model confidence score provides an overall assessment of prediction quality (0-1, higher is better).",
39
+ "Boltz pTM Score": "The Boltz model predicted TM-score (pTM) assesses the overall fold accuracy of the predicted structure (0-1, higher is better).",
40
+ "Boltz ipTM Score": "The Boltz model interface pTM score (ipTM) specifically evaluates the accuracy of interface regions (0-1, higher is better).",
41
+ "Boltz Complex pLDDT": "The Boltz model Complex pLDDT measures confidence in local structure predictions across the entire complex (0-100, higher is better).",
42
+ "Boltz Complex ipLDDT": "The Boltz model Complex interface pLDDT (ipLDDT) focuses on confidence in interface region predictions (0-100, higher is better).",
43
+ "Boltz Complex pDE": "The Boltz model Complex predicted distance error (pDE) estimates the confidence in predicted distances between residues (0-1, higher is better).",
44
+ "Boltz Complex ipDE": "The Boltz model Complex interface pDE (ipDE) estimates confidence in predicted distances specifically at interfaces (0-1, higher is better).",
45
+ "AlphaFold2 GapTrick Interchain PAE": "The AlphaFold2 GapTrick model interchain predicted aligned error (PAE) estimates position errors between chains in monomeric predictions (lower is better).",
46
+ "AlphaFold2 GapTrick Interface PAE": "The AlphaFold2 GapTrick model interface PAE estimates position errors specifically at interfaces in monomeric predictions (lower is better).",
47
+ "AlphaFold2 GapTrick Overall PAE": "The AlphaFold2 GapTrick model overall PAE estimates position errors across the entire structure in monomeric predictions (lower is better).",
48
+ "AlphaFold2 GapTrick Interface pLDDT": "The AlphaFold2 GapTrick model interface pLDDT measures confidence in interface region predictions for monomeric models (0-100, higher is better).",
49
+ "AlphaFold2 GapTrick Average pLDDT": "The AlphaFold2 GapTrick model average pLDDT provides the mean confidence across all residues in monomeric predictions (0-100, higher is better).",
50
+ "AlphaFold2 GapTrick pTM Score": "The AlphaFold2 GapTrick model pTM score assesses overall fold accuracy in monomeric predictions (0-1, higher is better).",
51
+ "AlphaFold2 GapTrick Interface pTM": "The AlphaFold2 GapTrick model interface pTM specifically evaluates accuracy of interface regions in monomeric predictions (0-1, higher is better).",
52
+ "AlphaFold2 GapTrick Interchain PAE": "The AlphaFold2 GapTrick model interchain PAE estimates position errors between chains in multimeric predictions (lower is better).",
53
+ "AlphaFold2 Multimer Interface PAE": "The AlphaFold2 Multimer model interface PAE estimates position errors specifically at interfaces in multimeric predictions (lower is better).",
54
+ "AlphaFold2 Multimer Overall PAE": "The AlphaFold2 Multimer model overall PAE estimates position errors across the entire structure in multimeric predictions (lower is better).",
55
+ "AlphaFold2 Multimer Interface pLDDT": "The AlphaFold2 Multimer model interface pLDDT measures confidence in interface region predictions for multimeric models (0-100, higher is better).",
56
+ "AlphaFold2 Multimer Average pLDDT": "The AlphaFold2 Multimer model average pLDDT provides the mean confidence across all residues in multimeric predictions (0-100, higher is better).",
57
+ "AlphaFold2 Multimer pTM Score": "The AlphaFold2 Multimer model pTM score assesses overall fold accuracy in multimeric predictions (0-1, higher is better).",
58
+ "AlphaFold2 Multimer Interface pTM": "The AlphaFold2 Multimer model interface pTM specifically evaluates accuracy of interface regions in multimeric predictions (0-1, higher is better)."
59
+ }
60
+ return descriptions.get(score, "No description available for this score.")
61
+
62
+ def compute_correlation_data(spr_data_with_scores: pd.DataFrame, score_cols: list[str]) -> pd.DataFrame:
63
+ corr_data_file = Path("corr_data.csv")
64
+ if corr_data_file.exists():
65
+ logger.info(f"Loading correlation data from {corr_data_file}")
66
+ return pd.read_csv(corr_data_file)
67
+
68
  corr_data = []
69
  spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
70
  kd_col = "KD (nM)"
71
+ corr_funcs = {}
72
+ corr_funcs["Spearman"] = spearmanr
73
+ corr_funcs["Pearson"] = pearsonr
74
+ corr_funcs[""] = linregress
75
+ for correlation_type, corr_func in corr_funcs.items():
76
+ for score_col in score_cols:
77
+ logger.info(f"Computing {correlation_type} correlation between {score_col} and KD (nM)")
78
+ res = corr_func(spr_data_with_scores[kd_col], spr_data_with_scores[score_col])
79
+ logger.info(f"Correlation function: {corr_func}")
80
+ correlation_value = res.rvalue**2 if correlation_type == "R²" else res.statistic
81
+ corr_data.append({
82
+ "correlation_type": correlation_type,
83
+ "score": score_col,
84
+ "correlation": correlation_value,
85
+ "p-value": res.pvalue
86
+ })
87
+ logger.info(f"Correlation {correlation_type} between {score_col} and KD (nM): {correlation_value}")
88
 
89
  corr_data = pd.DataFrame(corr_data)
90
  # Find the lines in corr_data with NaN values and remove them
 
92
  # Sort correlation data by correlation value
93
  corr_data = corr_data.sort_values('correlation', ascending=True)
94
 
95
+ corr_data.to_csv("corr_data.csv", index=False)
96
+
97
+ return corr_data
98
+
99
+ def plot_correlation_ranking(corr_data: pd.DataFrame, correlation_type: str) -> go.Figure:
100
  # Create bar plot of correlations
101
+ data = corr_data[corr_data["correlation_type"] == correlation_type]
102
  corr_ranking_plot = go.Figure(data=[
103
  go.Bar(
104
+ x=data["correlation"],
105
+ y=data["score"],
106
+ name=correlation_type,
107
+ text=data["correlation"],
108
  orientation='h',
109
  hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>"
110
  )
111
  ])
112
  corr_ranking_plot.update_layout(
113
  title="Correlation with Binding Affinity",
114
+ yaxis_title="Score",
115
+ xaxis_title=correlation_type,
116
  template="simple_white",
117
  showlegend=False
118
  )
119
+ return corr_ranking_plot
120
+
121
+ def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: list[str], main_cols: list[str]) -> tuple[pd.DataFrame, go.Figure]:
122
+ """Fake predict structures of all complexes and correlate the results."""
123
+
124
+ corr_data = compute_correlation_data(spr_data_with_scores, score_cols)
125
+ corr_ranking_plot = plot_correlation_ranking(corr_data, "Spearman")
126
 
127
  cols_to_show = main_cols[:]
128
  cols_to_show.extend(score_cols)
129
 
130
+ corr_plot = make_regression_plot(spr_data_with_scores, score_cols[0], use_log=False)
131
 
132
  return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot
133
 
134
+ def make_regression_plot(spr_data_with_scores: pd.DataFrame, score: str, use_log: bool) -> go.Figure:
135
+ """Select the regression plot to display."""
136
+ # corr_plot is a scatter plot of the regression between the binding affinity and each of the scores
137
  scatter = go.Scatter(
138
  x=spr_data_with_scores["KD (nM)"],
139
  y=spr_data_with_scores[score],
 
156
  ),
157
  xaxis_type="log" if use_log else "linear" # Set x-axis to logarithmic scale
158
  )
159
+ # compute the regression line
160
  corr_line = np.polyfit(spr_data_with_scores["KD (nM)"], spr_data_with_scores[score], 1)
161
  corr_line_x = np.linspace(min(spr_data_with_scores["KD (nM)"]), max(spr_data_with_scores["KD (nM)"]), 100)
162
  corr_line_y = corr_line[0] * corr_line_x + corr_line[1]
163
+ # add the regression line to the plot
164
  corr_plot.add_trace(go.Scatter(
165
  x=corr_line_x,
166
  y=corr_line_y,