jfaustin commited on
Commit
90a13ac
·
1 Parent(s): 03588ca
folding_studio_demo/app.py CHANGED
@@ -6,11 +6,9 @@ import gradio as gr
6
  from folding_studio_data_models import FoldingModel
7
  from gradio_molecule3d import Molecule3D
8
  import pandas as pd
9
- import numpy as np
10
- from scipy.stats import spearmanr
11
- import plotly.graph_objects as go
12
 
13
  from folding_studio_demo.predict import predict
 
14
 
15
  logger = logging.getLogger(__name__)
16
 
@@ -144,59 +142,6 @@ def model_comparison(api_key: str) -> None:
144
  outputs=[mol_output, metrics_plot],
145
  )
146
 
147
- def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: list[str]) -> tuple[pd.DataFrame, go.Figure]:
148
- """Fake predict structures of all complexes and correlate the results."""
149
- corr_data = []
150
- spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
151
- kd_col = "KD (nM)"
152
- for score_col in score_cols:
153
- logger.info(f"Computing correlation between {score_col} and KD (nM)")
154
- res = spearmanr(spr_data_with_scores[kd_col], spr_data_with_scores[score_col])
155
- corr_data.append({"score": score_col, "correlation": res.statistic, "p-value": res.pvalue})
156
- logger.info(f"Correlation between {score_col} and KD (nM): {res.statistic}")
157
-
158
- corr_data = pd.DataFrame(corr_data)
159
- # Find the lines in corr_data with NaN values and remove them
160
- corr_data = corr_data[corr_data["correlation"].notna()]
161
- # Sort correlation data by correlation value
162
- corr_data = corr_data.sort_values('correlation', ascending=True)
163
-
164
- # Create bar plot of correlations
165
- corr_ranking_plot = go.Figure(data=[
166
- go.Bar(
167
- x=corr_data["correlation"],
168
- y=corr_data["score"],
169
- name="correlation",
170
- orientation='h',
171
- hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>"
172
- )
173
- ])
174
- corr_ranking_plot.update_layout(
175
- title="Correlation with Binding Affinity",
176
- yaxis_title="Score Type",
177
- xaxis_title="Spearman Correlation",
178
- template="simple_white",
179
- showlegend=False
180
- )
181
-
182
- # corr_plot is a scatter plot of the correlation between the binding affinity and each of the scores
183
- scatters = []
184
- for score_col in score_cols:
185
- scatters.append(
186
- go.Scatter(
187
- x=spr_data_with_scores[kd_col],
188
- y=spr_data_with_scores[score_col],
189
- name=f"{kd_col} vs {score_col}",
190
- mode='markers', # Only show markers/dots, no lines
191
- hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>"
192
- )
193
- )
194
- corr_plot = go.Figure(data=scatters)
195
-
196
- cols_to_show = [kd_col]
197
- cols_to_show.extend(score_cols)
198
-
199
- return spr_data_with_scores[cols_to_show], corr_ranking_plot, corr_plot
200
 
201
  def create_correlation_tab():
202
  gr.Markdown("# Upload binding affinity data")
@@ -217,36 +162,13 @@ def create_correlation_tab():
217
  correlation_ranking_plot = gr.Plot(label="Correlation ranking")
218
  correlation_plot = gr.Plot(label="Correlation with binding affinity")
219
 
220
- cols = [
221
- "confidence_score_boltz",
222
- "ptm_boltz",
223
- "iptm_boltz",
224
- "complex_plddt_boltz",
225
- "complex_iplddt_boltz",
226
- "complex_pde_boltz",
227
- "complex_ipde_boltz",
228
- "interchain_pae_monomer",
229
- "interface_pae_monomer",
230
- "overall_pae_monomer",
231
- "interface_plddt_monomer",
232
- "average_plddt_monomer",
233
- "ptm_monomer",
234
- "interface_ptm_monomer",
235
- "interchain_pae_multimer",
236
- "interface_pae_multimer",
237
- "overall_pae_multimer",
238
- "interface_plddt_multimer",
239
- "average_plddt_multimer",
240
- "ptm_multimer",
241
- "interface_ptm_multimer"
242
- ]
243
  csv_file.change(
244
- fn=lambda file: spr_data_with_scores.drop(columns=cols) if file else None,
245
  inputs=csv_file,
246
  outputs=dataframe
247
  )
248
  fake_predict_btn.click(
249
- fn=lambda x: fake_predict_and_correlate(spr_data_with_scores, cols),
250
  inputs=None,
251
  outputs=[prediction_dataframe, correlation_ranking_plot, correlation_plot]
252
  )
 
6
  from folding_studio_data_models import FoldingModel
7
  from gradio_molecule3d import Molecule3D
