improve model comparison (#10)
Browse files- improve model comparion (6187a6421d0fce2de6e37381ac15baa56c12ea39)
Co-authored-by: Achille Soulie <[email protected]>
- folding_studio_demo/app.py +85 -29
- folding_studio_demo/model_fasta_validators.py +9 -9
- folding_studio_demo/models.py +207 -0
- folding_studio_demo/predict.py +240 -216
folding_studio_demo/app.py
CHANGED
@@ -4,19 +4,20 @@ import logging
|
|
4 |
|
5 |
import gradio as gr
|
6 |
import pandas as pd
|
|
|
7 |
from folding_studio_data_models import FoldingModel
|
8 |
from gradio_molecule3d import Molecule3D
|
9 |
|
10 |
from folding_studio_demo.correlate import (
|
11 |
-
SCORE_COLUMNS,
|
12 |
SCORE_COLUMN_NAMES,
|
|
|
|
|
13 |
fake_predict_and_correlate,
|
|
|
14 |
make_regression_plot,
|
15 |
-
compute_correlation_data,
|
16 |
plot_correlation_ranking,
|
17 |
-
get_score_description
|
18 |
)
|
19 |
-
from folding_studio_demo.predict import predict, predict_comparison
|
20 |
|
21 |
logger = logging.getLogger(__name__)
|
22 |
|
@@ -24,8 +25,8 @@ logger = logging.getLogger(__name__)
|
|
24 |
MOLECULE_REPS = [
|
25 |
{
|
26 |
"model": 0,
|
27 |
-
"chain": "",
|
28 |
-
"resname": "",
|
29 |
"style": "cartoon",
|
30 |
"color": "alphafold",
|
31 |
# "residue_range": "",
|
@@ -36,7 +37,6 @@ MOLECULE_REPS = [
|
|
36 |
}
|
37 |
]
|
38 |
|
39 |
-
DEFAULT_PROTEIN_SEQ = ">protein description\nMALWMRLLPLLALLALWGPDPAAA"
|
40 |
|
41 |
MODEL_CHOICES = [
|
42 |
# ("AlphaFold2", FoldingModel.AF2),
|
@@ -47,8 +47,24 @@ MODEL_CHOICES = [
|
|
47 |
("Protenix", FoldingModel.PROTENIX),
|
48 |
]
|
49 |
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
"""Sequence input component.
|
53 |
|
54 |
Returns:
|
@@ -56,10 +72,21 @@ def sequence_input() -> gr.Textbox:
|
|
56 |
"""
|
57 |
sequence = gr.Textbox(
|
58 |
label="Protein Sequence",
|
59 |
-
value=DEFAULT_PROTEIN_SEQ,
|
60 |
lines=2,
|
61 |
placeholder="Enter a protein sequence or upload a FASTA file",
|
62 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
file_input = gr.File(
|
64 |
label="Upload a FASTA file",
|
65 |
file_types=[".fasta", ".fa"],
|
@@ -104,7 +131,7 @@ def simple_prediction(api_key: str) -> None:
|
|
104 |
value=FoldingModel.BOLTZ,
|
105 |
)
|
106 |
with gr.Column():
|
107 |
-
sequence = sequence_input()
|
108 |
|
109 |
predict_btn = gr.Button(
|
110 |
"Predict",
|
@@ -132,10 +159,9 @@ def model_comparison(api_key: str) -> None:
|
|
132 |
"""
|
133 |
|
134 |
with gr.Row():
|
135 |
-
models = gr.
|
136 |
label="Model",
|
137 |
choices=MODEL_CHOICES,
|
138 |
-
multiselect=True,
|
139 |
scale=0,
|
140 |
min_width=300,
|
141 |
value=[FoldingModel.BOLTZ, FoldingModel.CHAI, FoldingModel.PROTENIX],
|
@@ -149,22 +175,46 @@ def model_comparison(api_key: str) -> None:
|
|
149 |
elem_id="compare-models-btn",
|
150 |
variant="primary",
|
151 |
)
|
152 |
-
|
|
|
|
|
|
|
153 |
with gr.Row():
|
154 |
mol_outputs = Molecule3D(
|
155 |
-
label="Protein Structure",
|
156 |
-
reps=MOLECULE_REPS,
|
157 |
-
file_count="multiple",
|
158 |
)
|
|
|
159 |
|
160 |
-
|
|
|
|
|
161 |
|
162 |
predict_btn.click(
|
163 |
fn=predict_comparison,
|
164 |
inputs=[sequence, api_key, models],
|
165 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
)
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
def create_correlation_tab():
|
170 |
gr.Markdown("# Correlation with experimental binding affinity data")
|
@@ -221,7 +271,7 @@ def create_correlation_tab():
|
|
221 |
choices=["Spearman", "Pearson", "R²"],
|
222 |
value="Spearman",
|
223 |
label="Correlation Type",
|
224 |
-
interactive=True
|
225 |
)
|
226 |
with gr.Row():
|
227 |
correlation_ranking_plot = gr.Plot(label="Correlation ranking")
|
@@ -230,17 +280,24 @@ def create_correlation_tab():
|
|
230 |
with gr.Row():
|
231 |
# User can select the columns to display in the correlation plot
|
232 |
correlation_column = gr.Dropdown(
|
233 |
-
label="Score data to display",
|
|
|
|
|
|
|
234 |
)
|
235 |
# Add checkbox for log scale and update plot when either input changes
|
236 |
with gr.Row():
|
237 |
-
log_scale = gr.Checkbox(
|
|
|
|
|
238 |
with gr.Row():
|
239 |
-
score_description = gr.Markdown(
|
|
|
|
|
240 |
correlation_column.change(
|
241 |
fn=lambda x: get_score_description(x),
|
242 |
inputs=correlation_column,
|
243 |
-
outputs=score_description
|
244 |
)
|
245 |
with gr.Column():
|
246 |
correlation_plot = gr.Plot(label="Correlation with binding affinity")
|
@@ -252,10 +309,10 @@ def create_correlation_tab():
|
|
252 |
inputs=[correlation_type],
|
253 |
outputs=[prediction_dataframe, correlation_ranking_plot, correlation_plot],
|
254 |
)
|
255 |
-
|
256 |
def update_regression_plot(score, use_log):
|
257 |
return make_regression_plot(spr_data_with_scores, score, use_log)
|
258 |
-
|
259 |
def update_correlation_plot(correlation_type):
|
260 |
logger.info(f"Updating correlation plot for {correlation_type}")
|
261 |
corr_data = compute_correlation_data(spr_data_with_scores, SCORE_COLUMNS)
|
@@ -273,16 +330,15 @@ def create_correlation_tab():
|
|
273 |
inputs=[correlation_type],
|
274 |
outputs=correlation_ranking_plot,
|
275 |
)
|
276 |
-
|
277 |
log_scale.change(
|
278 |
fn=update_regression_plot,
|
279 |
-
inputs=[correlation_column, log_scale],
|
280 |
outputs=correlation_plot,
|
281 |
)
|
282 |
|
283 |
|
284 |
def __main__():
|
285 |
-
|
286 |
theme = gr.themes.Ocean(
|
287 |
primary_hue="blue",
|
288 |
secondary_hue="purple",
|
|
|
4 |
|
5 |
import gradio as gr
|
6 |
import pandas as pd
|
7 |
+
import plotly.graph_objects as go
|
8 |
from folding_studio_data_models import FoldingModel
|
9 |
from gradio_molecule3d import Molecule3D
|
10 |
|
11 |
from folding_studio_demo.correlate import (
|
|
|
12 |
SCORE_COLUMN_NAMES,
|
13 |
+
SCORE_COLUMNS,
|
14 |
+
compute_correlation_data,
|
15 |
fake_predict_and_correlate,
|
16 |
+
get_score_description,
|
17 |
make_regression_plot,
|
|
|
18 |
plot_correlation_ranking,
|
|
|
19 |
)
|
20 |
+
from folding_studio_demo.predict import filter_predictions, predict, predict_comparison
|
21 |
|
22 |
logger = logging.getLogger(__name__)
|
23 |
|
|
|
25 |
MOLECULE_REPS = [
|
26 |
{
|
27 |
"model": 0,
|
28 |
+
# "chain": "",
|
29 |
+
# "resname": "",
|
30 |
"style": "cartoon",
|
31 |
"color": "alphafold",
|
32 |
# "residue_range": "",
|
|
|
37 |
}
|
38 |
]
|
39 |
|
|
|
40 |
|
41 |
MODEL_CHOICES = [
|
42 |
# ("AlphaFold2", FoldingModel.AF2),
|
|
|
47 |
("Protenix", FoldingModel.PROTENIX),
|
48 |
]
|
49 |
|
50 |
+
DEFAULT_SEQ = "MALWMRLLPLLALLALWGPDPAAA"
|
51 |
+
MODEL_EXAMPLES = {
|
52 |
+
FoldingModel.BOLTZ: [
|
53 |
+
["Monomer", f">A|protein\n{DEFAULT_SEQ}"],
|
54 |
+
["Multimer", f">A|protein\n{DEFAULT_SEQ}\n>B|protein\n{DEFAULT_SEQ}"],
|
55 |
+
],
|
56 |
+
FoldingModel.CHAI: [
|
57 |
+
["Monomer", f">protein|name=A\n{DEFAULT_SEQ}"],
|
58 |
+
["Multimer", f">protein|name=A\n{DEFAULT_SEQ}\n>protein|name=B\n{DEFAULT_SEQ}"],
|
59 |
+
],
|
60 |
+
FoldingModel.PROTENIX: [
|
61 |
+
["Monomer", f">A|protein\n{DEFAULT_SEQ}"],
|
62 |
+
["Multimer", f">A|protein\n{DEFAULT_SEQ}\n>B|protein\n{DEFAULT_SEQ}"],
|
63 |
+
],
|
64 |
+
}
|
65 |
+
|
66 |
+
|
67 |
+
def sequence_input(dropdown: gr.Dropdown | None = None) -> gr.Textbox:
|
68 |
"""Sequence input component.
|
69 |
|
70 |
Returns:
|
|
|
72 |
"""
|
73 |
sequence = gr.Textbox(
|
74 |
label="Protein Sequence",
|
|
|
75 |
lines=2,
|
76 |
placeholder="Enter a protein sequence or upload a FASTA file",
|
77 |
)
|
78 |
+
dummy = gr.Textbox(label="Complex type", visible=False)
|
79 |
+
|
80 |
+
examples = gr.Examples(
|
81 |
+
examples=MODEL_EXAMPLES[FoldingModel.BOLTZ],
|
82 |
+
inputs=[dummy, sequence],
|
83 |
+
)
|
84 |
+
if dropdown is not None:
|
85 |
+
dropdown.change(
|
86 |
+
fn=lambda x: gr.Dataset(samples=MODEL_EXAMPLES[x]),
|
87 |
+
inputs=[dropdown],
|
88 |
+
outputs=[examples.dataset],
|
89 |
+
)
|
90 |
file_input = gr.File(
|
91 |
label="Upload a FASTA file",
|
92 |
file_types=[".fasta", ".fa"],
|
|
|
131 |
value=FoldingModel.BOLTZ,
|
132 |
)
|
133 |
with gr.Column():
|
134 |
+
sequence = sequence_input(dropdown)
|
135 |
|
136 |
predict_btn = gr.Button(
|
137 |
"Predict",
|
|
|
159 |
"""
|
160 |
|
161 |
with gr.Row():
|
162 |
+
models = gr.CheckboxGroup(
|
163 |
label="Model",
|
164 |
choices=MODEL_CHOICES,
|
|
|
165 |
scale=0,
|
166 |
min_width=300,
|
167 |
value=[FoldingModel.BOLTZ, FoldingModel.CHAI, FoldingModel.PROTENIX],
|
|
|
175 |
elem_id="compare-models-btn",
|
176 |
variant="primary",
|
177 |
)
|
178 |
+
with gr.Row():
|
179 |
+
chai_predictions = gr.CheckboxGroup(label="Chai", visible=False)
|
180 |
+
protenix_predictions = gr.CheckboxGroup(label="Protenix", visible=False)
|
181 |
+
boltz_predictions = gr.CheckboxGroup(label="Boltz", visible=False)
|
182 |
with gr.Row():
|
183 |
mol_outputs = Molecule3D(
|
184 |
+
label="Protein Structure", reps=MOLECULE_REPS, height=1000
|
|
|
|
|
185 |
)
|
186 |
+
metrics_plot = gr.Plot(label="pLDDT")
|
187 |
|
188 |
+
# Store the initial predictions
|
189 |
+
aligned_paths = gr.State()
|
190 |
+
plddt_fig = gr.State()
|
191 |
|
192 |
predict_btn.click(
|
193 |
fn=predict_comparison,
|
194 |
inputs=[sequence, api_key, models],
|
195 |
+
outputs=[
|
196 |
+
chai_predictions,
|
197 |
+
boltz_predictions,
|
198 |
+
protenix_predictions,
|
199 |
+
aligned_paths,
|
200 |
+
plddt_fig,
|
201 |
+
],
|
202 |
)
|
203 |
|
204 |
+
# Handle checkbox changes
|
205 |
+
for checkbox in [chai_predictions, boltz_predictions, protenix_predictions]:
|
206 |
+
checkbox.change(
|
207 |
+
fn=filter_predictions,
|
208 |
+
inputs=[
|
209 |
+
aligned_paths,
|
210 |
+
plddt_fig,
|
211 |
+
chai_predictions,
|
212 |
+
boltz_predictions,
|
213 |
+
protenix_predictions,
|
214 |
+
],
|
215 |
+
outputs=[mol_outputs, metrics_plot],
|
216 |
+
)
|
217 |
+
|
218 |
|
219 |
def create_correlation_tab():
|
220 |
gr.Markdown("# Correlation with experimental binding affinity data")
|
|
|
271 |
choices=["Spearman", "Pearson", "R²"],
|
272 |
value="Spearman",
|
273 |
label="Correlation Type",
|
274 |
+
interactive=True,
|
275 |
)
|
276 |
with gr.Row():
|
277 |
correlation_ranking_plot = gr.Plot(label="Correlation ranking")
|
|
|
280 |
with gr.Row():
|
281 |
# User can select the columns to display in the correlation plot
|
282 |
correlation_column = gr.Dropdown(
|
283 |
+
label="Score data to display",
|
284 |
+
choices=SCORE_COLUMNS,
|
285 |
+
multiselect=False,
|
286 |
+
value=SCORE_COLUMNS[0],
|
287 |
)
|
288 |
# Add checkbox for log scale and update plot when either input changes
|
289 |
with gr.Row():
|
290 |
+
log_scale = gr.Checkbox(
|
291 |
+
label="Display x-axis on logarithmic scale", value=False
|
292 |
+
)
|
293 |
with gr.Row():
|
294 |
+
score_description = gr.Markdown(
|
295 |
+
get_score_description(correlation_column.value)
|
296 |
+
)
|
297 |
correlation_column.change(
|
298 |
fn=lambda x: get_score_description(x),
|
299 |
inputs=correlation_column,
|
300 |
+
outputs=score_description,
|
301 |
)
|
302 |
with gr.Column():
|
303 |
correlation_plot = gr.Plot(label="Correlation with binding affinity")
|
|
|
309 |
inputs=[correlation_type],
|
310 |
outputs=[prediction_dataframe, correlation_ranking_plot, correlation_plot],
|
311 |
)
|
312 |
+
|
313 |
def update_regression_plot(score, use_log):
|
314 |
return make_regression_plot(spr_data_with_scores, score, use_log)
|
315 |
+
|
316 |
def update_correlation_plot(correlation_type):
|
317 |
logger.info(f"Updating correlation plot for {correlation_type}")
|
318 |
corr_data = compute_correlation_data(spr_data_with_scores, SCORE_COLUMNS)
|
|
|
330 |
inputs=[correlation_type],
|
331 |
outputs=correlation_ranking_plot,
|
332 |
)
|
333 |
+
|
334 |
log_scale.change(
|
335 |
fn=update_regression_plot,
|
336 |
+
inputs=[correlation_column, log_scale],
|
337 |
outputs=correlation_plot,
|
338 |
)
|
339 |
|
340 |
|
341 |
def __main__():
|
|
|
342 |
theme = gr.themes.Ocean(
|
343 |
primary_hue="blue",
|
344 |
secondary_hue="purple",
|
folding_studio_demo/model_fasta_validators.py
CHANGED
@@ -248,15 +248,15 @@ class ChaiFastaValidator(BaseFastaValidator):
|
|
248 |
)
|
249 |
seen_names.add(name)
|
250 |
# validate sequence format
|
251 |
-
sequence = str(record.seq).strip()
|
252 |
-
if (
|
253 |
-
|
254 |
-
|
255 |
-
):
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
|
261 |
return True, None
|
262 |
|
|
|
248 |
)
|
249 |
seen_names.add(name)
|
250 |
# validate sequence format
|
251 |
+
# sequence = str(record.seq).strip()
|
252 |
+
# if (
|
253 |
+
# entity_type in {EntityType.PEPTIDE, EntityType.PROTEIN}
|
254 |
+
# and not get_entity_type(sequence) == entity_type
|
255 |
+
# ):
|
256 |
+
# return (
|
257 |
+
# False,
|
258 |
+
# f"CHAI Validation Error: Sequence type mismatch. Expected '{entity_type}' but found '{get_entity_type(sequence)}'",
|
259 |
+
# )
|
260 |
|
261 |
return True, None
|
262 |
|
folding_studio_demo/models.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Models for the Folding Studio API."""
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
import numpy as np
|
10 |
+
from folding_studio.client import Client
|
11 |
+
from folding_studio.query import Query
|
12 |
+
from folding_studio.query.boltz import BoltzQuery
|
13 |
+
from folding_studio.query.chai import ChaiQuery
|
14 |
+
from folding_studio.query.protenix import ProtenixQuery
|
15 |
+
|
16 |
+
from folding_studio_demo.model_fasta_validators import (
|
17 |
+
BaseFastaValidator,
|
18 |
+
BoltzFastaValidator,
|
19 |
+
ChaiFastaValidator,
|
20 |
+
ProtenixFastaValidator,
|
21 |
+
)
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
class AF3Model:
|
27 |
+
def __init__(
|
28 |
+
self, api_key: str, model_name: str, query: Query, validator: BaseFastaValidator
|
29 |
+
):
|
30 |
+
self.api_key = api_key
|
31 |
+
self.model_name = model_name
|
32 |
+
self.query = query
|
33 |
+
self.validator = validator
|
34 |
+
|
35 |
+
def call(
|
36 |
+
self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
|
37 |
+
) -> None:
|
38 |
+
"""Predict protein structure from amino acid sequence using AF3 model.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
seq_file (Path | str): Path to FASTA file containing amino acid sequence
|
42 |
+
output_dir (Path): Path to output directory
|
43 |
+
format_description (bool): Whether to format the description of the sequence
|
44 |
+
"""
|
45 |
+
# Validate FASTA format before calling
|
46 |
+
is_valid, error_msg = self.check_file_description(seq_file)
|
47 |
+
if format_fasta and not is_valid:
|
48 |
+
logger.info("Invalid FASTA file format, forcing formatting...")
|
49 |
+
self.format_fasta(seq_file)
|
50 |
+
elif not is_valid:
|
51 |
+
logger.error(error_msg)
|
52 |
+
raise gr.Error(error_msg)
|
53 |
+
|
54 |
+
# Create a client using API key
|
55 |
+
logger.info("Authenticating client with API key")
|
56 |
+
client = Client.from_api_key(api_key=self.api_key)
|
57 |
+
|
58 |
+
# Define query
|
59 |
+
query: Query = self.query.from_file(path=seq_file, query_name="gradio")
|
60 |
+
query.save_parameters(output_dir)
|
61 |
+
|
62 |
+
logger.info("Payload: %s", query.payload)
|
63 |
+
|
64 |
+
# Send a request
|
65 |
+
logger.info(f"Sending {self.model_name} request to Folding Studio API")
|
66 |
+
response = client.send_request(
|
67 |
+
query, project_code=os.environ["FOLDING_PROJECT_CODE"]
|
68 |
+
)
|
69 |
+
|
70 |
+
# Access confidence data
|
71 |
+
logger.info("Confidence data: %s", response.confidence_data)
|
72 |
+
|
73 |
+
response.download_results(output_dir=output_dir, force=True, unzip=True)
|
74 |
+
logger.info("Results downloaded to %s", output_dir)
|
75 |
+
|
76 |
+
def format_fasta(self, seq_file: Path | str) -> None:
|
77 |
+
"""Format sequence to FASTA format.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
seq_file (Path | str): Path to FASTA file
|
81 |
+
"""
|
82 |
+
formatted_fasta = self.validator.transform_fasta(seq_file)
|
83 |
+
with open(seq_file, "w") as f:
|
84 |
+
f.write(formatted_fasta)
|
85 |
+
|
86 |
+
def predictions(self, output_dir: Path) -> list[Path]:
|
87 |
+
"""Get the path to the prediction.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
output_dir (Path): Path to output directory
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
list[Path]: List of paths to predictions
|
94 |
+
"""
|
95 |
+
raise NotImplementedError()
|
96 |
+
|
97 |
+
def has_prediction(self, output_dir: Path) -> bool:
|
98 |
+
"""Check if prediction exists in output directory."""
|
99 |
+
return len(self.predictions(output_dir)) > 0
|
100 |
+
|
101 |
+
def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]:
|
102 |
+
"""Check if the file description is correct.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
seq_file (Path | str): Path to FASTA file
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not
|
109 |
+
"""
|
110 |
+
|
111 |
+
is_valid, error_msg = self.validator.is_valid_fasta(seq_file)
|
112 |
+
if not is_valid:
|
113 |
+
return False, error_msg
|
114 |
+
|
115 |
+
return True, None
|
116 |
+
|
117 |
+
|
118 |
+
class ChaiModel(AF3Model):
|
119 |
+
def __init__(self, api_key: str):
|
120 |
+
super().__init__(api_key, "Chai", ChaiQuery, ChaiFastaValidator())
|
121 |
+
|
122 |
+
def call(
|
123 |
+
self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
|
124 |
+
) -> None:
|
125 |
+
"""Predict protein structure from amino acid sequence using Chai model.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
seq_file (Path | str): Path to FASTA file containing amino acid sequence
|
129 |
+
output_dir (Path): Path to output directory
|
130 |
+
format_fasta (bool): Whether to format the FASTA file
|
131 |
+
"""
|
132 |
+
super().call(seq_file, output_dir, format_fasta)
|
133 |
+
|
134 |
+
def predictions(self, output_dir: Path) -> dict[Path, dict[str, Any]]:
|
135 |
+
"""Get the path to the prediction."""
|
136 |
+
prediction = next(output_dir.rglob("pred.model_idx_[0-9].cif"), None)
|
137 |
+
if prediction is None:
|
138 |
+
return {}
|
139 |
+
|
140 |
+
cif_files = {
|
141 |
+
int(f.stem.split("model_idx_")[1]): f
|
142 |
+
for f in prediction.parent.glob("pred.model_idx_*.cif")
|
143 |
+
}
|
144 |
+
|
145 |
+
# Get all npz files and extract their indices
|
146 |
+
npz_files = {
|
147 |
+
int(f.stem.split("model_idx_")[1]): f
|
148 |
+
for f in prediction.parent.glob("scores.model_idx_*.npz")
|
149 |
+
}
|
150 |
+
|
151 |
+
# Find common indices and create pairs
|
152 |
+
common_indices = sorted(set(cif_files.keys()) & set(npz_files.keys()))
|
153 |
+
|
154 |
+
return {
|
155 |
+
idx: {"prediction_path": cif_files[idx], "metrics": np.load(npz_files[idx])}
|
156 |
+
for idx in common_indices
|
157 |
+
}
|
158 |
+
|
159 |
+
|
160 |
+
class ProtenixModel(AF3Model):
|
161 |
+
def __init__(self, api_key: str):
|
162 |
+
super().__init__(api_key, "Protenix", ProtenixQuery, ProtenixFastaValidator())
|
163 |
+
|
164 |
+
def call(
|
165 |
+
self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
|
166 |
+
) -> None:
|
167 |
+
"""Predict protein structure from amino acid sequence using Protenix model.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
seq_file (Path | str): Path to FASTA file containing amino acid sequence
|
171 |
+
output_dir (Path): Path to output directory
|
172 |
+
format_fasta (bool): Whether to format the FASTA file
|
173 |
+
"""
|
174 |
+
super().call(seq_file, output_dir, format_fasta)
|
175 |
+
|
176 |
+
def predictions(self, output_dir: Path) -> list[Path]:
|
177 |
+
"""Get the path to the prediction."""
|
178 |
+
return list(output_dir.rglob("*_model_[0-9].cif"))
|
179 |
+
|
180 |
+
|
181 |
+
class BoltzModel(AF3Model):
|
182 |
+
def __init__(self, api_key: str):
|
183 |
+
super().__init__(api_key, "Boltz", BoltzQuery, BoltzFastaValidator())
|
184 |
+
|
185 |
+
def call(
|
186 |
+
self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
|
187 |
+
) -> None:
|
188 |
+
"""Predict protein structure from amino acid sequence using Boltz model.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
seq_file (Path | str): Path to FASTA file containing amino acid sequence
|
192 |
+
output_dir (Path): Path to output directory
|
193 |
+
format_fasta (bool): Whether to format the FASTA file
|
194 |
+
"""
|
195 |
+
|
196 |
+
super().call(seq_file, output_dir, format_fasta)
|
197 |
+
|
198 |
+
def predictions(self, output_dir: Path) -> list[Path]:
|
199 |
+
"""Get the path to the prediction."""
|
200 |
+
prediction_paths = list(output_dir.rglob("*_model_[0-9].cif"))
|
201 |
+
return {
|
202 |
+
int(cif_path.stem[-1]): {
|
203 |
+
"prediction_path": cif_path,
|
204 |
+
"metrics": np.load(list(cif_path.parent.glob("plddt_*.npz"))[0]),
|
205 |
+
}
|
206 |
+
for cif_path in prediction_paths
|
207 |
+
}
|
folding_studio_demo/predict.py
CHANGED
@@ -2,29 +2,17 @@
|
|
2 |
|
3 |
import hashlib
|
4 |
import logging
|
5 |
-
import os
|
6 |
from io import StringIO
|
7 |
from pathlib import Path
|
8 |
-
from typing import Any
|
9 |
|
10 |
import gradio as gr
|
11 |
import numpy as np
|
12 |
import plotly.graph_objects as go
|
13 |
from Bio import SeqIO
|
14 |
from Bio.PDB import PDBIO, MMCIFParser, PDBParser, Superimposer
|
15 |
-
from folding_studio.client import Client
|
16 |
-
from folding_studio.query import Query
|
17 |
-
from folding_studio.query.boltz import BoltzQuery
|
18 |
-
from folding_studio.query.chai import ChaiQuery
|
19 |
-
from folding_studio.query.protenix import ProtenixQuery
|
20 |
from folding_studio_data_models import FoldingModel
|
21 |
|
22 |
-
from folding_studio_demo.
|
23 |
-
BaseFastaValidator,
|
24 |
-
BoltzFastaValidator,
|
25 |
-
ChaiFastaValidator,
|
26 |
-
ProtenixFastaValidator,
|
27 |
-
)
|
28 |
|
29 |
logger = logging.getLogger(__name__)
|
30 |
|
@@ -34,6 +22,48 @@ SEQUENCE_DIR.mkdir(parents=True, exist_ok=True)
|
|
34 |
OUTPUT_DIR = Path("output")
|
35 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None:
|
39 |
"""Convert a .cif file to .pdb format using Biopython.
|
@@ -52,29 +82,46 @@ def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None:
|
|
52 |
io.save(pdb_path)
|
53 |
|
54 |
|
55 |
-
def
|
|
|
|
|
|
|
|
|
56 |
"""Create a plot of metrics."""
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
)
|
66 |
-
for i, plddt_val in enumerate(plddt_vals)
|
67 |
-
]
|
68 |
-
|
69 |
plddt_fig = go.Figure(data=plddt_traces)
|
70 |
plddt_fig.update_layout(
|
71 |
title="pLDDT",
|
72 |
-
xaxis_title="Residue
|
73 |
yaxis_title="pLDDT",
|
74 |
height=500,
|
75 |
template="simple_white",
|
76 |
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
|
77 |
)
|
|
|
78 |
return plddt_fig
|
79 |
|
80 |
|
@@ -103,178 +150,12 @@ def _write_fasta_file(
|
|
103 |
return seq_id, seq_file
|
104 |
|
105 |
|
106 |
-
class AF3Model:
|
107 |
-
def __init__(
|
108 |
-
self, api_key: str, model_name: str, query: Query, validator: BaseFastaValidator
|
109 |
-
):
|
110 |
-
self.api_key = api_key
|
111 |
-
self.model_name = model_name
|
112 |
-
self.query = query
|
113 |
-
self.validator = validator
|
114 |
-
|
115 |
-
def call(self, seq_file: Path | str, output_dir: Path) -> None:
|
116 |
-
"""Predict protein structure from amino acid sequence using AF3 model.
|
117 |
-
|
118 |
-
Args:
|
119 |
-
seq_file (Path | str): Path to FASTA file containing amino acid sequence
|
120 |
-
output_dir (Path): Path to output directory
|
121 |
-
"""
|
122 |
-
# Validate FASTA format before calling
|
123 |
-
is_valid, error_msg = self.check_file_description(seq_file)
|
124 |
-
if not is_valid:
|
125 |
-
logger.error(error_msg)
|
126 |
-
raise gr.Error(error_msg)
|
127 |
-
|
128 |
-
# Create a client using API key
|
129 |
-
logger.info("Authenticating client with API key")
|
130 |
-
client = Client.from_api_key(api_key=self.api_key)
|
131 |
-
|
132 |
-
# Define query
|
133 |
-
query: Query = self.query.from_file(path=seq_file, query_name="gradio")
|
134 |
-
query.save_parameters(output_dir)
|
135 |
-
|
136 |
-
logger.info("Payload: %s", query.payload)
|
137 |
-
|
138 |
-
# Send a request
|
139 |
-
logger.info(f"Sending {self.model_name} request to Folding Studio API")
|
140 |
-
response = client.send_request(
|
141 |
-
query, project_code=os.environ["FOLDING_PROJECT_CODE"]
|
142 |
-
)
|
143 |
-
|
144 |
-
# Access confidence data
|
145 |
-
logger.info("Confidence data: %s", response.confidence_data)
|
146 |
-
|
147 |
-
response.download_results(output_dir=output_dir, force=True, unzip=True)
|
148 |
-
logger.info("Results downloaded to %s", output_dir)
|
149 |
-
|
150 |
-
def format_fasta(self, sequence: str) -> str:
|
151 |
-
"""Format sequence to FASTA format."""
|
152 |
-
return f">{self.model_name}\n{sequence}"
|
153 |
-
|
154 |
-
def predictions(self, output_dir: Path) -> list[Path]:
|
155 |
-
"""Get the path to the prediction."""
|
156 |
-
raise NotImplementedError("Not implemented")
|
157 |
-
|
158 |
-
def has_prediction(self, output_dir: Path) -> bool:
|
159 |
-
"""Check if prediction exists in output directory."""
|
160 |
-
return len(self.predictions(output_dir)) > 0
|
161 |
-
|
162 |
-
def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]:
|
163 |
-
"""Check if the file description is correct.
|
164 |
-
|
165 |
-
Args:
|
166 |
-
seq_file (Path | str): Path to FASTA file
|
167 |
-
|
168 |
-
Returns:
|
169 |
-
tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not
|
170 |
-
"""
|
171 |
-
|
172 |
-
is_valid, error_msg = self.validator.is_valid_fasta(seq_file)
|
173 |
-
if not is_valid:
|
174 |
-
return False, error_msg
|
175 |
-
|
176 |
-
return True, None
|
177 |
-
|
178 |
-
|
179 |
-
class ChaiModel(AF3Model):
|
180 |
-
def __init__(self, api_key: str):
|
181 |
-
super().__init__(api_key, "Chai", ChaiQuery, ChaiFastaValidator())
|
182 |
-
|
183 |
-
def call(self, seq_file: Path | str, output_dir: Path) -> None:
|
184 |
-
"""Predict protein structure from amino acid sequence using Chai model.
|
185 |
-
|
186 |
-
Args:
|
187 |
-
seq_file (Path | str): Path to FASTA file containing amino acid sequence
|
188 |
-
output_dir (Path): Path to output directory
|
189 |
-
"""
|
190 |
-
super().call(seq_file, output_dir)
|
191 |
-
|
192 |
-
def _get_chai_paired_files(self, directory: Path) -> list[tuple[Path, Path]]:
|
193 |
-
"""Get pairs of .cif and .npz files with matching model indices.
|
194 |
-
|
195 |
-
Args:
|
196 |
-
directory (Path): Directory containing the prediction files
|
197 |
-
|
198 |
-
Returns:
|
199 |
-
list[tuple[Path, Path]]: List of tuples containing (cif_path, npz_path) pairs
|
200 |
-
"""
|
201 |
-
# Get all cif files and extract their indices
|
202 |
-
|
203 |
-
def predictions(self, output_dir: Path) -> dict[Path, dict[str, Any]]:
|
204 |
-
"""Get the path to the prediction."""
|
205 |
-
prediction = next(output_dir.rglob("pred.model_idx_[0-9].cif"), None)
|
206 |
-
if prediction is None:
|
207 |
-
return {}
|
208 |
-
|
209 |
-
cif_files = {
|
210 |
-
int(f.stem.split("model_idx_")[1]): f
|
211 |
-
for f in prediction.parent.glob("pred.model_idx_*.cif")
|
212 |
-
}
|
213 |
-
|
214 |
-
# Get all npz files and extract their indices
|
215 |
-
npz_files = {
|
216 |
-
int(f.stem.split("model_idx_")[1]): f
|
217 |
-
for f in prediction.parent.glob("scores.model_idx_*.npz")
|
218 |
-
}
|
219 |
-
|
220 |
-
# Find common indices and create pairs
|
221 |
-
common_indices = sorted(set(cif_files.keys()) & set(npz_files.keys()))
|
222 |
-
|
223 |
-
return {
|
224 |
-
idx: {"prediction_path": cif_files[idx], "metrics": np.load(npz_files[idx])}
|
225 |
-
for idx in common_indices
|
226 |
-
}
|
227 |
-
|
228 |
-
|
229 |
-
class ProtenixModel(AF3Model):
|
230 |
-
def __init__(self, api_key: str):
|
231 |
-
super().__init__(api_key, "Protenix", ProtenixQuery, ProtenixFastaValidator())
|
232 |
-
|
233 |
-
def call(self, seq_file: Path | str, output_dir: Path) -> None:
|
234 |
-
"""Predict protein structure from amino acid sequence using Protenix model.
|
235 |
-
|
236 |
-
Args:
|
237 |
-
seq_file (Path | str): Path to FASTA file containing amino acid sequence
|
238 |
-
output_dir (Path): Path to output directory
|
239 |
-
"""
|
240 |
-
super().call(seq_file, output_dir)
|
241 |
-
|
242 |
-
def predictions(self, output_dir: Path) -> list[Path]:
|
243 |
-
"""Get the path to the prediction."""
|
244 |
-
return list(output_dir.rglob("*_model_[0-9].cif"))
|
245 |
-
|
246 |
-
|
247 |
-
class BoltzModel(AF3Model):
|
248 |
-
def __init__(self, api_key: str):
|
249 |
-
super().__init__(api_key, "Boltz", BoltzQuery, BoltzFastaValidator())
|
250 |
-
|
251 |
-
def call(self, seq_file: Path | str, output_dir: Path) -> None:
|
252 |
-
"""Predict protein structure from amino acid sequence using Boltz model.
|
253 |
-
|
254 |
-
Args:
|
255 |
-
seq_file (Path | str): Path to FASTA file containing amino acid sequence
|
256 |
-
output_dir (Path): Path to output directory
|
257 |
-
"""
|
258 |
-
|
259 |
-
super().call(seq_file, output_dir)
|
260 |
-
|
261 |
-
def predictions(self, output_dir: Path) -> list[Path]:
|
262 |
-
"""Get the path to the prediction."""
|
263 |
-
prediction_paths = list(output_dir.rglob("*_model_[0-9].cif"))
|
264 |
-
return {
|
265 |
-
int(cif_path.stem[-1]): {
|
266 |
-
"prediction_path": cif_path,
|
267 |
-
"metrics": np.load(list(cif_path.parent.glob("plddt_*.npz"))[0]),
|
268 |
-
}
|
269 |
-
for cif_path in prediction_paths
|
270 |
-
}
|
271 |
-
|
272 |
-
|
273 |
def extract_plddt_from_cif(cif_path):
|
274 |
structure = MMCIFParser().get_structure("structure", cif_path)
|
275 |
|
276 |
-
#
|
277 |
plddt_values = []
|
|
|
278 |
|
279 |
# Iterate through all atoms
|
280 |
for model in structure:
|
@@ -285,17 +166,27 @@ def extract_plddt_from_cif(cif_path):
|
|
285 |
# The B-factor contains the pLDDT value
|
286 |
plddt = residue["CA"].get_bfactor()
|
287 |
plddt_values.append(plddt)
|
|
|
|
|
288 |
|
289 |
-
return plddt_values
|
290 |
|
291 |
|
292 |
-
def predict(
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
"""Predict protein structure from amino acid sequence using Boltz model.
|
294 |
|
295 |
Args:
|
296 |
sequence (str): Amino acid sequence to predict structure for
|
297 |
api_key (str): Folding API key
|
298 |
model (FoldingModel): Folding model to use
|
|
|
|
|
299 |
|
300 |
Returns:
|
301 |
tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot
|
@@ -303,6 +194,7 @@ def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str,
|
|
303 |
if not api_key:
|
304 |
raise gr.Error("Missing API key, please enter a valid API key")
|
305 |
|
|
|
306 |
# Set up unique output directory based on sequence hash
|
307 |
seq_id, seq_file = _write_fasta_file(sequence)
|
308 |
output_dir = OUTPUT_DIR / seq_id / model_type
|
@@ -319,15 +211,16 @@ def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str,
|
|
319 |
|
320 |
# Check if prediction already exists
|
321 |
if not model.has_prediction(output_dir):
|
322 |
-
|
|
|
323 |
logger.info(f"Predicting {seq_id}")
|
324 |
-
model.call(seq_file=seq_file, output_dir=output_dir)
|
325 |
logger.info("Prediction done. Output directory: %s", output_dir)
|
326 |
else:
|
|
|
327 |
logger.info("Prediction already exists. Output directory: %s", output_dir)
|
328 |
|
329 |
-
|
330 |
-
|
331 |
# Convert output CIF to PDB
|
332 |
if not model.has_prediction(output_dir):
|
333 |
raise gr.Error("No prediction found")
|
@@ -335,23 +228,34 @@ def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str,
|
|
335 |
predictions = model.predictions(output_dir)
|
336 |
pdb_paths = []
|
337 |
model_plddt_vals = []
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
|
|
|
|
342 |
)
|
|
|
|
|
343 |
|
344 |
converted_pdb_path = str(
|
345 |
output_dir / f"{model.model_name}_prediction_{model_idx}.pdb"
|
346 |
)
|
347 |
convert_cif_to_pdb(str(cif_path), str(converted_pdb_path))
|
348 |
-
plddt_vals = extract_plddt_from_cif(cif_path)
|
349 |
pdb_paths.append(converted_pdb_path)
|
350 |
model_plddt_vals.append(plddt_vals)
|
351 |
-
|
352 |
-
|
|
|
|
|
|
|
|
|
|
|
353 |
)
|
354 |
-
|
|
|
|
|
355 |
|
356 |
|
357 |
def align_structures(pdb_paths: list[str]) -> list[str]:
|
@@ -397,28 +301,148 @@ def align_structures(pdb_paths: list[str]) -> list[str]:
|
|
397 |
return aligned_paths
|
398 |
|
399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
def predict_comparison(
|
401 |
-
sequence: str, api_key: str, model_types: list[FoldingModel]
|
402 |
-
) -> tuple[
|
403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
|
405 |
Args:
|
406 |
sequence (str): Amino acid sequence to predict structure for
|
407 |
api_key (str): Folding API key
|
408 |
-
|
|
|
409 |
|
410 |
Returns:
|
411 |
-
tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
"""
|
413 |
if not api_key:
|
414 |
raise gr.Error("Missing API key, please enter a valid API key")
|
415 |
|
416 |
# Set up unique output directory based on sequence hash
|
417 |
pdb_paths = []
|
418 |
-
|
419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
pdb_paths += model_pdb_paths
|
|
|
|
|
421 |
|
|
|
422 |
aligned_paths = align_structures(pdb_paths)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
|
424 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import hashlib
|
4 |
import logging
|
|
|
5 |
from io import StringIO
|
6 |
from pathlib import Path
|
|
|
7 |
|
8 |
import gradio as gr
|
9 |
import numpy as np
|
10 |
import plotly.graph_objects as go
|
11 |
from Bio import SeqIO
|
12 |
from Bio.PDB import PDBIO, MMCIFParser, PDBParser, Superimposer
|
|
|
|
|
|
|
|
|
|
|
13 |
from folding_studio_data_models import FoldingModel
|
14 |
|
15 |
+
from folding_studio_demo.models import BoltzModel, ChaiModel, ProtenixModel
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
18 |
|
|
|
22 |
OUTPUT_DIR = Path("output")
|
23 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
24 |
|
25 |
+
THREE_TO_ONE_LETTER = {
|
26 |
+
"ALA": "A",
|
27 |
+
"ARG": "R",
|
28 |
+
"ASN": "N",
|
29 |
+
"ASP": "D",
|
30 |
+
"CYS": "C",
|
31 |
+
"GLN": "Q",
|
32 |
+
"GLU": "E",
|
33 |
+
"GLY": "G",
|
34 |
+
"HIS": "H",
|
35 |
+
"ILE": "I",
|
36 |
+
"LEU": "L",
|
37 |
+
"LYS": "K",
|
38 |
+
"MET": "M",
|
39 |
+
"PHE": "F",
|
40 |
+
"PRO": "P",
|
41 |
+
"SER": "S",
|
42 |
+
"THR": "T",
|
43 |
+
"TRP": "W",
|
44 |
+
"TYR": "Y",
|
45 |
+
"VAL": "V",
|
46 |
+
"SEC": "U",
|
47 |
+
"PYL": "O",
|
48 |
+
"ASX": "B",
|
49 |
+
"GLX": "Z",
|
50 |
+
"XAA": "X",
|
51 |
+
"XLE": "J",
|
52 |
+
"UNK": "X",
|
53 |
+
}
|
54 |
+
|
55 |
+
|
56 |
+
def convert_to_one_letter(resname: str) -> str:
|
57 |
+
"""Convert three-letter amino acid code to one-letter code.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
resname (str): Three-letter amino acid code
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
str: One-letter amino acid code
|
64 |
+
"""
|
65 |
+
return THREE_TO_ONE_LETTER.get(resname, "X")
|
66 |
+
|
67 |
|
68 |
def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None:
|
69 |
"""Convert a .cif file to .pdb format using Biopython.
|
|
|
82 |
io.save(pdb_path)
|
83 |
|
84 |
|
85 |
+
def create_plddt_figure(
|
86 |
+
plddt_vals: list[list[float]],
|
87 |
+
model_name: str,
|
88 |
+
residue_codes: list[list[str]] = None,
|
89 |
+
) -> go.Figure:
|
90 |
"""Create a plot of metrics."""
|
91 |
+
plddt_traces = []
|
92 |
+
for i, plddt_val in enumerate(plddt_vals):
|
93 |
+
# Create hover text with residue codes if available
|
94 |
+
if residue_codes and i < len(residue_codes):
|
95 |
+
hover_text = [
|
96 |
+
f"<i>pLDDT</i>: {plddt:.2f}<br><i>Residue:</i> {code} {idx}"
|
97 |
+
for idx, (plddt, code) in enumerate(zip(plddt_val, residue_codes[i]))
|
98 |
+
]
|
99 |
+
else:
|
100 |
+
hover_text = [
|
101 |
+
f"<i>pLDDT</i>: {plddt:.2f}<br><i>Residue index:</i> {idx}"
|
102 |
+
for idx, plddt in enumerate(plddt_val)
|
103 |
+
]
|
104 |
+
|
105 |
+
plddt_traces.append(
|
106 |
+
go.Scatter(
|
107 |
+
x=np.arange(len(plddt_val)),
|
108 |
+
y=plddt_val,
|
109 |
+
hovertemplate="%{text}<extra></extra>",
|
110 |
+
text=hover_text,
|
111 |
+
name=f"{model_name} {i}",
|
112 |
+
visible=True,
|
113 |
+
)
|
114 |
)
|
|
|
|
|
|
|
115 |
plddt_fig = go.Figure(data=plddt_traces)
|
116 |
plddt_fig.update_layout(
|
117 |
title="pLDDT",
|
118 |
+
xaxis_title="Residue",
|
119 |
yaxis_title="pLDDT",
|
120 |
height=500,
|
121 |
template="simple_white",
|
122 |
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
|
123 |
)
|
124 |
+
|
125 |
return plddt_fig
|
126 |
|
127 |
|
|
|
150 |
return seq_id, seq_file
|
151 |
|
152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
def extract_plddt_from_cif(cif_path):
|
154 |
structure = MMCIFParser().get_structure("structure", cif_path)
|
155 |
|
156 |
+
# Lists to store pLDDT values and residue codes
|
157 |
plddt_values = []
|
158 |
+
residue_codes = []
|
159 |
|
160 |
# Iterate through all atoms
|
161 |
for model in structure:
|
|
|
166 |
# The B-factor contains the pLDDT value
|
167 |
plddt = residue["CA"].get_bfactor()
|
168 |
plddt_values.append(plddt)
|
169 |
+
# Get residue code and convert to one-letter code
|
170 |
+
residue_codes.append(convert_to_one_letter(residue.get_resname()))
|
171 |
|
172 |
+
return plddt_values, residue_codes
|
173 |
|
174 |
|
175 |
+
def predict(
|
176 |
+
sequence: str,
|
177 |
+
api_key: str,
|
178 |
+
model_type: FoldingModel,
|
179 |
+
format_fasta: bool = False,
|
180 |
+
progress=gr.Progress(),
|
181 |
+
) -> tuple[str, str]:
|
182 |
"""Predict protein structure from amino acid sequence using Boltz model.
|
183 |
|
184 |
Args:
|
185 |
sequence (str): Amino acid sequence to predict structure for
|
186 |
api_key (str): Folding API key
|
187 |
model (FoldingModel): Folding model to use
|
188 |
+
format_fasta (bool): Whether to format the FASTA file
|
189 |
+
progress (gr.Progress): Gradio progress tracker
|
190 |
|
191 |
Returns:
|
192 |
tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot
|
|
|
194 |
if not api_key:
|
195 |
raise gr.Error("Missing API key, please enter a valid API key")
|
196 |
|
197 |
+
progress(0, desc="Setting up prediction...")
|
198 |
# Set up unique output directory based on sequence hash
|
199 |
seq_id, seq_file = _write_fasta_file(sequence)
|
200 |
output_dir = OUTPUT_DIR / seq_id / model_type
|
|
|
211 |
|
212 |
# Check if prediction already exists
|
213 |
if not model.has_prediction(output_dir):
|
214 |
+
progress(0.2, desc="Running prediction...")
|
215 |
+
# Run prediction
|
216 |
logger.info(f"Predicting {seq_id}")
|
217 |
+
model.call(seq_file=seq_file, output_dir=output_dir, format_fasta=format_fasta)
|
218 |
logger.info("Prediction done. Output directory: %s", output_dir)
|
219 |
else:
|
220 |
+
progress(0.2, desc="Using existing prediction...")
|
221 |
logger.info("Prediction already exists. Output directory: %s", output_dir)
|
222 |
|
223 |
+
progress(0.4, desc="Processing results...")
|
|
|
224 |
# Convert output CIF to PDB
|
225 |
if not model.has_prediction(output_dir):
|
226 |
raise gr.Error("No prediction found")
|
|
|
228 |
predictions = model.predictions(output_dir)
|
229 |
pdb_paths = []
|
230 |
model_plddt_vals = []
|
231 |
+
model_residue_codes = []
|
232 |
+
|
233 |
+
total_predictions = len(predictions)
|
234 |
+
for i, (model_idx, prediction) in enumerate(predictions.items()):
|
235 |
+
progress(
|
236 |
+
0.4 + (0.4 * i / total_predictions), desc=f"Converting model {model_idx}..."
|
237 |
)
|
238 |
+
cif_path = prediction["prediction_path"]
|
239 |
+
logger.info(f"CIF file: {cif_path}")
|
240 |
|
241 |
converted_pdb_path = str(
|
242 |
output_dir / f"{model.model_name}_prediction_{model_idx}.pdb"
|
243 |
)
|
244 |
convert_cif_to_pdb(str(cif_path), str(converted_pdb_path))
|
245 |
+
plddt_vals, residue_codes = extract_plddt_from_cif(cif_path)
|
246 |
pdb_paths.append(converted_pdb_path)
|
247 |
model_plddt_vals.append(plddt_vals)
|
248 |
+
model_residue_codes.append(residue_codes)
|
249 |
+
|
250 |
+
progress(0.8, desc="Generating plots...")
|
251 |
+
plddt_fig = create_plddt_figure(
|
252 |
+
plddt_vals=model_plddt_vals,
|
253 |
+
model_name=model.model_name,
|
254 |
+
residue_codes=model_residue_codes,
|
255 |
)
|
256 |
+
|
257 |
+
progress(1.0, desc="Done!")
|
258 |
+
return pdb_paths, plddt_fig
|
259 |
|
260 |
|
261 |
def align_structures(pdb_paths: list[str]) -> list[str]:
|
|
|
301 |
return aligned_paths
|
302 |
|
303 |
|
304 |
+
def filter_predictions(
|
305 |
+
aligned_paths: list[str],
|
306 |
+
plddt_fig: go.Figure,
|
307 |
+
chai_selected: list[int],
|
308 |
+
boltz_selected: list[int],
|
309 |
+
protenix_selected: list[int],
|
310 |
+
) -> tuple[list[str], go.Figure]:
|
311 |
+
"""Filter predictions based on selected checkboxes.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
aligned_paths (list[str]): List of aligned PDB paths
|
315 |
+
plddt_fig (go.Figure): Original pLDDT plot
|
316 |
+
chai_selected (list[int]): Selected Chai model indices
|
317 |
+
boltz_selected (list[int]): Selected Boltz model indices
|
318 |
+
protenix_selected (list[int]): Selected Protenix model indices
|
319 |
+
model_predictions (dict[FoldingModel, list[int]]): Dictionary mapping models to their prediction indices
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
tuple[list[str], go.Figure]: Filtered PDB paths and updated pLDDT plot
|
323 |
+
"""
|
324 |
+
# Create a new figure with only selected traces
|
325 |
+
filtered_fig = go.Figure()
|
326 |
+
|
327 |
+
# Keep track of which traces to show
|
328 |
+
visible_paths = []
|
329 |
+
|
330 |
+
# Helper function to check if a trace should be visible
|
331 |
+
def should_show_trace(trace_name: str) -> bool:
|
332 |
+
model_name = trace_name.split()[0]
|
333 |
+
model_idx = int(trace_name.split()[1])
|
334 |
+
|
335 |
+
if model_name == "Chai" and model_idx in chai_selected:
|
336 |
+
return True
|
337 |
+
if model_name == "Boltz" and model_idx in boltz_selected:
|
338 |
+
return True
|
339 |
+
if model_name == "Protenix" and model_idx in protenix_selected:
|
340 |
+
return True
|
341 |
+
return False
|
342 |
+
|
343 |
+
# Filter traces and paths
|
344 |
+
for i, trace in enumerate(plddt_fig.data):
|
345 |
+
if should_show_trace(trace.name):
|
346 |
+
filtered_fig.add_trace(trace)
|
347 |
+
visible_paths.append(aligned_paths[i])
|
348 |
+
|
349 |
+
# Update layout
|
350 |
+
filtered_fig.update_layout(
|
351 |
+
title="pLDDT",
|
352 |
+
xaxis_title="Residue index",
|
353 |
+
yaxis_title="pLDDT",
|
354 |
+
height=500,
|
355 |
+
template="simple_white",
|
356 |
+
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
|
357 |
+
)
|
358 |
+
|
359 |
+
return visible_paths, filtered_fig
|
360 |
+
|
361 |
+
|
362 |
def predict_comparison(
|
363 |
+
sequence: str, api_key: str, model_types: list[FoldingModel], progress=gr.Progress()
|
364 |
+
) -> tuple[
|
365 |
+
list[str],
|
366 |
+
go.Figure,
|
367 |
+
gr.CheckboxGroup,
|
368 |
+
gr.CheckboxGroup,
|
369 |
+
gr.CheckboxGroup,
|
370 |
+
list[str],
|
371 |
+
go.Figure,
|
372 |
+
dict,
|
373 |
+
]:
|
374 |
+
"""Predict protein structure from amino acid sequence using multiple models.
|
375 |
|
376 |
Args:
|
377 |
sequence (str): Amino acid sequence to predict structure for
|
378 |
api_key (str): Folding API key
|
379 |
+
model_types (list[FoldingModel]): List of folding models to use
|
380 |
+
progress (gr.Progress): Gradio progress tracker
|
381 |
|
382 |
Returns:
|
383 |
+
tuple containing:
|
384 |
+
- list[str]: Aligned PDB paths
|
385 |
+
- go.Figure: pLDDT plot
|
386 |
+
- gr.CheckboxGroup: Chai predictions checkbox group
|
387 |
+
- gr.CheckboxGroup: Boltz predictions checkbox group
|
388 |
+
- gr.CheckboxGroup: Protenix predictions checkbox group
|
389 |
+
- list[str]: Original PDB paths
|
390 |
+
- go.Figure: Original pLDDT plot
|
391 |
+
- dict: Model predictions mapping
|
392 |
"""
|
393 |
if not api_key:
|
394 |
raise gr.Error("Missing API key, please enter a valid API key")
|
395 |
|
396 |
# Set up unique output directory based on sequence hash
|
397 |
pdb_paths = []
|
398 |
+
plddt_traces = []
|
399 |
+
total_models = len(model_types)
|
400 |
+
model_predictions = {}
|
401 |
+
|
402 |
+
for i, model_type in enumerate(model_types):
|
403 |
+
progress(i / total_models, desc=f"Running {model_type} prediction...")
|
404 |
+
model_pdb_paths, model_plddt_traces = predict(
|
405 |
+
sequence, api_key, model_type, format_fasta=True
|
406 |
+
)
|
407 |
pdb_paths += model_pdb_paths
|
408 |
+
plddt_traces += model_plddt_traces.data
|
409 |
+
model_predictions[model_type] = [int(Path(p).stem[-1]) for p in model_pdb_paths]
|
410 |
|
411 |
+
progress(0.9, desc="Aligning structures...")
|
412 |
aligned_paths = align_structures(pdb_paths)
|
413 |
+
plddt_fig = go.Figure(data=plddt_traces)
|
414 |
+
plddt_fig.update_layout(
|
415 |
+
title="pLDDT",
|
416 |
+
xaxis_title="Residue index",
|
417 |
+
yaxis_title="pLDDT",
|
418 |
+
height=500,
|
419 |
+
template="simple_white",
|
420 |
+
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
|
421 |
+
)
|
422 |
|
423 |
+
progress(1.0, desc="Done!")
|
424 |
+
|
425 |
+
# Create checkbox groups for each model type
|
426 |
+
chai_predictions = gr.CheckboxGroup(
|
427 |
+
visible=model_predictions.get(FoldingModel.CHAI) is not None,
|
428 |
+
choices=model_predictions.get(FoldingModel.CHAI, []),
|
429 |
+
value=model_predictions.get(FoldingModel.CHAI, []),
|
430 |
+
)
|
431 |
+
boltz_predictions = gr.CheckboxGroup(
|
432 |
+
visible=model_predictions.get(FoldingModel.BOLTZ) is not None,
|
433 |
+
choices=model_predictions.get(FoldingModel.BOLTZ, []),
|
434 |
+
value=model_predictions.get(FoldingModel.BOLTZ, []),
|
435 |
+
)
|
436 |
+
protenix_predictions = gr.CheckboxGroup(
|
437 |
+
visible=model_predictions.get(FoldingModel.PROTENIX) is not None,
|
438 |
+
choices=model_predictions.get(FoldingModel.PROTENIX, []),
|
439 |
+
value=model_predictions.get(FoldingModel.PROTENIX, []),
|
440 |
+
)
|
441 |
+
|
442 |
+
return (
|
443 |
+
chai_predictions,
|
444 |
+
boltz_predictions,
|
445 |
+
protenix_predictions,
|
446 |
+
aligned_paths,
|
447 |
+
plddt_fig,
|
448 |
+
)
|