Merge remote-tracking branch 'origin/main' into pr/12
Browse files- folding_studio_demo/app.py +128 -77
- folding_studio_demo/correlate.py +85 -58
- folding_studio_demo/model_fasta_validators.py +9 -9
- folding_studio_demo/models.py +146 -7
- folding_studio_demo/predict.py +210 -94
folding_studio_demo/app.py
CHANGED
@@ -39,8 +39,8 @@ MOLECULE_REPS = [
|
|
39 |
|
40 |
|
41 |
MODEL_CHOICES = [
|
42 |
-
|
43 |
-
|
44 |
# ("SoloSeq", FoldingModel.SOLOSEQ),
|
45 |
("Boltz-1", FoldingModel.BOLTZ),
|
46 |
("Chai-1", FoldingModel.CHAI),
|
@@ -49,6 +49,15 @@ MODEL_CHOICES = [
|
|
49 |
|
50 |
DEFAULT_SEQ = "MALWMRLLPLLALLALWGPDPAAA"
|
51 |
MODEL_EXAMPLES = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
FoldingModel.BOLTZ: [
|
53 |
["Monomer", f">A|protein\n{DEFAULT_SEQ}"],
|
54 |
["Multimer", f">A|protein\n{DEFAULT_SEQ}\n>B|protein\n{DEFAULT_SEQ}"],
|
@@ -70,27 +79,31 @@ def sequence_input(dropdown: gr.Dropdown | None = None) -> gr.Textbox:
|
|
70 |
Returns:
|
71 |
gr.Textbox: Sequence input component
|
72 |
"""
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
-
examples = gr.Examples(
|
81 |
-
examples=MODEL_EXAMPLES[FoldingModel.BOLTZ],
|
82 |
-
inputs=[dummy, sequence],
|
83 |
-
)
|
84 |
if dropdown is not None:
|
85 |
dropdown.change(
|
86 |
fn=lambda x: gr.Dataset(samples=MODEL_EXAMPLES[x]),
|
87 |
inputs=[dropdown],
|
88 |
outputs=[examples.dataset],
|
89 |
)
|
90 |
-
file_input = gr.File(
|
91 |
-
label="Upload a FASTA file",
|
92 |
-
file_types=[".fasta", ".fa"],
|
93 |
-
)
|
94 |
|
95 |
def _process_file(file: gr.File | None) -> gr.Textbox:
|
96 |
if file is None:
|
@@ -115,7 +128,7 @@ def simple_prediction(api_key: str) -> None:
|
|
115 |
"""
|
116 |
gr.Markdown(
|
117 |
"""
|
118 |
-
|
119 |
|
120 |
It will be run in the background and the results will be displayed in the output section.
|
121 |
The output will contain the protein structure and the pLDDT plot.
|
@@ -157,7 +170,19 @@ def model_comparison(api_key: str) -> None:
|
|
157 |
Args:
|
158 |
api_key (str): Folding Studio API key
|
159 |
"""
|
|
|
|
|
|
|
160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
with gr.Row():
|
162 |
models = gr.CheckboxGroup(
|
163 |
label="Model",
|
@@ -176,6 +201,9 @@ def model_comparison(api_key: str) -> None:
|
|
176 |
variant="primary",
|
177 |
)
|
178 |
with gr.Row():
|
|
|
|
|
|
|
179 |
chai_predictions = gr.CheckboxGroup(label="Chai", visible=False)
|
180 |
protenix_predictions = gr.CheckboxGroup(label="Protenix", visible=False)
|
181 |
boltz_predictions = gr.CheckboxGroup(label="Boltz", visible=False)
|
@@ -186,28 +214,50 @@ def model_comparison(api_key: str) -> None:
|
|
186 |
metrics_plot = gr.Plot(label="pLDDT")
|
187 |
|
188 |
# Store the initial predictions
|
189 |
-
|
190 |
-
plddt_fig = gr.State()
|
191 |
|
192 |
predict_btn.click(
|
193 |
fn=predict_comparison,
|
194 |
inputs=[sequence, api_key, models],
|
195 |
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
chai_predictions,
|
197 |
boltz_predictions,
|
198 |
protenix_predictions,
|
199 |
-
aligned_paths,
|
200 |
-
plddt_fig,
|
201 |
],
|
|
|
202 |
)
|
203 |
|
204 |
# Handle checkbox changes
|
205 |
-
for checkbox in [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
checkbox.change(
|
207 |
fn=filter_predictions,
|
208 |
inputs=[
|
209 |
-
|
210 |
-
|
|
|
|
|
211 |
chai_predictions,
|
212 |
boltz_predictions,
|
213 |
protenix_predictions,
|
@@ -242,63 +292,64 @@ def create_correlation_tab():
|
|
242 |
"antigen_sequence": "Antigen Sequence",
|
243 |
}
|
244 |
spr_data_with_scores = spr_data_with_scores.rename(columns=prettified_columns)
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
)
|
258 |
|
259 |
gr.Markdown("# Prediction and correlation")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
with gr.Row():
|
261 |
-
|
262 |
-
"
|
263 |
-
|
264 |
-
|
|
|
|
|
265 |
)
|
|
|
266 |
with gr.Row():
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
choices=
|
272 |
-
|
273 |
-
|
274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
)
|
276 |
-
with gr.Row():
|
277 |
-
correlation_ranking_plot = gr.Plot(label="Correlation ranking")
|
278 |
-
with gr.Row():
|
279 |
-
with gr.Column():
|
280 |
-
with gr.Row():
|
281 |
-
# User can select the columns to display in the correlation plot
|
282 |
-
correlation_column = gr.Dropdown(
|
283 |
-
label="Score data to display",
|
284 |
-
choices=SCORE_COLUMNS,
|
285 |
-
multiselect=False,
|
286 |
-
value=SCORE_COLUMNS[0],
|
287 |
-
)
|
288 |
-
# Add checkbox for log scale and update plot when either input changes
|
289 |
-
with gr.Row():
|
290 |
-
log_scale = gr.Checkbox(
|
291 |
-
label="Display x-axis on logarithmic scale", value=False
|
292 |
-
)
|
293 |
-
with gr.Row():
|
294 |
-
score_description = gr.Markdown(
|
295 |
-
get_score_description(correlation_column.value)
|
296 |
-
)
|
297 |
-
correlation_column.change(
|
298 |
-
fn=lambda x: get_score_description(x),
|
299 |
-
inputs=correlation_column,
|
300 |
-
outputs=score_description,
|
301 |
-
)
|
302 |
with gr.Column():
|
303 |
regression_plot = gr.Plot(label="Correlation with binding affinity")
|
304 |
|
@@ -333,7 +384,7 @@ def create_correlation_tab():
|
|
333 |
|
334 |
log_scale.change(
|
335 |
fn=update_regression_plot,
|
336 |
-
inputs=[correlation_column, log_scale],
|
337 |
outputs=regression_plot,
|
338 |
)
|
339 |
|
@@ -360,7 +411,7 @@ def __main__():
|
|
360 |
)
|
361 |
api_key = gr.Textbox(label="Folding Studio API Key", type="password")
|
362 |
gr.Markdown("## Demo Usage")
|
363 |
-
with gr.Tab("🚀
|
364 |
simple_prediction(api_key)
|
365 |
with gr.Tab("📊 Model Comparison"):
|
366 |
model_comparison(api_key)
|
|
|
39 |
|
40 |
|
41 |
MODEL_CHOICES = [
|
42 |
+
("AlphaFold2", FoldingModel.AF2),
|
43 |
+
("OpenFold", FoldingModel.OPENFOLD),
|
44 |
# ("SoloSeq", FoldingModel.SOLOSEQ),
|
45 |
("Boltz-1", FoldingModel.BOLTZ),
|
46 |
("Chai-1", FoldingModel.CHAI),
|
|
|
49 |
|
50 |
DEFAULT_SEQ = "MALWMRLLPLLALLALWGPDPAAA"
|
51 |
MODEL_EXAMPLES = {
|
52 |
+
FoldingModel.AF2: [
|
53 |
+
["Monomer", f">A\n{DEFAULT_SEQ}"],
|
54 |
+
["Multimer", f">A\n{DEFAULT_SEQ}\n>B\n{DEFAULT_SEQ}"],
|
55 |
+
],
|
56 |
+
FoldingModel.OPENFOLD: [
|
57 |
+
["Monomer", f">A\n{DEFAULT_SEQ}"],
|
58 |
+
["Multimer", f">A\n{DEFAULT_SEQ}\n>B\n{DEFAULT_SEQ}"],
|
59 |
+
],
|
60 |
+
FoldingModel.SOLOSEQ: [["Monomer", f">A\n{DEFAULT_SEQ}"]],
|
61 |
FoldingModel.BOLTZ: [
|
62 |
["Monomer", f">A|protein\n{DEFAULT_SEQ}"],
|
63 |
["Multimer", f">A|protein\n{DEFAULT_SEQ}\n>B|protein\n{DEFAULT_SEQ}"],
|
|
|
79 |
Returns:
|
80 |
gr.Textbox: Sequence input component
|
81 |
"""
|
82 |
+
with gr.Row(equal_height=True):
|
83 |
+
with gr.Column():
|
84 |
+
sequence = gr.Textbox(
|
85 |
+
label="Protein Sequence",
|
86 |
+
lines=2,
|
87 |
+
placeholder="Enter a protein sequence or upload a FASTA file",
|
88 |
+
)
|
89 |
+
dummy = gr.Textbox(label="Complex type", visible=False)
|
90 |
+
|
91 |
+
examples = gr.Examples(
|
92 |
+
examples=MODEL_EXAMPLES[FoldingModel.BOLTZ],
|
93 |
+
inputs=[dummy, sequence],
|
94 |
+
)
|
95 |
+
file_input = gr.File(
|
96 |
+
label="Upload a FASTA file",
|
97 |
+
file_types=[".fasta", ".fa"],
|
98 |
+
scale=0,
|
99 |
+
)
|
100 |
|
|
|
|
|
|
|
|
|
101 |
if dropdown is not None:
|
102 |
dropdown.change(
|
103 |
fn=lambda x: gr.Dataset(samples=MODEL_EXAMPLES[x]),
|
104 |
inputs=[dropdown],
|
105 |
outputs=[examples.dataset],
|
106 |
)
|
|
|
|
|
|
|
|
|
107 |
|
108 |
def _process_file(file: gr.File | None) -> gr.Textbox:
|
109 |
if file is None:
|
|
|
128 |
"""
|
129 |
gr.Markdown(
|
130 |
"""
|
131 |
+
## Predict a Protein Structure
|
132 |
|
133 |
It will be run in the background and the results will be displayed in the output section.
|
134 |
The output will contain the protein structure and the pLDDT plot.
|
|
|
170 |
Args:
|
171 |
api_key (str): Folding Studio API key
|
172 |
"""
|
173 |
+
gr.Markdown(
|
174 |
+
"""
|
175 |
+
## Compare Folding Models
|
176 |
|
177 |
+
Select multiple models to compare their predictions on your protein sequence.
|
178 |
+
You can either enter the sequence directly or upload a FASTA file.
|
179 |
+
|
180 |
+
The selected models will run in parallel and generate:
|
181 |
+
- 3D structures of your protein that you can visualize and compare
|
182 |
+
- pLDDT confidence scores plotted for each residue
|
183 |
+
|
184 |
+
"""
|
185 |
+
)
|
186 |
with gr.Row():
|
187 |
models = gr.CheckboxGroup(
|
188 |
label="Model",
|
|
|
201 |
variant="primary",
|
202 |
)
|
203 |
with gr.Row():
|
204 |
+
af2_predictions = gr.CheckboxGroup(label="AlphaFold2", visible=False)
|
205 |
+
openfold_predictions = gr.CheckboxGroup(label="OpenFold", visible=False)
|
206 |
+
solo_predictions = gr.CheckboxGroup(label="SoloSeq", visible=False)
|
207 |
chai_predictions = gr.CheckboxGroup(label="Chai", visible=False)
|
208 |
protenix_predictions = gr.CheckboxGroup(label="Protenix", visible=False)
|
209 |
boltz_predictions = gr.CheckboxGroup(label="Boltz", visible=False)
|
|
|
214 |
metrics_plot = gr.Plot(label="pLDDT")
|
215 |
|
216 |
# Store the initial predictions
|
217 |
+
prediction_outputs = gr.State()
|
|
|
218 |
|
219 |
predict_btn.click(
|
220 |
fn=predict_comparison,
|
221 |
inputs=[sequence, api_key, models],
|
222 |
outputs=[
|
223 |
+
prediction_outputs,
|
224 |
+
af2_predictions,
|
225 |
+
openfold_predictions,
|
226 |
+
solo_predictions,
|
227 |
+
chai_predictions,
|
228 |
+
boltz_predictions,
|
229 |
+
protenix_predictions,
|
230 |
+
],
|
231 |
+
).then(
|
232 |
+
fn=filter_predictions,
|
233 |
+
inputs=[
|
234 |
+
prediction_outputs,
|
235 |
+
af2_predictions,
|
236 |
+
openfold_predictions,
|
237 |
+
solo_predictions,
|
238 |
chai_predictions,
|
239 |
boltz_predictions,
|
240 |
protenix_predictions,
|
|
|
|
|
241 |
],
|
242 |
+
outputs=[mol_outputs, metrics_plot],
|
243 |
)
|
244 |
|
245 |
# Handle checkbox changes
|
246 |
+
for checkbox in [
|
247 |
+
af2_predictions,
|
248 |
+
openfold_predictions,
|
249 |
+
solo_predictions,
|
250 |
+
chai_predictions,
|
251 |
+
boltz_predictions,
|
252 |
+
protenix_predictions,
|
253 |
+
]:
|
254 |
checkbox.change(
|
255 |
fn=filter_predictions,
|
256 |
inputs=[
|
257 |
+
prediction_outputs,
|
258 |
+
af2_predictions,
|
259 |
+
openfold_predictions,
|
260 |
+
solo_predictions,
|
261 |
chai_predictions,
|
262 |
boltz_predictions,
|
263 |
protenix_predictions,
|
|
|
292 |
"antigen_sequence": "Antigen Sequence",
|
293 |
}
|
294 |
spr_data_with_scores = spr_data_with_scores.rename(columns=prettified_columns)
|
295 |
+
columns = [
|
296 |
+
"Antibody Name",
|
297 |
+
"KD (nM)",
|
298 |
+
"Antibody VH Sequence",
|
299 |
+
"Antibody VL Sequence",
|
300 |
+
"Antigen Sequence",
|
301 |
+
]
|
302 |
+
# Display dataframe with floating point values rounded to 2 decimal places
|
303 |
+
spr_data = gr.DataFrame(
|
304 |
+
value=spr_data_with_scores[columns].round(2),
|
305 |
+
label="Experimental Antibody-Antigen Binding Affinity Data",
|
306 |
+
)
|
|
|
307 |
|
308 |
gr.Markdown("# Prediction and correlation")
|
309 |
+
|
310 |
+
fake_predict_btn = gr.Button(
|
311 |
+
"Predict structures of all complexes",
|
312 |
+
elem_classes="gradient-button",
|
313 |
+
variant="primary",
|
314 |
+
)
|
315 |
+
prediction_dataframe = gr.Dataframe(
|
316 |
+
label="Predicted Structures Data", visible=False
|
317 |
+
)
|
318 |
+
prediction_dataframe.change(
|
319 |
+
fn=lambda x: gr.Dataframe(x, visible=True),
|
320 |
+
inputs=[prediction_dataframe],
|
321 |
+
outputs=[prediction_dataframe],
|
322 |
+
)
|
323 |
with gr.Row():
|
324 |
+
correlation_type = gr.Radio(
|
325 |
+
choices=["Spearman", "Pearson", "R²"],
|
326 |
+
value="Spearman",
|
327 |
+
label="Correlation Type",
|
328 |
+
interactive=True,
|
329 |
+
scale=0,
|
330 |
)
|
331 |
+
correlation_ranking_plot = gr.Plot(label="Correlation ranking")
|
332 |
with gr.Row():
|
333 |
+
with gr.Column(scale=0):
|
334 |
+
# User can select the columns to display in the correlation plot
|
335 |
+
correlation_column = gr.Dropdown(
|
336 |
+
label="Score data to display",
|
337 |
+
choices=SCORE_COLUMNS,
|
338 |
+
multiselect=False,
|
339 |
+
value=SCORE_COLUMNS[0],
|
340 |
+
)
|
341 |
+
# Add checkbox for log scale and update plot when either input changes
|
342 |
+
log_scale = gr.Checkbox(
|
343 |
+
label="Display x-axis on logarithmic scale", value=False
|
344 |
+
)
|
345 |
+
score_description = gr.Markdown(
|
346 |
+
get_score_description(correlation_column.value)
|
347 |
+
)
|
348 |
+
correlation_column.change(
|
349 |
+
fn=lambda x: get_score_description(x),
|
350 |
+
inputs=correlation_column,
|
351 |
+
outputs=score_description,
|
352 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
with gr.Column():
|
354 |
regression_plot = gr.Plot(label="Correlation with binding affinity")
|
355 |
|
|
|
384 |
|
385 |
log_scale.change(
|
386 |
fn=update_regression_plot,
|
387 |
+
inputs=[correlation_column, log_scale],
|
388 |
outputs=regression_plot,
|
389 |
)
|
390 |
|
|
|
411 |
)
|
412 |
api_key = gr.Textbox(label="Folding Studio API Key", type="password")
|
413 |
gr.Markdown("## Demo Usage")
|
414 |
+
with gr.Tab("🚀 Basic Folding"):
|
415 |
simple_prediction(api_key)
|
416 |
with gr.Tab("📊 Model Comparison"):
|
417 |
model_comparison(api_key)
|
folding_studio_demo/correlate.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import logging
|
2 |
-
import pandas as pd
|
3 |
from pathlib import Path
|
|
|
4 |
import numpy as np
|
|
|
5 |
import plotly.graph_objects as go
|
6 |
-
from scipy.stats import
|
7 |
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
@@ -16,7 +17,7 @@ SCORE_COLUMN_NAMES = {
|
|
16 |
"complex_pde_boltz": "Boltz Complex pDE",
|
17 |
"complex_ipde_boltz": "Boltz Complex ipDE",
|
18 |
"interchain_pae_monomer": "AlphaFold2 GapTrick Interchain PAE",
|
19 |
-
"interface_pae_monomer": "AlphaFold2 GapTrick Interface PAE",
|
20 |
"overall_pae_monomer": "AlphaFold2 GapTrick Overall PAE",
|
21 |
"interface_plddt_monomer": "AlphaFold2 GapTrick Interface pLDDT",
|
22 |
"average_plddt_monomer": "AlphaFold2 GapTrick Average pLDDT",
|
@@ -24,15 +25,16 @@ SCORE_COLUMN_NAMES = {
|
|
24 |
"interface_ptm_monomer": "AlphaFold2 GapTrick Interface pTM",
|
25 |
"interchain_pae_multimer": "AlphaFold2 Multimer Interchain PAE",
|
26 |
"interface_pae_multimer": "AlphaFold2 Multimer Interface PAE",
|
27 |
-
"overall_pae_multimer": "AlphaFold2 Multimer Overall PAE",
|
28 |
"interface_plddt_multimer": "AlphaFold2 Multimer Interface pLDDT",
|
29 |
"average_plddt_multimer": "AlphaFold2 Multimer Average pLDDT",
|
30 |
"ptm_multimer": "AlphaFold2 Multimer pTM Score",
|
31 |
-
"interface_ptm_multimer": "AlphaFold2 Multimer Interface pTM"
|
32 |
}
|
33 |
|
34 |
SCORE_COLUMNS = list(SCORE_COLUMN_NAMES.values())
|
35 |
|
|
|
36 |
def get_score_description(score: str) -> str:
|
37 |
descriptions = {
|
38 |
"Boltz Confidence Score": "The Boltz model confidence score provides an overall assessment of prediction quality (0-1, higher is better).",
|
@@ -49,22 +51,24 @@ def get_score_description(score: str) -> str:
|
|
49 |
"AlphaFold2 GapTrick Average pLDDT": "The AlphaFold2 GapTrick model average pLDDT provides the mean confidence across all residues in monomeric predictions (0-100, higher is better).",
|
50 |
"AlphaFold2 GapTrick pTM Score": "The AlphaFold2 GapTrick model pTM score assesses overall fold accuracy in monomeric predictions (0-1, higher is better).",
|
51 |
"AlphaFold2 GapTrick Interface pTM": "The AlphaFold2 GapTrick model interface pTM specifically evaluates accuracy of interface regions in monomeric predictions (0-1, higher is better).",
|
52 |
-
"AlphaFold2
|
53 |
-
"AlphaFold2 Multimer Interface PAE": "The AlphaFold2 Multimer model interface PAE estimates position errors specifically at interfaces in multimeric predictions (lower is better).",
|
54 |
"AlphaFold2 Multimer Overall PAE": "The AlphaFold2 Multimer model overall PAE estimates position errors across the entire structure in multimeric predictions (lower is better).",
|
55 |
"AlphaFold2 Multimer Interface pLDDT": "The AlphaFold2 Multimer model interface pLDDT measures confidence in interface region predictions for multimeric models (0-100, higher is better).",
|
56 |
"AlphaFold2 Multimer Average pLDDT": "The AlphaFold2 Multimer model average pLDDT provides the mean confidence across all residues in multimeric predictions (0-100, higher is better).",
|
57 |
"AlphaFold2 Multimer pTM Score": "The AlphaFold2 Multimer model pTM score assesses overall fold accuracy in multimeric predictions (0-1, higher is better).",
|
58 |
-
"AlphaFold2 Multimer Interface pTM": "The AlphaFold2 Multimer model interface pTM specifically evaluates accuracy of interface regions in multimeric predictions (0-1, higher is better)."
|
59 |
}
|
60 |
return descriptions.get(score, "No description available for this score.")
|
61 |
|
62 |
-
|
|
|
|
|
|
|
63 |
corr_data_file = Path("corr_data.csv")
|
64 |
if corr_data_file.exists():
|
65 |
logger.info(f"Loading correlation data from {corr_data_file}")
|
66 |
return pd.read_csv(corr_data_file)
|
67 |
-
|
68 |
corr_data = []
|
69 |
spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
|
70 |
kd_col = "KD (nM)"
|
@@ -74,53 +78,71 @@ def compute_correlation_data(spr_data_with_scores: pd.DataFrame, score_cols: lis
|
|
74 |
corr_funcs["R²"] = linregress
|
75 |
for correlation_type, corr_func in corr_funcs.items():
|
76 |
for score_col in score_cols:
|
77 |
-
logger.info(
|
78 |
-
|
|
|
|
|
|
|
|
|
79 |
logger.info(f"Correlation function: {corr_func}")
|
80 |
-
correlation_value =
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
corr_data = pd.DataFrame(corr_data)
|
90 |
# Find the lines in corr_data with NaN values and remove them
|
91 |
corr_data = corr_data[corr_data["correlation"].notna()]
|
92 |
# Sort correlation data by correlation value
|
93 |
-
corr_data = corr_data.sort_values(
|
94 |
-
|
95 |
corr_data.to_csv("corr_data.csv", index=False)
|
96 |
-
|
97 |
return corr_data
|
98 |
|
99 |
-
|
|
|
|
|
|
|
100 |
# Create bar plot of correlations
|
101 |
data = corr_data[corr_data["correlation_type"] == correlation_type]
|
102 |
-
corr_ranking_plot = go.Figure(
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
112 |
corr_ranking_plot.update_layout(
|
113 |
title="Correlation with Binding Affinity",
|
114 |
yaxis_title="Score",
|
115 |
xaxis_title=correlation_type,
|
116 |
template="simple_white",
|
117 |
-
showlegend=False
|
118 |
)
|
119 |
return corr_ranking_plot
|
120 |
|
121 |
-
|
|
|
|
|
|
|
122 |
"""Fake predict structures of all complexes and correlate the results."""
|
123 |
-
|
124 |
corr_data = compute_correlation_data(spr_data_with_scores, score_cols)
|
125 |
corr_ranking_plot = plot_correlation_ranking(corr_data, "Spearman")
|
126 |
|
@@ -131,17 +153,20 @@ def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: l
|
|
131 |
|
132 |
return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot
|
133 |
|
134 |
-
|
|
|
|
|
|
|
135 |
"""Select the regression plot to display."""
|
136 |
# corr_plot is a scatter plot of the regression between the binding affinity and each of the scores
|
137 |
-
scatter =
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
corr_plot = go.Figure(data=scatter)
|
146 |
corr_plot.update_layout(
|
147 |
xaxis_title="KD (nM)",
|
@@ -154,7 +179,7 @@ def make_regression_plot(spr_data_with_scores: pd.DataFrame, score: str, use_log
|
|
154 |
xanchor="right",
|
155 |
x=1,
|
156 |
),
|
157 |
-
xaxis_type="log" if use_log else "linear" # Set x-axis to logarithmic scale
|
158 |
)
|
159 |
# compute the regression line
|
160 |
if use_log:
|
@@ -162,23 +187,25 @@ def make_regression_plot(spr_data_with_scores: pd.DataFrame, score: str, use_log
|
|
162 |
x_vals = np.log10(spr_data_with_scores["KD (nM)"])
|
163 |
else:
|
164 |
x_vals = spr_data_with_scores["KD (nM)"]
|
165 |
-
|
166 |
# Fit line to data
|
167 |
corr_line = np.polyfit(x_vals, spr_data_with_scores[score], 1)
|
168 |
-
|
169 |
# Generate x points for line
|
170 |
corr_line_x = np.linspace(min(x_vals), max(x_vals), 100)
|
171 |
corr_line_y = corr_line[0] * corr_line_x + corr_line[1]
|
172 |
-
|
173 |
# Convert back from log space if needed
|
174 |
if use_log:
|
175 |
corr_line_x = 10**corr_line_x
|
176 |
# add the regression line to the plot
|
177 |
-
corr_plot.add_trace(
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
1 |
import logging
|
|
|
2 |
from pathlib import Path
|
3 |
+
|
4 |
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
import plotly.graph_objects as go
|
7 |
+
from scipy.stats import linregress, pearsonr, spearmanr
|
8 |
|
9 |
logger = logging.getLogger(__name__)
|
10 |
|
|
|
17 |
"complex_pde_boltz": "Boltz Complex pDE",
|
18 |
"complex_ipde_boltz": "Boltz Complex ipDE",
|
19 |
"interchain_pae_monomer": "AlphaFold2 GapTrick Interchain PAE",
|
20 |
+
"interface_pae_monomer": "AlphaFold2 GapTrick Interface PAE",
|
21 |
"overall_pae_monomer": "AlphaFold2 GapTrick Overall PAE",
|
22 |
"interface_plddt_monomer": "AlphaFold2 GapTrick Interface pLDDT",
|
23 |
"average_plddt_monomer": "AlphaFold2 GapTrick Average pLDDT",
|
|
|
25 |
"interface_ptm_monomer": "AlphaFold2 GapTrick Interface pTM",
|
26 |
"interchain_pae_multimer": "AlphaFold2 Multimer Interchain PAE",
|
27 |
"interface_pae_multimer": "AlphaFold2 Multimer Interface PAE",
|
28 |
+
"overall_pae_multimer": "AlphaFold2 Multimer Overall PAE",
|
29 |
"interface_plddt_multimer": "AlphaFold2 Multimer Interface pLDDT",
|
30 |
"average_plddt_multimer": "AlphaFold2 Multimer Average pLDDT",
|
31 |
"ptm_multimer": "AlphaFold2 Multimer pTM Score",
|
32 |
+
"interface_ptm_multimer": "AlphaFold2 Multimer Interface pTM",
|
33 |
}
|
34 |
|
35 |
SCORE_COLUMNS = list(SCORE_COLUMN_NAMES.values())
|
36 |
|
37 |
+
|
38 |
def get_score_description(score: str) -> str:
|
39 |
descriptions = {
|
40 |
"Boltz Confidence Score": "The Boltz model confidence score provides an overall assessment of prediction quality (0-1, higher is better).",
|
|
|
51 |
"AlphaFold2 GapTrick Average pLDDT": "The AlphaFold2 GapTrick model average pLDDT provides the mean confidence across all residues in monomeric predictions (0-100, higher is better).",
|
52 |
"AlphaFold2 GapTrick pTM Score": "The AlphaFold2 GapTrick model pTM score assesses overall fold accuracy in monomeric predictions (0-1, higher is better).",
|
53 |
"AlphaFold2 GapTrick Interface pTM": "The AlphaFold2 GapTrick model interface pTM specifically evaluates accuracy of interface regions in monomeric predictions (0-1, higher is better).",
|
54 |
+
"AlphaFold2 Multimer Interface PAE": "The AlphaFold2 Multimer model interface PAE estimates position errors specifically at interfaces in multimeric predictions (lower is better).",
|
|
|
55 |
"AlphaFold2 Multimer Overall PAE": "The AlphaFold2 Multimer model overall PAE estimates position errors across the entire structure in multimeric predictions (lower is better).",
|
56 |
"AlphaFold2 Multimer Interface pLDDT": "The AlphaFold2 Multimer model interface pLDDT measures confidence in interface region predictions for multimeric models (0-100, higher is better).",
|
57 |
"AlphaFold2 Multimer Average pLDDT": "The AlphaFold2 Multimer model average pLDDT provides the mean confidence across all residues in multimeric predictions (0-100, higher is better).",
|
58 |
"AlphaFold2 Multimer pTM Score": "The AlphaFold2 Multimer model pTM score assesses overall fold accuracy in multimeric predictions (0-1, higher is better).",
|
59 |
+
"AlphaFold2 Multimer Interface pTM": "The AlphaFold2 Multimer model interface pTM specifically evaluates accuracy of interface regions in multimeric predictions (0-1, higher is better).",
|
60 |
}
|
61 |
return descriptions.get(score, "No description available for this score.")
|
62 |
|
63 |
+
|
64 |
+
def compute_correlation_data(
|
65 |
+
spr_data_with_scores: pd.DataFrame, score_cols: list[str]
|
66 |
+
) -> pd.DataFrame:
|
67 |
corr_data_file = Path("corr_data.csv")
|
68 |
if corr_data_file.exists():
|
69 |
logger.info(f"Loading correlation data from {corr_data_file}")
|
70 |
return pd.read_csv(corr_data_file)
|
71 |
+
|
72 |
corr_data = []
|
73 |
spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
|
74 |
kd_col = "KD (nM)"
|
|
|
78 |
corr_funcs["R²"] = linregress
|
79 |
for correlation_type, corr_func in corr_funcs.items():
|
80 |
for score_col in score_cols:
|
81 |
+
logger.info(
|
82 |
+
f"Computing {correlation_type} correlation between {score_col} and KD (nM)"
|
83 |
+
)
|
84 |
+
res = corr_func(
|
85 |
+
spr_data_with_scores[kd_col], spr_data_with_scores[score_col]
|
86 |
+
)
|
87 |
logger.info(f"Correlation function: {corr_func}")
|
88 |
+
correlation_value = (
|
89 |
+
res.rvalue**2 if correlation_type == "R²" else res.statistic
|
90 |
+
)
|
91 |
+
corr_data.append(
|
92 |
+
{
|
93 |
+
"correlation_type": correlation_type,
|
94 |
+
"score": score_col,
|
95 |
+
"correlation": correlation_value,
|
96 |
+
"p-value": res.pvalue,
|
97 |
+
}
|
98 |
+
)
|
99 |
+
logger.info(
|
100 |
+
f"Correlation {correlation_type} between {score_col} and KD (nM): {correlation_value}"
|
101 |
+
)
|
102 |
|
103 |
corr_data = pd.DataFrame(corr_data)
|
104 |
# Find the lines in corr_data with NaN values and remove them
|
105 |
corr_data = corr_data[corr_data["correlation"].notna()]
|
106 |
# Sort correlation data by correlation value
|
107 |
+
corr_data = corr_data.sort_values("correlation", ascending=True)
|
108 |
+
|
109 |
corr_data.to_csv("corr_data.csv", index=False)
|
110 |
+
|
111 |
return corr_data
|
112 |
|
113 |
+
|
114 |
+
def plot_correlation_ranking(
|
115 |
+
corr_data: pd.DataFrame, correlation_type: str
|
116 |
+
) -> go.Figure:
|
117 |
# Create bar plot of correlations
|
118 |
data = corr_data[corr_data["correlation_type"] == correlation_type]
|
119 |
+
corr_ranking_plot = go.Figure(
|
120 |
+
data=[
|
121 |
+
go.Bar(
|
122 |
+
x=data["correlation"],
|
123 |
+
y=data["score"],
|
124 |
+
name=correlation_type,
|
125 |
+
text=data["correlation"],
|
126 |
+
orientation="h",
|
127 |
+
hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>",
|
128 |
+
)
|
129 |
+
]
|
130 |
+
)
|
131 |
corr_ranking_plot.update_layout(
|
132 |
title="Correlation with Binding Affinity",
|
133 |
yaxis_title="Score",
|
134 |
xaxis_title=correlation_type,
|
135 |
template="simple_white",
|
136 |
+
showlegend=False,
|
137 |
)
|
138 |
return corr_ranking_plot
|
139 |
|
140 |
+
|
141 |
+
def fake_predict_and_correlate(
|
142 |
+
spr_data_with_scores: pd.DataFrame, score_cols: list[str], main_cols: list[str]
|
143 |
+
) -> tuple[pd.DataFrame, go.Figure]:
|
144 |
"""Fake predict structures of all complexes and correlate the results."""
|
145 |
+
|
146 |
corr_data = compute_correlation_data(spr_data_with_scores, score_cols)
|
147 |
corr_ranking_plot = plot_correlation_ranking(corr_data, "Spearman")
|
148 |
|
|
|
153 |
|
154 |
return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot
|
155 |
|
156 |
+
|
157 |
+
def make_regression_plot(
|
158 |
+
spr_data_with_scores: pd.DataFrame, score: str, use_log: bool
|
159 |
+
) -> go.Figure:
|
160 |
"""Select the regression plot to display."""
|
161 |
# corr_plot is a scatter plot of the regression between the binding affinity and each of the scores
|
162 |
+
scatter = go.Scatter(
|
163 |
+
x=spr_data_with_scores["KD (nM)"],
|
164 |
+
y=spr_data_with_scores[score],
|
165 |
+
name=f"Samples",
|
166 |
+
mode="markers", # Only show markers/dots, no lines
|
167 |
+
hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>",
|
168 |
+
marker=dict(color="#1f77b4"), # Set color to match default first color
|
169 |
+
)
|
170 |
corr_plot = go.Figure(data=scatter)
|
171 |
corr_plot.update_layout(
|
172 |
xaxis_title="KD (nM)",
|
|
|
179 |
xanchor="right",
|
180 |
x=1,
|
181 |
),
|
182 |
+
xaxis_type="log" if use_log else "linear", # Set x-axis to logarithmic scale
|
183 |
)
|
184 |
# compute the regression line
|
185 |
if use_log:
|
|
|
187 |
x_vals = np.log10(spr_data_with_scores["KD (nM)"])
|
188 |
else:
|
189 |
x_vals = spr_data_with_scores["KD (nM)"]
|
190 |
+
|
191 |
# Fit line to data
|
192 |
corr_line = np.polyfit(x_vals, spr_data_with_scores[score], 1)
|
193 |
+
|
194 |
# Generate x points for line
|
195 |
corr_line_x = np.linspace(min(x_vals), max(x_vals), 100)
|
196 |
corr_line_y = corr_line[0] * corr_line_x + corr_line[1]
|
197 |
+
|
198 |
# Convert back from log space if needed
|
199 |
if use_log:
|
200 |
corr_line_x = 10**corr_line_x
|
201 |
# add the regression line to the plot
|
202 |
+
corr_plot.add_trace(
|
203 |
+
go.Scatter(
|
204 |
+
x=corr_line_x,
|
205 |
+
y=corr_line_y,
|
206 |
+
mode="lines",
|
207 |
+
name=f"Regression line",
|
208 |
+
line=dict(color="#1f77b4"), # Set same color as scatter points
|
209 |
+
)
|
210 |
+
)
|
211 |
+
return corr_plot
|
folding_studio_demo/model_fasta_validators.py
CHANGED
@@ -248,15 +248,15 @@ class ChaiFastaValidator(BaseFastaValidator):
|
|
248 |
)
|
249 |
seen_names.add(name)
|
250 |
# validate sequence format
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
|
261 |
return True, None
|
262 |
|
|
|
248 |
)
|
249 |
seen_names.add(name)
|
250 |
# validate sequence format
|
251 |
+
sequence = str(record.seq).strip()
|
252 |
+
if (
|
253 |
+
entity_type in {EntityType.PEPTIDE, EntityType.PROTEIN}
|
254 |
+
and not get_entity_type(sequence) == entity_type
|
255 |
+
):
|
256 |
+
return (
|
257 |
+
False,
|
258 |
+
f"CHAI Validation Error: Sequence type mismatch. Expected '{entity_type}' but found '{get_entity_type(sequence)}'",
|
259 |
+
)
|
260 |
|
261 |
return True, None
|
262 |
|
folding_studio_demo/models.py
CHANGED
@@ -1,17 +1,26 @@
|
|
1 |
"""Models for the Folding Studio API."""
|
2 |
|
|
|
3 |
import logging
|
4 |
import os
|
|
|
|
|
|
|
5 |
from pathlib import Path
|
6 |
from typing import Any
|
7 |
|
8 |
import gradio as gr
|
9 |
import numpy as np
|
|
|
10 |
from folding_studio.client import Client
|
|
|
|
|
11 |
from folding_studio.query import Query
|
12 |
from folding_studio.query.boltz import BoltzQuery
|
13 |
from folding_studio.query.chai import ChaiQuery
|
14 |
from folding_studio.query.protenix import ProtenixQuery
|
|
|
|
|
15 |
|
16 |
from folding_studio_demo.model_fasta_validators import (
|
17 |
BaseFastaValidator,
|
@@ -20,15 +29,29 @@ from folding_studio_demo.model_fasta_validators import (
|
|
20 |
ProtenixFastaValidator,
|
21 |
)
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
logger = logging.getLogger(__name__)
|
24 |
|
25 |
|
26 |
class AF3Model:
|
27 |
-
|
28 |
-
|
29 |
-
):
|
30 |
self.api_key = api_key
|
31 |
-
self.model_name = model_name
|
32 |
self.query = query
|
33 |
self.validator = validator
|
34 |
|
@@ -116,8 +139,10 @@ class AF3Model:
|
|
116 |
|
117 |
|
118 |
class ChaiModel(AF3Model):
|
|
|
|
|
119 |
def __init__(self, api_key: str):
|
120 |
-
super().__init__(api_key,
|
121 |
|
122 |
def call(
|
123 |
self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
|
@@ -158,8 +183,10 @@ class ChaiModel(AF3Model):
|
|
158 |
|
159 |
|
160 |
class ProtenixModel(AF3Model):
|
|
|
|
|
161 |
def __init__(self, api_key: str):
|
162 |
-
super().__init__(api_key,
|
163 |
|
164 |
def call(
|
165 |
self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
|
@@ -179,8 +206,10 @@ class ProtenixModel(AF3Model):
|
|
179 |
|
180 |
|
181 |
class BoltzModel(AF3Model):
|
|
|
|
|
182 |
def __init__(self, api_key: str):
|
183 |
-
super().__init__(api_key,
|
184 |
|
185 |
def call(
|
186 |
self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
|
@@ -205,3 +234,113 @@ class BoltzModel(AF3Model):
|
|
205 |
}
|
206 |
for cif_path in prediction_paths
|
207 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""Models for the Folding Studio API."""
|
2 |
|
3 |
+
import json
|
4 |
import logging
|
5 |
import os
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
from io import StringIO
|
9 |
from pathlib import Path
|
10 |
from typing import Any
|
11 |
|
12 |
import gradio as gr
|
13 |
import numpy as np
|
14 |
+
from folding_studio import single_job_prediction
|
15 |
from folding_studio.client import Client
|
16 |
+
from folding_studio.commands.experiment import results as get_results
|
17 |
+
from folding_studio.commands.experiment import status as get_status
|
18 |
from folding_studio.query import Query
|
19 |
from folding_studio.query.boltz import BoltzQuery
|
20 |
from folding_studio.query.chai import ChaiQuery
|
21 |
from folding_studio.query.protenix import ProtenixQuery
|
22 |
+
from folding_studio_data_models import AF2Parameters, OpenFoldParameters
|
23 |
+
from folding_studio_data_models.parameters.base import BaseFoldingParameters
|
24 |
|
25 |
from folding_studio_demo.model_fasta_validators import (
|
26 |
BaseFastaValidator,
|
|
|
29 |
ProtenixFastaValidator,
|
30 |
)
|
31 |
|
32 |
+
|
33 |
+
class Capturing(list):
|
34 |
+
"""Capture stdout output."""
|
35 |
+
|
36 |
+
def __enter__(self):
|
37 |
+
self._stdout = sys.stdout
|
38 |
+
sys.stdout = self._stringio = StringIO()
|
39 |
+
return self
|
40 |
+
|
41 |
+
def __exit__(self, *args):
|
42 |
+
self.extend(self._stringio.getvalue().splitlines())
|
43 |
+
del self._stringio # free up some memory
|
44 |
+
sys.stdout = self._stdout
|
45 |
+
|
46 |
+
|
47 |
logger = logging.getLogger(__name__)
|
48 |
|
49 |
|
50 |
class AF3Model:
|
51 |
+
model_name = None
|
52 |
+
|
53 |
+
def __init__(self, api_key: str, query: Query, validator: BaseFastaValidator):
|
54 |
self.api_key = api_key
|
|
|
55 |
self.query = query
|
56 |
self.validator = validator
|
57 |
|
|
|
139 |
|
140 |
|
141 |
class ChaiModel(AF3Model):
|
142 |
+
model_name = "Chai"
|
143 |
+
|
144 |
def __init__(self, api_key: str):
|
145 |
+
super().__init__(api_key, ChaiQuery, ChaiFastaValidator())
|
146 |
|
147 |
def call(
|
148 |
self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
|
|
|
183 |
|
184 |
|
185 |
class ProtenixModel(AF3Model):
|
186 |
+
model_name = "Protenix"
|
187 |
+
|
188 |
def __init__(self, api_key: str):
|
189 |
+
super().__init__(api_key, ProtenixQuery, ProtenixFastaValidator())
|
190 |
|
191 |
def call(
|
192 |
self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
|
|
|
206 |
|
207 |
|
208 |
class BoltzModel(AF3Model):
|
209 |
+
model_name = "Boltz"
|
210 |
+
|
211 |
def __init__(self, api_key: str):
|
212 |
+
super().__init__(api_key, BoltzQuery, BoltzFastaValidator())
|
213 |
|
214 |
def call(
|
215 |
self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
|
|
|
234 |
}
|
235 |
for cif_path in prediction_paths
|
236 |
}
|
237 |
+
|
238 |
+
|
239 |
+
class OldModel:
|
240 |
+
model_name = None
|
241 |
+
|
242 |
+
def __init__(self, api_key: str):
|
243 |
+
self.api_key = api_key
|
244 |
+
|
245 |
+
def call(
|
246 |
+
self,
|
247 |
+
seq_file: Path | str,
|
248 |
+
output_dir: Path,
|
249 |
+
parameters: BaseFoldingParameters,
|
250 |
+
*args,
|
251 |
+
**kwargs,
|
252 |
+
) -> None:
|
253 |
+
"""Predict protein structure from amino acid sequence using AF2 model.
|
254 |
+
|
255 |
+
Args:
|
256 |
+
seq_file (Path | str): Path to FASTA file containing amino acid sequence
|
257 |
+
output_dir (Path): Path to output directory
|
258 |
+
"""
|
259 |
+
output = single_job_prediction(
|
260 |
+
fasta_file=seq_file,
|
261 |
+
parameters=parameters,
|
262 |
+
)
|
263 |
+
experiment_id = output["message"]["experiment_id"]
|
264 |
+
done = False
|
265 |
+
while not done:
|
266 |
+
with Capturing() as output:
|
267 |
+
get_status(experiment_id)
|
268 |
+
status = output[0]
|
269 |
+
logger.info(f"Experiment {experiment_id} status: {status}")
|
270 |
+
if status == "Done":
|
271 |
+
done = True
|
272 |
+
logger.info("Downloading results")
|
273 |
+
get_results(
|
274 |
+
experiment_id,
|
275 |
+
force=True,
|
276 |
+
unzip=True,
|
277 |
+
output=output_dir / "results.zip",
|
278 |
+
)
|
279 |
+
logger.info("Results downloaded to %s", output_dir)
|
280 |
+
else:
|
281 |
+
logger.info("Sleeping for 10 seconds")
|
282 |
+
time.sleep(10)
|
283 |
+
|
284 |
+
def format_fasta(self, seq_file: Path | str) -> None:
|
285 |
+
"""Format sequence to FASTA format.
|
286 |
+
|
287 |
+
Args:
|
288 |
+
seq_file (Path | str): Path to FASTA file
|
289 |
+
"""
|
290 |
+
return
|
291 |
+
|
292 |
+
def predictions(self, output_dir: Path) -> dict[int, dict[str, Any]]:
|
293 |
+
"""Get the path to the prediction.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
output_dir (Path): Path to output directory
|
297 |
+
|
298 |
+
Returns:
|
299 |
+
dict[int, dict[str, Any]]: Dictionary mapping model indices to their prediction paths and metrics
|
300 |
+
"""
|
301 |
+
prediction_paths = list(
|
302 |
+
(output_dir / "results").rglob("relaxed_model_[0-9]_ptm_pred_0.pdb")
|
303 |
+
)
|
304 |
+
metrics_path = output_dir / "results" / "metrics_per_model.json"
|
305 |
+
if not metrics_path.exists():
|
306 |
+
return {}
|
307 |
+
with open(metrics_path, "r") as f:
|
308 |
+
metrics = json.load(f)
|
309 |
+
output = {
|
310 |
+
int(pred_path.stem.split("_")[2]): {
|
311 |
+
"prediction_path": pred_path,
|
312 |
+
"metrics": metrics[f"model_{int(pred_path.stem.split('_')[2])}_ptm"],
|
313 |
+
}
|
314 |
+
for pred_path in prediction_paths
|
315 |
+
}
|
316 |
+
return output
|
317 |
+
|
318 |
+
def has_prediction(self, output_dir: Path) -> bool:
|
319 |
+
"""Check if prediction exists in output directory."""
|
320 |
+
return len(self.predictions(output_dir)) > 0
|
321 |
+
|
322 |
+
def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]:
|
323 |
+
"""Check if the file description is correct.
|
324 |
+
|
325 |
+
Args:
|
326 |
+
seq_file (Path | str): Path to FASTA file
|
327 |
+
|
328 |
+
Returns:
|
329 |
+
tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not
|
330 |
+
"""
|
331 |
+
|
332 |
+
return True, None
|
333 |
+
|
334 |
+
|
335 |
+
class AF2Model(OldModel):
|
336 |
+
model_name = "AlphaFold2"
|
337 |
+
|
338 |
+
def call(self, seq_file: Path | str, output_dir: Path, *args, **kwargs) -> None:
|
339 |
+
super().call(seq_file, output_dir, AF2Parameters(), *args, **kwargs)
|
340 |
+
|
341 |
+
|
342 |
+
class OpenFoldModel(OldModel):
|
343 |
+
model_name = "OpenFold"
|
344 |
+
|
345 |
+
def call(self, seq_file: Path | str, output_dir: Path, *args, **kwargs) -> None:
|
346 |
+
super().call(seq_file, output_dir, OpenFoldParameters(), *args, **kwargs)
|
folding_studio_demo/predict.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
"""Predict protein structure using Folding Studio."""
|
2 |
|
|
|
3 |
import hashlib
|
4 |
import logging
|
5 |
from io import StringIO
|
6 |
from pathlib import Path
|
|
|
7 |
|
8 |
import gradio as gr
|
9 |
import numpy as np
|
@@ -12,7 +14,13 @@ from Bio import SeqIO
|
|
12 |
from Bio.PDB import PDBIO, MMCIFParser, PDBParser, Superimposer
|
13 |
from folding_studio_data_models import FoldingModel
|
14 |
|
15 |
-
from folding_studio_demo.models import
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
18 |
|
@@ -85,20 +93,22 @@ def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None:
|
|
85 |
def create_plddt_figure(
|
86 |
plddt_vals: list[list[float]],
|
87 |
model_name: str,
|
|
|
88 |
residue_codes: list[list[str]] = None,
|
89 |
) -> go.Figure:
|
90 |
"""Create a plot of metrics."""
|
91 |
plddt_traces = []
|
92 |
-
|
|
|
93 |
# Create hover text with residue codes if available
|
94 |
if residue_codes and i < len(residue_codes):
|
95 |
hover_text = [
|
96 |
-
f"<i>pLDDT</i>: {plddt:.2f}<br><i>Residue:</i> {code} {idx}"
|
97 |
for idx, (plddt, code) in enumerate(zip(plddt_val, residue_codes[i]))
|
98 |
]
|
99 |
else:
|
100 |
hover_text = [
|
101 |
-
f"<i>pLDDT</i>: {plddt:.2f}<br><i>Residue index:</i> {idx}"
|
102 |
for idx, plddt in enumerate(plddt_val)
|
103 |
]
|
104 |
|
@@ -108,7 +118,7 @@ def create_plddt_figure(
|
|
108 |
y=plddt_val,
|
109 |
hovertemplate="%{text}<extra></extra>",
|
110 |
text=hover_text,
|
111 |
-
name=f"{model_name} {
|
112 |
visible=True,
|
113 |
)
|
114 |
)
|
@@ -150,8 +160,19 @@ def _write_fasta_file(
|
|
150 |
return seq_id, seq_file
|
151 |
|
152 |
|
153 |
-
def
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
# Lists to store pLDDT values and residue codes
|
157 |
plddt_values = []
|
@@ -206,6 +227,10 @@ def predict(
|
|
206 |
model = ChaiModel(api_key)
|
207 |
elif model_type == FoldingModel.PROTENIX:
|
208 |
model = ProtenixModel(api_key)
|
|
|
|
|
|
|
|
|
209 |
else:
|
210 |
raise ValueError(f"Model {model_type} not supported")
|
211 |
|
@@ -235,22 +260,36 @@ def predict(
|
|
235 |
progress(
|
236 |
0.4 + (0.4 * i / total_predictions), desc=f"Converting model {model_idx}..."
|
237 |
)
|
238 |
-
|
239 |
-
logger.info(f"
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
|
|
|
|
247 |
model_plddt_vals.append(plddt_vals)
|
248 |
model_residue_codes.append(residue_codes)
|
249 |
|
250 |
progress(0.8, desc="Generating plots...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
plddt_fig = create_plddt_figure(
|
252 |
plddt_vals=model_plddt_vals,
|
253 |
model_name=model.model_name,
|
|
|
254 |
residue_codes=model_residue_codes,
|
255 |
)
|
256 |
|
@@ -258,11 +297,13 @@ def predict(
|
|
258 |
return pdb_paths, plddt_fig
|
259 |
|
260 |
|
261 |
-
def align_structures(
|
|
|
|
|
262 |
"""Align multiple PDB structures to the first structure.
|
263 |
|
264 |
Args:
|
265 |
-
|
266 |
|
267 |
Returns:
|
268 |
list[str]: List of paths to aligned PDB files
|
@@ -271,39 +312,47 @@ def align_structures(pdb_paths: list[str]) -> list[str]:
|
|
271 |
parser = PDBParser()
|
272 |
io = PDBIO()
|
273 |
|
274 |
-
#
|
275 |
-
|
|
|
|
|
|
|
|
|
|
|
276 |
ref_atoms = [atom for atom in ref_structure.get_atoms() if atom.get_name() == "CA"]
|
277 |
|
278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
|
280 |
-
|
281 |
-
|
282 |
-
# Parse the structure to align
|
283 |
-
structure = parser.get_structure(f"model_{i}", pdb_path)
|
284 |
-
atoms = [atom for atom in structure.get_atoms() if atom.get_name() == "CA"]
|
285 |
|
286 |
-
|
287 |
-
|
288 |
|
289 |
-
|
290 |
-
|
291 |
|
292 |
-
|
293 |
-
|
|
|
|
|
294 |
|
295 |
-
|
296 |
-
aligned_path = str(Path(pdb_path).parent / f"aligned_{Path(pdb_path).name}")
|
297 |
-
io.set_structure(structure)
|
298 |
-
io.save(aligned_path)
|
299 |
-
aligned_paths.append(aligned_path)
|
300 |
|
301 |
-
return
|
302 |
|
303 |
|
304 |
def filter_predictions(
|
305 |
-
|
306 |
-
|
|
|
|
|
307 |
chai_selected: list[int],
|
308 |
boltz_selected: list[int],
|
309 |
protenix_selected: list[int],
|
@@ -316,7 +365,7 @@ def filter_predictions(
|
|
316 |
chai_selected (list[int]): Selected Chai model indices
|
317 |
boltz_selected (list[int]): Selected Boltz model indices
|
318 |
protenix_selected (list[int]): Selected Protenix model indices
|
319 |
-
model_predictions (dict[FoldingModel,
|
320 |
|
321 |
Returns:
|
322 |
tuple[list[str], go.Figure]: Filtered PDB paths and updated pLDDT plot
|
@@ -325,26 +374,30 @@ def filter_predictions(
|
|
325 |
filtered_fig = go.Figure()
|
326 |
|
327 |
# Keep track of which traces to show
|
328 |
-
|
329 |
|
330 |
# Helper function to check if a trace should be visible
|
331 |
-
def should_show_trace(
|
332 |
-
model_name
|
333 |
-
|
334 |
-
|
335 |
-
|
|
|
|
|
|
|
336 |
return True
|
337 |
-
if model_name ==
|
338 |
return True
|
339 |
-
if model_name ==
|
340 |
return True
|
341 |
return False
|
342 |
|
343 |
# Filter traces and paths
|
344 |
-
for
|
345 |
-
|
346 |
-
|
347 |
-
|
|
|
348 |
|
349 |
# Update layout
|
350 |
filtered_fig.update_layout(
|
@@ -355,21 +408,58 @@ def filter_predictions(
|
|
355 |
template="simple_white",
|
356 |
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
|
357 |
)
|
|
|
358 |
|
359 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
|
361 |
|
362 |
def predict_comparison(
|
363 |
sequence: str, api_key: str, model_types: list[FoldingModel], progress=gr.Progress()
|
364 |
) -> tuple[
|
365 |
-
|
366 |
-
|
|
|
|
|
367 |
gr.CheckboxGroup,
|
368 |
gr.CheckboxGroup,
|
369 |
gr.CheckboxGroup,
|
370 |
-
list[str],
|
371 |
-
go.Figure,
|
372 |
-
dict,
|
373 |
]:
|
374 |
"""Predict protein structure from amino acid sequence using multiple models.
|
375 |
|
@@ -381,68 +471,94 @@ def predict_comparison(
|
|
381 |
|
382 |
Returns:
|
383 |
tuple containing:
|
384 |
-
-
|
385 |
-
-
|
|
|
|
|
386 |
- gr.CheckboxGroup: Chai predictions checkbox group
|
387 |
- gr.CheckboxGroup: Boltz predictions checkbox group
|
388 |
- gr.CheckboxGroup: Protenix predictions checkbox group
|
389 |
-
- list[str]: Original PDB paths
|
390 |
-
- go.Figure: Original pLDDT plot
|
391 |
-
- dict: Model predictions mapping
|
392 |
"""
|
393 |
if not api_key:
|
394 |
raise gr.Error("Missing API key, please enter a valid API key")
|
395 |
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
total_models = len(model_types)
|
400 |
model_predictions = {}
|
401 |
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
|
411 |
progress(0.9, desc="Aligning structures...")
|
412 |
-
|
413 |
-
|
414 |
-
plddt_fig.update_layout(
|
415 |
-
title="pLDDT",
|
416 |
-
xaxis_title="Residue index",
|
417 |
-
yaxis_title="pLDDT",
|
418 |
-
height=500,
|
419 |
-
template="simple_white",
|
420 |
-
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
|
421 |
-
)
|
422 |
|
423 |
progress(1.0, desc="Done!")
|
424 |
|
425 |
# Create checkbox groups for each model type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
chai_predictions = gr.CheckboxGroup(
|
427 |
visible=model_predictions.get(FoldingModel.CHAI) is not None,
|
428 |
-
choices=model_predictions.get(FoldingModel.CHAI,
|
429 |
-
value=model_predictions.get(FoldingModel.CHAI,
|
430 |
)
|
431 |
boltz_predictions = gr.CheckboxGroup(
|
432 |
visible=model_predictions.get(FoldingModel.BOLTZ) is not None,
|
433 |
-
choices=model_predictions.get(FoldingModel.BOLTZ,
|
434 |
-
value=model_predictions.get(FoldingModel.BOLTZ,
|
435 |
)
|
436 |
protenix_predictions = gr.CheckboxGroup(
|
437 |
visible=model_predictions.get(FoldingModel.PROTENIX) is not None,
|
438 |
-
choices=model_predictions.get(FoldingModel.PROTENIX,
|
439 |
-
value=model_predictions.get(FoldingModel.PROTENIX,
|
440 |
)
|
441 |
|
442 |
return (
|
|
|
|
|
|
|
|
|
443 |
chai_predictions,
|
444 |
boltz_predictions,
|
445 |
protenix_predictions,
|
446 |
-
aligned_paths,
|
447 |
-
plddt_fig,
|
448 |
)
|
|
|
1 |
"""Predict protein structure using Folding Studio."""
|
2 |
|
3 |
+
import concurrent.futures
|
4 |
import hashlib
|
5 |
import logging
|
6 |
from io import StringIO
|
7 |
from pathlib import Path
|
8 |
+
from typing import Any
|
9 |
|
10 |
import gradio as gr
|
11 |
import numpy as np
|
|
|
14 |
from Bio.PDB import PDBIO, MMCIFParser, PDBParser, Superimposer
|
15 |
from folding_studio_data_models import FoldingModel
|
16 |
|
17 |
+
from folding_studio_demo.models import (
|
18 |
+
AF2Model,
|
19 |
+
BoltzModel,
|
20 |
+
ChaiModel,
|
21 |
+
OpenFoldModel,
|
22 |
+
ProtenixModel,
|
23 |
+
)
|
24 |
|
25 |
logger = logging.getLogger(__name__)
|
26 |
|
|
|
93 |
def create_plddt_figure(
|
94 |
plddt_vals: list[list[float]],
|
95 |
model_name: str,
|
96 |
+
indexes: list[int],
|
97 |
residue_codes: list[list[str]] = None,
|
98 |
) -> go.Figure:
|
99 |
"""Create a plot of metrics."""
|
100 |
plddt_traces = []
|
101 |
+
|
102 |
+
for i, (plddt_val, index) in enumerate(zip(plddt_vals, indexes)):
|
103 |
# Create hover text with residue codes if available
|
104 |
if residue_codes and i < len(residue_codes):
|
105 |
hover_text = [
|
106 |
+
f"<i>{model_name} {index}</i><br><i>pLDDT</i>: {plddt:.2f}<br><i>Residue:</i> {code} {idx}"
|
107 |
for idx, (plddt, code) in enumerate(zip(plddt_val, residue_codes[i]))
|
108 |
]
|
109 |
else:
|
110 |
hover_text = [
|
111 |
+
f"<i>{model_name} {index}</i><br><i>pLDDT</i>: {plddt:.2f}<br><i>Residue index:</i> {idx}"
|
112 |
for idx, plddt in enumerate(plddt_val)
|
113 |
]
|
114 |
|
|
|
118 |
y=plddt_val,
|
119 |
hovertemplate="%{text}<extra></extra>",
|
120 |
text=hover_text,
|
121 |
+
name=f"{model_name} {index}",
|
122 |
visible=True,
|
123 |
)
|
124 |
)
|
|
|
160 |
return seq_id, seq_file
|
161 |
|
162 |
|
163 |
+
def extract_plddt_from_structure(structure_path: str) -> tuple[list[float], list[str]]:
|
164 |
+
"""Extract pLDDT values and residue codes from a structure file.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
structure_path (Path): Path to structure file
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
tuple[list[float], list[str]]: Tuple containing lists of pLDDT values and residue codes
|
171 |
+
"""
|
172 |
+
if Path(structure_path).suffix == ".cif":
|
173 |
+
structure = MMCIFParser().get_structure("structure", structure_path)
|
174 |
+
else:
|
175 |
+
structure = PDBParser().get_structure("structure", structure_path)
|
176 |
|
177 |
# Lists to store pLDDT values and residue codes
|
178 |
plddt_values = []
|
|
|
227 |
model = ChaiModel(api_key)
|
228 |
elif model_type == FoldingModel.PROTENIX:
|
229 |
model = ProtenixModel(api_key)
|
230 |
+
elif model_type == FoldingModel.AF2:
|
231 |
+
model = AF2Model(api_key)
|
232 |
+
elif model_type == FoldingModel.OPENFOLD:
|
233 |
+
model = OpenFoldModel(api_key)
|
234 |
else:
|
235 |
raise ValueError(f"Model {model_type} not supported")
|
236 |
|
|
|
260 |
progress(
|
261 |
0.4 + (0.4 * i / total_predictions), desc=f"Converting model {model_idx}..."
|
262 |
)
|
263 |
+
prediction_path = prediction["prediction_path"]
|
264 |
+
logger.info(f"Prediction file: {prediction_path}")
|
265 |
+
if Path(prediction_path).suffix == ".cif":
|
266 |
+
converted_pdb_path = str(
|
267 |
+
output_dir / f"{model.model_name}_prediction_{model_idx}.pdb"
|
268 |
+
)
|
269 |
+
convert_cif_to_pdb(str(prediction_path), str(converted_pdb_path))
|
270 |
+
pdb_paths.append(converted_pdb_path)
|
271 |
+
else:
|
272 |
+
pdb_paths.append(str(prediction_path))
|
273 |
+
plddt_vals, residue_codes = extract_plddt_from_structure(prediction_path)
|
274 |
model_plddt_vals.append(plddt_vals)
|
275 |
model_residue_codes.append(residue_codes)
|
276 |
|
277 |
progress(0.8, desc="Generating plots...")
|
278 |
+
indexes = []
|
279 |
+
for pdb_path in pdb_paths:
|
280 |
+
if model_type in [
|
281 |
+
FoldingModel.AF2,
|
282 |
+
FoldingModel.OPENFOLD,
|
283 |
+
FoldingModel.SOLOSEQ,
|
284 |
+
]:
|
285 |
+
indexes.append(int(Path(pdb_path).stem.split("_")[2]))
|
286 |
+
else:
|
287 |
+
indexes.append(int(Path(pdb_path).stem[-1]))
|
288 |
+
|
289 |
plddt_fig = create_plddt_figure(
|
290 |
plddt_vals=model_plddt_vals,
|
291 |
model_name=model.model_name,
|
292 |
+
indexes=indexes,
|
293 |
residue_codes=model_residue_codes,
|
294 |
)
|
295 |
|
|
|
297 |
return pdb_paths, plddt_fig
|
298 |
|
299 |
|
300 |
+
def align_structures(
|
301 |
+
model_predictions: dict[FoldingModel, dict[int, dict[str, Any]]],
|
302 |
+
) -> list[str]:
|
303 |
"""Align multiple PDB structures to the first structure.
|
304 |
|
305 |
Args:
|
306 |
+
model_predictions (dict[FoldingModel, dict[int, dict[str, Any]]]): Dictionary mapping models to their prediction indices
|
307 |
|
308 |
Returns:
|
309 |
list[str]: List of paths to aligned PDB files
|
|
|
312 |
parser = PDBParser()
|
313 |
io = PDBIO()
|
314 |
|
315 |
+
# Get the first structure as reference
|
316 |
+
first_model = next(iter(model_predictions.keys()))
|
317 |
+
first_pred = next(iter(model_predictions[first_model].values()))
|
318 |
+
ref_pdb_path = first_pred["pdb_path"]
|
319 |
+
|
320 |
+
# Parse reference structure and get CA atoms
|
321 |
+
ref_structure = parser.get_structure("reference", ref_pdb_path)
|
322 |
ref_atoms = [atom for atom in ref_structure.get_atoms() if atom.get_name() == "CA"]
|
323 |
|
324 |
+
for model_type in model_predictions.keys():
|
325 |
+
for index, prediction in model_predictions[model_type].items():
|
326 |
+
pdb_path = prediction["pdb_path"]
|
327 |
+
|
328 |
+
# Parse the structure to align
|
329 |
+
structure = parser.get_structure(f"{model_type}_{index}", pdb_path)
|
330 |
+
atoms = [atom for atom in structure.get_atoms() if atom.get_name() == "CA"]
|
331 |
|
332 |
+
# Create superimposer
|
333 |
+
sup = Superimposer()
|
|
|
|
|
|
|
334 |
|
335 |
+
# Set the reference and moving atoms
|
336 |
+
sup.set_atoms(ref_atoms, atoms)
|
337 |
|
338 |
+
# Apply the transformation to all atoms in the structure
|
339 |
+
sup.apply(structure.get_atoms())
|
340 |
|
341 |
+
# Save the aligned structure
|
342 |
+
aligned_path = str(Path(pdb_path).parent / f"aligned_{Path(pdb_path).name}")
|
343 |
+
io.set_structure(structure)
|
344 |
+
io.save(aligned_path)
|
345 |
|
346 |
+
model_predictions[model_type][index]["pdb_path"] = aligned_path
|
|
|
|
|
|
|
|
|
347 |
|
348 |
+
return model_predictions
|
349 |
|
350 |
|
351 |
def filter_predictions(
|
352 |
+
model_predictions: dict[FoldingModel, dict[int, dict[str, Any]]],
|
353 |
+
af2_selected: list[int],
|
354 |
+
openfold_selected: list[int],
|
355 |
+
solo_selected: list[int],
|
356 |
chai_selected: list[int],
|
357 |
boltz_selected: list[int],
|
358 |
protenix_selected: list[int],
|
|
|
365 |
chai_selected (list[int]): Selected Chai model indices
|
366 |
boltz_selected (list[int]): Selected Boltz model indices
|
367 |
protenix_selected (list[int]): Selected Protenix model indices
|
368 |
+
model_predictions (dict[FoldingModel, dict[int, dict[str, Any]]]): Dictionary mapping models to their prediction indices
|
369 |
|
370 |
Returns:
|
371 |
tuple[list[str], go.Figure]: Filtered PDB paths and updated pLDDT plot
|
|
|
374 |
filtered_fig = go.Figure()
|
375 |
|
376 |
# Keep track of which traces to show
|
377 |
+
filtered_paths = []
|
378 |
|
379 |
# Helper function to check if a trace should be visible
|
380 |
+
def should_show_trace(model_name, pred_index: int) -> bool:
|
381 |
+
if model_name == FoldingModel.CHAI and pred_index in chai_selected:
|
382 |
+
return True
|
383 |
+
if model_name == FoldingModel.BOLTZ and pred_index in boltz_selected:
|
384 |
+
return True
|
385 |
+
if model_name == FoldingModel.PROTENIX and pred_index in protenix_selected:
|
386 |
+
return True
|
387 |
+
if model_name == FoldingModel.AF2 and pred_index in af2_selected:
|
388 |
return True
|
389 |
+
if model_name == FoldingModel.OPENFOLD and pred_index in openfold_selected:
|
390 |
return True
|
391 |
+
if model_name == FoldingModel.SOLOSEQ and pred_index in solo_selected:
|
392 |
return True
|
393 |
return False
|
394 |
|
395 |
# Filter traces and paths
|
396 |
+
for model_type in model_predictions.keys():
|
397 |
+
for index, prediction in model_predictions[model_type].items():
|
398 |
+
if should_show_trace(model_type, index):
|
399 |
+
filtered_fig.add_trace(prediction["plddt_trace"])
|
400 |
+
filtered_paths.append(prediction["pdb_path"])
|
401 |
|
402 |
# Update layout
|
403 |
filtered_fig.update_layout(
|
|
|
408 |
template="simple_white",
|
409 |
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
|
410 |
)
|
411 |
+
return filtered_paths, filtered_fig
|
412 |
|
413 |
+
|
414 |
+
def run_prediction(
|
415 |
+
sequence: str,
|
416 |
+
api_key: str,
|
417 |
+
model_type: FoldingModel,
|
418 |
+
format_fasta: bool = False,
|
419 |
+
) -> dict[FoldingModel, dict[int, dict[str, Any]]]:
|
420 |
+
"""Run a single prediction.
|
421 |
+
|
422 |
+
Args:
|
423 |
+
sequence (str): Amino acid sequence to predict structure for
|
424 |
+
api_key (str): Folding API key
|
425 |
+
model_type (FoldingModel): Folding model to use
|
426 |
+
format_fasta (bool): Whether to format the FASTA file
|
427 |
+
|
428 |
+
Returns:
|
429 |
+
Tuple containing:
|
430 |
+
- List of PDB paths
|
431 |
+
- pLDDT plot
|
432 |
+
- Dictionary mapping model to prediction indices
|
433 |
+
"""
|
434 |
+
model_pdb_paths, model_plddt_traces = predict(
|
435 |
+
sequence, api_key, model_type, format_fasta=format_fasta
|
436 |
+
)
|
437 |
+
model_pdb_paths = sorted(model_pdb_paths)
|
438 |
+
model_predictions = {}
|
439 |
+
for pdb_path, plddt_trace in zip(model_pdb_paths, model_plddt_traces.data):
|
440 |
+
if model_type in [
|
441 |
+
FoldingModel.AF2,
|
442 |
+
FoldingModel.OPENFOLD,
|
443 |
+
FoldingModel.SOLOSEQ,
|
444 |
+
]:
|
445 |
+
index = int(Path(pdb_path).stem.split("_")[2])
|
446 |
+
else:
|
447 |
+
index = int(Path(pdb_path).stem[-1])
|
448 |
+
|
449 |
+
model_predictions[index] = {"pdb_path": pdb_path, "plddt_trace": plddt_trace}
|
450 |
+
return model_predictions
|
451 |
|
452 |
|
453 |
def predict_comparison(
|
454 |
sequence: str, api_key: str, model_types: list[FoldingModel], progress=gr.Progress()
|
455 |
) -> tuple[
|
456 |
+
dict[FoldingModel, dict[int, dict[str, Any]]],
|
457 |
+
gr.CheckboxGroup,
|
458 |
+
gr.CheckboxGroup,
|
459 |
+
gr.CheckboxGroup,
|
460 |
gr.CheckboxGroup,
|
461 |
gr.CheckboxGroup,
|
462 |
gr.CheckboxGroup,
|
|
|
|
|
|
|
463 |
]:
|
464 |
"""Predict protein structure from amino acid sequence using multiple models.
|
465 |
|
|
|
471 |
|
472 |
Returns:
|
473 |
tuple containing:
|
474 |
+
- dict[FoldingModel, dict[int, dict[str, Any]]]: Model predictions mapping
|
475 |
+
- gr.CheckboxGroup: AF2 predictions checkbox group
|
476 |
+
- gr.CheckboxGroup: OpenFold predictions checkbox group
|
477 |
+
- gr.CheckboxGroup: SoloSeq predictions checkbox group
|
478 |
- gr.CheckboxGroup: Chai predictions checkbox group
|
479 |
- gr.CheckboxGroup: Boltz predictions checkbox group
|
480 |
- gr.CheckboxGroup: Protenix predictions checkbox group
|
|
|
|
|
|
|
481 |
"""
|
482 |
if not api_key:
|
483 |
raise gr.Error("Missing API key, please enter a valid API key")
|
484 |
|
485 |
+
progress(0, desc="Starting parallel predictions...")
|
486 |
+
|
487 |
+
# Run predictions in parallel
|
|
|
488 |
model_predictions = {}
|
489 |
|
490 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
491 |
+
# Create a future for each model prediction
|
492 |
+
future_to_model = {
|
493 |
+
executor.submit(
|
494 |
+
run_prediction, sequence, api_key, model_type, True
|
495 |
+
): model_type
|
496 |
+
for model_type in model_types
|
497 |
+
}
|
498 |
+
|
499 |
+
# Process results as they complete
|
500 |
+
total_models = len(model_types)
|
501 |
+
completed = 0
|
502 |
+
|
503 |
+
for future in concurrent.futures.as_completed(future_to_model):
|
504 |
+
model_type = future_to_model[future]
|
505 |
+
try:
|
506 |
+
model_preds = future.result()
|
507 |
+
model_predictions[model_type] = model_preds
|
508 |
+
|
509 |
+
completed += 1
|
510 |
+
progress(
|
511 |
+
completed / total_models,
|
512 |
+
desc=f"Completed {model_type} prediction...",
|
513 |
+
)
|
514 |
+
except Exception as e:
|
515 |
+
logger.error(f"Prediction failed for {model_type}: {str(e)}")
|
516 |
+
raise gr.Error(f"Prediction failed for {model_type}: {str(e)}")
|
517 |
|
518 |
progress(0.9, desc="Aligning structures...")
|
519 |
+
|
520 |
+
model_predictions = align_structures(model_predictions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
521 |
|
522 |
progress(1.0, desc="Done!")
|
523 |
|
524 |
# Create checkbox groups for each model type
|
525 |
+
af2_predictions = gr.CheckboxGroup(
|
526 |
+
visible=model_predictions.get(FoldingModel.AF2) is not None,
|
527 |
+
choices=list(model_predictions.get(FoldingModel.AF2, {}).keys()),
|
528 |
+
value=list(model_predictions.get(FoldingModel.AF2, {}).keys()),
|
529 |
+
)
|
530 |
+
openfold_predictions = gr.CheckboxGroup(
|
531 |
+
visible=model_predictions.get(FoldingModel.OPENFOLD) is not None,
|
532 |
+
choices=list(model_predictions.get(FoldingModel.OPENFOLD, {}).keys()),
|
533 |
+
value=list(model_predictions.get(FoldingModel.OPENFOLD, {}).keys()),
|
534 |
+
)
|
535 |
+
solo_predictions = gr.CheckboxGroup(
|
536 |
+
visible=model_predictions.get(FoldingModel.SOLOSEQ) is not None,
|
537 |
+
choices=list(model_predictions.get(FoldingModel.SOLOSEQ, {}).keys()),
|
538 |
+
value=list(model_predictions.get(FoldingModel.SOLOSEQ, {}).keys()),
|
539 |
+
)
|
540 |
chai_predictions = gr.CheckboxGroup(
|
541 |
visible=model_predictions.get(FoldingModel.CHAI) is not None,
|
542 |
+
choices=list(model_predictions.get(FoldingModel.CHAI, {}).keys()),
|
543 |
+
value=list(model_predictions.get(FoldingModel.CHAI, {}).keys()),
|
544 |
)
|
545 |
boltz_predictions = gr.CheckboxGroup(
|
546 |
visible=model_predictions.get(FoldingModel.BOLTZ) is not None,
|
547 |
+
choices=list(model_predictions.get(FoldingModel.BOLTZ, {}).keys()),
|
548 |
+
value=list(model_predictions.get(FoldingModel.BOLTZ, {}).keys()),
|
549 |
)
|
550 |
protenix_predictions = gr.CheckboxGroup(
|
551 |
visible=model_predictions.get(FoldingModel.PROTENIX) is not None,
|
552 |
+
choices=list(model_predictions.get(FoldingModel.PROTENIX, {}).keys()),
|
553 |
+
value=list(model_predictions.get(FoldingModel.PROTENIX, {}).keys()),
|
554 |
)
|
555 |
|
556 |
return (
|
557 |
+
model_predictions,
|
558 |
+
af2_predictions,
|
559 |
+
openfold_predictions,
|
560 |
+
solo_predictions,
|
561 |
chai_predictions,
|
562 |
boltz_predictions,
|
563 |
protenix_predictions,
|
|
|
|
|
564 |
)
|