8
  import pandas as pd
 
 
 
9
 
10
  from folding_studio_demo.predict import predict
11
+ from folding_studio_demo.correlate import fake_predict_and_correlate, COLUMNS
12
 
13
  logger = logging.getLogger(__name__)
14
 
 
142
  outputs=[mol_output, metrics_plot],
143
  )
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  def create_correlation_tab():
147
  gr.Markdown("# Upload binding affinity data")
 
162
  correlation_ranking_plot = gr.Plot(label="Correlation ranking")
163
  correlation_plot = gr.Plot(label="Correlation with binding affinity")
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  csv_file.change(
166
+ fn=lambda file: spr_data_with_scores.drop(columns=COLUMNS) if file else None,
167
  inputs=csv_file,
168
  outputs=dataframe
169
  )
170
  fake_predict_btn.click(
171
+ fn=lambda x: fake_predict_and_correlate(spr_data_with_scores, COLUMNS),
172
  inputs=None,
173
  outputs=[prediction_dataframe, correlation_ranking_plot, correlation_plot]
174
  )
folding_studio_demo/correlate.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import pandas as pd
3
+ import numpy as np
4
+ import plotly.graph_objects as go
5
+ from scipy.stats import spearmanr
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ COLUMNS = [
10
+ "confidence_score_boltz",
11
+ "ptm_boltz",
12
+ "iptm_boltz",
13
+ "complex_plddt_boltz",
14
+ "complex_iplddt_boltz",
15
+ "complex_pde_boltz",
16
+ "complex_ipde_boltz",
17
+ "interchain_pae_monomer",
18
+ "interface_pae_monomer",
19
+ "overall_pae_monomer",
20
+ "interface_plddt_monomer",
21
+ "average_plddt_monomer",
22
+ "ptm_monomer",
23
+ "interface_ptm_monomer",
24
+ "interchain_pae_multimer",
25
+ "interface_pae_multimer",
26
+ "overall_pae_multimer",
27
+ "interface_plddt_multimer",
28
+ "average_plddt_multimer",
29
+ "ptm_multimer",
30
+ "interface_ptm_multimer"
31
+ ]
32
+
33
+ def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: list[str]) -> tuple[pd.DataFrame, go.Figure]:
34
+ """Fake predict structures of all complexes and correlate the results."""
35
+ corr_data = []
36
+ spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
37
+ kd_col = "KD (nM)"
38
+ for score_col in score_cols:
39
+ logger.info(f"Computing correlation between {score_col} and KD (nM)")
40
+ res = spearmanr(spr_data_with_scores[kd_col], spr_data_with_scores[score_col])
41
+ corr_data.append({"score": score_col, "correlation": res.statistic, "p-value": res.pvalue})
42
+ logger.info(f"Correlation between {score_col} and KD (nM): {res.statistic}")
43
+
44
+ corr_data = pd.DataFrame(corr_data)
45
+ # Find the lines in corr_data with NaN values and remove them
46
+ corr_data = corr_data[corr_data["correlation"].notna()]
47
+ # Sort correlation data by correlation value
48
+ corr_data = corr_data.sort_values('correlation', ascending=True)
49
+
50
+ # Create bar plot of correlations
51
+ corr_ranking_plot = go.Figure(data=[
52
+ go.Bar(
53
+ x=corr_data["correlation"],
54
+ y=corr_data["score"],
55
+ name="correlation",
56
+ orientation='h',
57
+ hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>"
58
+ )
59
+ ])
60
+ corr_ranking_plot.update_layout(
61
+ title="Correlation with Binding Affinity",
62
+ yaxis_title="Score Type",
63
+ xaxis_title="Spearman Correlation",
64
+ template="simple_white",
65
+ showlegend=False
66
+ )
67
+
68
+ # corr_plot is a scatter plot of the correlation between the binding affinity and each of the scores
69
+ scatters = []
70
+ for score_col in score_cols:
71
+ scatters.append(
72
+ go.Scatter(
73
+ x=spr_data_with_scores[kd_col],
74
+ y=spr_data_with_scores[score_col],
75
+ name=f"{kd_col} vs {score_col}",
76
+ mode='markers', # Only show markers/dots, no lines
77
+ hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>"
78
+ )
79
+ )
80
+ corr_plot = go.Figure(data=scatters)
81
+ corr_plot.update_layout(
82
+ xaxis_title="KD (nM)",
83
+ yaxis_title="Score",
84
+ template="simple_white",
85
+ xaxis_type="log" # Set x-axis to logarithmic scale
86
+ )
87
+ cols_to_show = [kd_col]
88
+ cols_to_show.extend(score_cols)
89
+
90
+ return spr_data_with_scores[cols_to_show], corr_ranking_plot, corr_plot