refactor
Browse files- folding_studio_demo/app.py +3 -81
- folding_studio_demo/correlate.py +90 -0
folding_studio_demo/app.py
CHANGED
@@ -6,11 +6,9 @@ import gradio as gr
|
|
6 |
from folding_studio_data_models import FoldingModel
|
7 |
from gradio_molecule3d import Molecule3D
|
8 |
import pandas as pd
|
9 |
-
import numpy as np
|
10 |
-
from scipy.stats import spearmanr
|
11 |
-
import plotly.graph_objects as go
|
12 |
|
13 |
from folding_studio_demo.predict import predict
|
|
|
14 |
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
@@ -144,59 +142,6 @@ def model_comparison(api_key: str) -> None:
|
|
144 |
outputs=[mol_output, metrics_plot],
|
145 |
)
|
146 |
|
147 |
-
def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: list[str]) -> tuple[pd.DataFrame, go.Figure]:
|
148 |
-
"""Fake predict structures of all complexes and correlate the results."""
|
149 |
-
corr_data = []
|
150 |
-
spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
|
151 |
-
kd_col = "KD (nM)"
|
152 |
-
for score_col in score_cols:
|
153 |
-
logger.info(f"Computing correlation between {score_col} and KD (nM)")
|
154 |
-
res = spearmanr(spr_data_with_scores[kd_col], spr_data_with_scores[score_col])
|
155 |
-
corr_data.append({"score": score_col, "correlation": res.statistic, "p-value": res.pvalue})
|
156 |
-
logger.info(f"Correlation between {score_col} and KD (nM): {res.statistic}")
|
157 |
-
|
158 |
-
corr_data = pd.DataFrame(corr_data)
|
159 |
-
# Find the lines in corr_data with NaN values and remove them
|
160 |
-
corr_data = corr_data[corr_data["correlation"].notna()]
|
161 |
-
# Sort correlation data by correlation value
|
162 |
-
corr_data = corr_data.sort_values('correlation', ascending=True)
|
163 |
-
|
164 |
-
# Create bar plot of correlations
|
165 |
-
corr_ranking_plot = go.Figure(data=[
|
166 |
-
go.Bar(
|
167 |
-
x=corr_data["correlation"],
|
168 |
-
y=corr_data["score"],
|
169 |
-
name="correlation",
|
170 |
-
orientation='h',
|
171 |
-
hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>"
|
172 |
-
)
|
173 |
-
])
|
174 |
-
corr_ranking_plot.update_layout(
|
175 |
-
title="Correlation with Binding Affinity",
|
176 |
-
yaxis_title="Score Type",
|
177 |
-
xaxis_title="Spearman Correlation",
|
178 |
-
template="simple_white",
|
179 |
-
showlegend=False
|
180 |
-
)
|
181 |
-
|
182 |
-
# corr_plot is a scatter plot of the correlation between the binding affinity and each of the scores
|
183 |
-
scatters = []
|
184 |
-
for score_col in score_cols:
|
185 |
-
scatters.append(
|
186 |
-
go.Scatter(
|
187 |
-
x=spr_data_with_scores[kd_col],
|
188 |
-
y=spr_data_with_scores[score_col],
|
189 |
-
name=f"{kd_col} vs {score_col}",
|
190 |
-
mode='markers', # Only show markers/dots, no lines
|
191 |
-
hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>"
|
192 |
-
)
|
193 |
-
)
|
194 |
-
corr_plot = go.Figure(data=scatters)
|
195 |
-
|
196 |
-
cols_to_show = [kd_col]
|
197 |
-
cols_to_show.extend(score_cols)
|
198 |
-
|
199 |
-
return spr_data_with_scores[cols_to_show], corr_ranking_plot, corr_plot
|
200 |
|
201 |
def create_correlation_tab():
|
202 |
gr.Markdown("# Upload binding affinity data")
|
@@ -217,36 +162,13 @@ def create_correlation_tab():
|
|
217 |
correlation_ranking_plot = gr.Plot(label="Correlation ranking")
|
218 |
correlation_plot = gr.Plot(label="Correlation with binding affinity")
|
219 |
|
220 |
-
cols = [
|
221 |
-
"confidence_score_boltz",
|
222 |
-
"ptm_boltz",
|
223 |
-
"iptm_boltz",
|
224 |
-
"complex_plddt_boltz",
|
225 |
-
"complex_iplddt_boltz",
|
226 |
-
"complex_pde_boltz",
|
227 |
-
"complex_ipde_boltz",
|
228 |
-
"interchain_pae_monomer",
|
229 |
-
"interface_pae_monomer",
|
230 |
-
"overall_pae_monomer",
|
231 |
-
"interface_plddt_monomer",
|
232 |
-
"average_plddt_monomer",
|
233 |
-
"ptm_monomer",
|
234 |
-
"interface_ptm_monomer",
|
235 |
-
"interchain_pae_multimer",
|
236 |
-
"interface_pae_multimer",
|
237 |
-
"overall_pae_multimer",
|
238 |
-
"interface_plddt_multimer",
|
239 |
-
"average_plddt_multimer",
|
240 |
-
"ptm_multimer",
|
241 |
-
"interface_ptm_multimer"
|
242 |
-
]
|
243 |
csv_file.change(
|
244 |
-
fn=lambda file: spr_data_with_scores.drop(columns=
|
245 |
inputs=csv_file,
|
246 |
outputs=dataframe
|
247 |
)
|
248 |
fake_predict_btn.click(
|
249 |
-
fn=lambda x: fake_predict_and_correlate(spr_data_with_scores,
|
250 |
inputs=None,
|
251 |
outputs=[prediction_dataframe, correlation_ranking_plot, correlation_plot]
|
252 |
)
|
|
|
6 |
from folding_studio_data_models import FoldingModel
|
7 |
from gradio_molecule3d import Molecule3D
|
8 |
import pandas as pd
|
|
|
|
|
|
|
9 |
|
10 |
from folding_studio_demo.predict import predict
|
11 |
+
from folding_studio_demo.correlate import fake_predict_and_correlate, COLUMNS
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
|
|
142 |
outputs=[mol_output, metrics_plot],
|
143 |
)
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
def create_correlation_tab():
|
147 |
gr.Markdown("# Upload binding affinity data")
|
|
|
162 |
correlation_ranking_plot = gr.Plot(label="Correlation ranking")
|
163 |
correlation_plot = gr.Plot(label="Correlation with binding affinity")
|
164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
csv_file.change(
|
166 |
+
fn=lambda file: spr_data_with_scores.drop(columns=COLUMNS) if file else None,
|
167 |
inputs=csv_file,
|
168 |
outputs=dataframe
|
169 |
)
|
170 |
fake_predict_btn.click(
|
171 |
+
fn=lambda x: fake_predict_and_correlate(spr_data_with_scores, COLUMNS),
|
172 |
inputs=None,
|
173 |
outputs=[prediction_dataframe, correlation_ranking_plot, correlation_plot]
|
174 |
)
|
folding_studio_demo/correlate.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import plotly.graph_objects as go
|
5 |
+
from scipy.stats import spearmanr
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
COLUMNS = [
|
10 |
+
"confidence_score_boltz",
|
11 |
+
"ptm_boltz",
|
12 |
+
"iptm_boltz",
|
13 |
+
"complex_plddt_boltz",
|
14 |
+
"complex_iplddt_boltz",
|
15 |
+
"complex_pde_boltz",
|
16 |
+
"complex_ipde_boltz",
|
17 |
+
"interchain_pae_monomer",
|
18 |
+
"interface_pae_monomer",
|
19 |
+
"overall_pae_monomer",
|
20 |
+
"interface_plddt_monomer",
|
21 |
+
"average_plddt_monomer",
|
22 |
+
"ptm_monomer",
|
23 |
+
"interface_ptm_monomer",
|
24 |
+
"interchain_pae_multimer",
|
25 |
+
"interface_pae_multimer",
|
26 |
+
"overall_pae_multimer",
|
27 |
+
"interface_plddt_multimer",
|
28 |
+
"average_plddt_multimer",
|
29 |
+
"ptm_multimer",
|
30 |
+
"interface_ptm_multimer"
|
31 |
+
]
|
32 |
+
|
33 |
+
def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: list[str]) -> tuple[pd.DataFrame, go.Figure]:
|
34 |
+
"""Fake predict structures of all complexes and correlate the results."""
|
35 |
+
corr_data = []
|
36 |
+
spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
|
37 |
+
kd_col = "KD (nM)"
|
38 |
+
for score_col in score_cols:
|
39 |
+
logger.info(f"Computing correlation between {score_col} and KD (nM)")
|
40 |
+
res = spearmanr(spr_data_with_scores[kd_col], spr_data_with_scores[score_col])
|
41 |
+
corr_data.append({"score": score_col, "correlation": res.statistic, "p-value": res.pvalue})
|
42 |
+
logger.info(f"Correlation between {score_col} and KD (nM): {res.statistic}")
|
43 |
+
|
44 |
+
corr_data = pd.DataFrame(corr_data)
|
45 |
+
# Find the lines in corr_data with NaN values and remove them
|
46 |
+
corr_data = corr_data[corr_data["correlation"].notna()]
|
47 |
+
# Sort correlation data by correlation value
|
48 |
+
corr_data = corr_data.sort_values('correlation', ascending=True)
|
49 |
+
|
50 |
+
# Create bar plot of correlations
|
51 |
+
corr_ranking_plot = go.Figure(data=[
|
52 |
+
go.Bar(
|
53 |
+
x=corr_data["correlation"],
|
54 |
+
y=corr_data["score"],
|
55 |
+
name="correlation",
|
56 |
+
orientation='h',
|
57 |
+
hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>"
|
58 |
+
)
|
59 |
+
])
|
60 |
+
corr_ranking_plot.update_layout(
|
61 |
+
title="Correlation with Binding Affinity",
|
62 |
+
yaxis_title="Score Type",
|
63 |
+
xaxis_title="Spearman Correlation",
|
64 |
+
template="simple_white",
|
65 |
+
showlegend=False
|
66 |
+
)
|
67 |
+
|
68 |
+
# corr_plot is a scatter plot of the correlation between the binding affinity and each of the scores
|
69 |
+
scatters = []
|
70 |
+
for score_col in score_cols:
|
71 |
+
scatters.append(
|
72 |
+
go.Scatter(
|
73 |
+
x=spr_data_with_scores[kd_col],
|
74 |
+
y=spr_data_with_scores[score_col],
|
75 |
+
name=f"{kd_col} vs {score_col}",
|
76 |
+
mode='markers', # Only show markers/dots, no lines
|
77 |
+
hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>"
|
78 |
+
)
|
79 |
+
)
|
80 |
+
corr_plot = go.Figure(data=scatters)
|
81 |
+
corr_plot.update_layout(
|
82 |
+
xaxis_title="KD (nM)",
|
83 |
+
yaxis_title="Score",
|
84 |
+
template="simple_white",
|
85 |
+
xaxis_type="log" # Set x-axis to logarithmic scale
|
86 |
+
)
|
87 |
+
cols_to_show = [kd_col]
|
88 |
+
cols_to_show.extend(score_cols)
|
89 |
+
|
90 |
+
return spr_data_with_scores[cols_to_show], corr_ranking_plot, corr_plot
|