jfaustin AchilleSoulieID commited on
Commit
b26b7a0
·
verified ·
1 Parent(s): e967c14

improve model comparison (#10)

Browse files

- improve model comparion (6187a6421d0fce2de6e37381ac15baa56c12ea39)


Co-authored-by: Achille Soulie <[email protected]>

folding_studio_demo/app.py CHANGED
@@ -4,19 +4,20 @@ import logging
4
 
5
  import gradio as gr
6
  import pandas as pd
 
7
  from folding_studio_data_models import FoldingModel
8
  from gradio_molecule3d import Molecule3D
9
 
10
  from folding_studio_demo.correlate import (
11
- SCORE_COLUMNS,
12
  SCORE_COLUMN_NAMES,
 
 
13
  fake_predict_and_correlate,
 
14
  make_regression_plot,
15
- compute_correlation_data,
16
  plot_correlation_ranking,
17
- get_score_description
18
  )
19
- from folding_studio_demo.predict import predict, predict_comparison
20
 
21
  logger = logging.getLogger(__name__)
22
 
@@ -24,8 +25,8 @@ logger = logging.getLogger(__name__)
24
  MOLECULE_REPS = [
25
  {
26
  "model": 0,
27
- "chain": "",
28
- "resname": "",
29
  "style": "cartoon",
30
  "color": "alphafold",
31
  # "residue_range": "",
@@ -36,7 +37,6 @@ MOLECULE_REPS = [
36
  }
37
  ]
38
 
39
- DEFAULT_PROTEIN_SEQ = ">protein description\nMALWMRLLPLLALLALWGPDPAAA"
40
 
41
  MODEL_CHOICES = [
42
  # ("AlphaFold2", FoldingModel.AF2),
@@ -47,8 +47,24 @@ MODEL_CHOICES = [
47
  ("Protenix", FoldingModel.PROTENIX),
48
  ]
49
 
50
-
51
- def sequence_input() -> gr.Textbox:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  """Sequence input component.
53
 
54
  Returns:
@@ -56,10 +72,21 @@ def sequence_input() -> gr.Textbox:
56
  """
57
  sequence = gr.Textbox(
58
  label="Protein Sequence",
59
- value=DEFAULT_PROTEIN_SEQ,
60
  lines=2,
61
  placeholder="Enter a protein sequence or upload a FASTA file",
62
  )
 
 
 
 
 
 
 
 
 
 
 
 
63
  file_input = gr.File(
64
  label="Upload a FASTA file",
65
  file_types=[".fasta", ".fa"],
@@ -104,7 +131,7 @@ def simple_prediction(api_key: str) -> None:
104
  value=FoldingModel.BOLTZ,
105
  )
106
  with gr.Column():
107
- sequence = sequence_input()
108
 
109
  predict_btn = gr.Button(
110
  "Predict",
@@ -132,10 +159,9 @@ def model_comparison(api_key: str) -> None:
132
  """
133
 
134
  with gr.Row():
135
- models = gr.Dropdown(
136
  label="Model",
137
  choices=MODEL_CHOICES,
138
- multiselect=True,
139
  scale=0,
140
  min_width=300,
141
  value=[FoldingModel.BOLTZ, FoldingModel.CHAI, FoldingModel.PROTENIX],
@@ -149,22 +175,46 @@ def model_comparison(api_key: str) -> None:
149
  elem_id="compare-models-btn",
150
  variant="primary",
151
  )
152
-
 
 
 
153
  with gr.Row():
154
  mol_outputs = Molecule3D(
155
- label="Protein Structure",
156
- reps=MOLECULE_REPS,
157
- file_count="multiple",
158
  )
 
159
 
160
- # metrics_plot = gr.Plot(label="pLDDT")
 
 
161
 
162
  predict_btn.click(
163
  fn=predict_comparison,
164
  inputs=[sequence, api_key, models],
165
- outputs=[mol_outputs],
 
 
 
 
 
 
166
  )
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  def create_correlation_tab():
170
  gr.Markdown("# Correlation with experimental binding affinity data")
@@ -221,7 +271,7 @@ def create_correlation_tab():
221
  choices=["Spearman", "Pearson", "R²"],
222
  value="Spearman",
223
  label="Correlation Type",
224
- interactive=True
225
  )
226
  with gr.Row():
227
  correlation_ranking_plot = gr.Plot(label="Correlation ranking")
@@ -230,17 +280,24 @@ def create_correlation_tab():
230
  with gr.Row():
231
  # User can select the columns to display in the correlation plot
232
  correlation_column = gr.Dropdown(
233
- label="Score data to display", choices=SCORE_COLUMNS, multiselect=False, value=SCORE_COLUMNS[0]
 
 
 
234
  )
235
  # Add checkbox for log scale and update plot when either input changes
236
  with gr.Row():
237
- log_scale = gr.Checkbox(label="Display x-axis on logarithmic scale", value=False)
 
 
238
  with gr.Row():
239
- score_description = gr.Markdown(get_score_description(correlation_column.value))
 
 
240
  correlation_column.change(
241
  fn=lambda x: get_score_description(x),
242
  inputs=correlation_column,
243
- outputs=score_description
244
  )
245
  with gr.Column():
246
  correlation_plot = gr.Plot(label="Correlation with binding affinity")
@@ -252,10 +309,10 @@ def create_correlation_tab():
252
  inputs=[correlation_type],
253
  outputs=[prediction_dataframe, correlation_ranking_plot, correlation_plot],
254
  )
255
-
256
  def update_regression_plot(score, use_log):
257
  return make_regression_plot(spr_data_with_scores, score, use_log)
258
-
259
  def update_correlation_plot(correlation_type):
260
  logger.info(f"Updating correlation plot for {correlation_type}")
261
  corr_data = compute_correlation_data(spr_data_with_scores, SCORE_COLUMNS)
@@ -273,16 +330,15 @@ def create_correlation_tab():
273
  inputs=[correlation_type],
274
  outputs=correlation_ranking_plot,
275
  )
276
-
277
  log_scale.change(
278
  fn=update_regression_plot,
279
- inputs=[correlation_column, log_scale],
280
  outputs=correlation_plot,
281
  )
282
 
283
 
284
  def __main__():
285
-
286
  theme = gr.themes.Ocean(
287
  primary_hue="blue",
288
  secondary_hue="purple",
 
4
 
5
  import gradio as gr
6
  import pandas as pd
7
+ import plotly.graph_objects as go
8
  from folding_studio_data_models import FoldingModel
9
  from gradio_molecule3d import Molecule3D
10
 
11
  from folding_studio_demo.correlate import (
 
12
  SCORE_COLUMN_NAMES,
13
+ SCORE_COLUMNS,
14
+ compute_correlation_data,
15
  fake_predict_and_correlate,
16
+ get_score_description,
17
  make_regression_plot,
 
18
  plot_correlation_ranking,
 
19
  )
20
+ from folding_studio_demo.predict import filter_predictions, predict, predict_comparison
21
 
22
  logger = logging.getLogger(__name__)
23
 
 
25
  MOLECULE_REPS = [
26
  {
27
  "model": 0,
28
+ # "chain": "",
29
+ # "resname": "",
30
  "style": "cartoon",
31
  "color": "alphafold",
32
  # "residue_range": "",
 
37
  }
38
  ]
39
 
 
40
 
41
  MODEL_CHOICES = [
42
  # ("AlphaFold2", FoldingModel.AF2),
 
47
  ("Protenix", FoldingModel.PROTENIX),
48
  ]
49
 
50
+ DEFAULT_SEQ = "MALWMRLLPLLALLALWGPDPAAA"
51
+ MODEL_EXAMPLES = {
52
+ FoldingModel.BOLTZ: [
53
+ ["Monomer", f">A|protein\n{DEFAULT_SEQ}"],
54
+ ["Multimer", f">A|protein\n{DEFAULT_SEQ}\n>B|protein\n{DEFAULT_SEQ}"],
55
+ ],
56
+ FoldingModel.CHAI: [
57
+ ["Monomer", f">protein|name=A\n{DEFAULT_SEQ}"],
58
+ ["Multimer", f">protein|name=A\n{DEFAULT_SEQ}\n>protein|name=B\n{DEFAULT_SEQ}"],
59
+ ],
60
+ FoldingModel.PROTENIX: [
61
+ ["Monomer", f">A|protein\n{DEFAULT_SEQ}"],
62
+ ["Multimer", f">A|protein\n{DEFAULT_SEQ}\n>B|protein\n{DEFAULT_SEQ}"],
63
+ ],
64
+ }
65
+
66
+
67
+ def sequence_input(dropdown: gr.Dropdown | None = None) -> gr.Textbox:
68
  """Sequence input component.
69
 
70
  Returns:
 
72
  """
73
  sequence = gr.Textbox(
74
  label="Protein Sequence",
 
75
  lines=2,
76
  placeholder="Enter a protein sequence or upload a FASTA file",
77
  )
78
+ dummy = gr.Textbox(label="Complex type", visible=False)
79
+
80
+ examples = gr.Examples(
81
+ examples=MODEL_EXAMPLES[FoldingModel.BOLTZ],
82
+ inputs=[dummy, sequence],
83
+ )
84
+ if dropdown is not None:
85
+ dropdown.change(
86
+ fn=lambda x: gr.Dataset(samples=MODEL_EXAMPLES[x]),
87
+ inputs=[dropdown],
88
+ outputs=[examples.dataset],
89
+ )
90
  file_input = gr.File(
91
  label="Upload a FASTA file",
92
  file_types=[".fasta", ".fa"],
 
131
  value=FoldingModel.BOLTZ,
132
  )
133
  with gr.Column():
134
+ sequence = sequence_input(dropdown)
135
 
136
  predict_btn = gr.Button(
137
  "Predict",
 
159
  """
160
 
161
  with gr.Row():
162
+ models = gr.CheckboxGroup(
163
  label="Model",
164
  choices=MODEL_CHOICES,
 
165
  scale=0,
166
  min_width=300,
167
  value=[FoldingModel.BOLTZ, FoldingModel.CHAI, FoldingModel.PROTENIX],
 
175
  elem_id="compare-models-btn",
176
  variant="primary",
177
  )
178
+ with gr.Row():
179
+ chai_predictions = gr.CheckboxGroup(label="Chai", visible=False)
180
+ protenix_predictions = gr.CheckboxGroup(label="Protenix", visible=False)
181
+ boltz_predictions = gr.CheckboxGroup(label="Boltz", visible=False)
182
  with gr.Row():
183
  mol_outputs = Molecule3D(
184
+ label="Protein Structure", reps=MOLECULE_REPS, height=1000
 
 
185
  )
186
+ metrics_plot = gr.Plot(label="pLDDT")
187
 
188
+ # Store the initial predictions
189
+ aligned_paths = gr.State()
190
+ plddt_fig = gr.State()
191
 
192
  predict_btn.click(
193
  fn=predict_comparison,
194
  inputs=[sequence, api_key, models],
195
+ outputs=[
196
+ chai_predictions,
197
+ boltz_predictions,
198
+ protenix_predictions,
199
+ aligned_paths,
200
+ plddt_fig,
201
+ ],
202
  )
203
 
204
+ # Handle checkbox changes
205
+ for checkbox in [chai_predictions, boltz_predictions, protenix_predictions]:
206
+ checkbox.change(
207
+ fn=filter_predictions,
208
+ inputs=[
209
+ aligned_paths,
210
+ plddt_fig,
211
+ chai_predictions,
212
+ boltz_predictions,
213
+ protenix_predictions,
214
+ ],
215
+ outputs=[mol_outputs, metrics_plot],
216
+ )
217
+
218
 
219
  def create_correlation_tab():
220
  gr.Markdown("# Correlation with experimental binding affinity data")
 
271
  choices=["Spearman", "Pearson", "R²"],
272
  value="Spearman",
273
  label="Correlation Type",
274
+ interactive=True,
275
  )
276
  with gr.Row():
277
  correlation_ranking_plot = gr.Plot(label="Correlation ranking")
 
280
  with gr.Row():
281
  # User can select the columns to display in the correlation plot
282
  correlation_column = gr.Dropdown(
283
+ label="Score data to display",
284
+ choices=SCORE_COLUMNS,
285
+ multiselect=False,
286
+ value=SCORE_COLUMNS[0],
287
  )
288
  # Add checkbox for log scale and update plot when either input changes
289
  with gr.Row():
290
+ log_scale = gr.Checkbox(
291
+ label="Display x-axis on logarithmic scale", value=False
292
+ )
293
  with gr.Row():
294
+ score_description = gr.Markdown(
295
+ get_score_description(correlation_column.value)
296
+ )
297
  correlation_column.change(
298
  fn=lambda x: get_score_description(x),
299
  inputs=correlation_column,
300
+ outputs=score_description,
301
  )
302
  with gr.Column():
303
  correlation_plot = gr.Plot(label="Correlation with binding affinity")
 
309
  inputs=[correlation_type],
310
  outputs=[prediction_dataframe, correlation_ranking_plot, correlation_plot],
311
  )
312
+
313
  def update_regression_plot(score, use_log):
314
  return make_regression_plot(spr_data_with_scores, score, use_log)
315
+
316
  def update_correlation_plot(correlation_type):
317
  logger.info(f"Updating correlation plot for {correlation_type}")
318
  corr_data = compute_correlation_data(spr_data_with_scores, SCORE_COLUMNS)
 
330
  inputs=[correlation_type],
331
  outputs=correlation_ranking_plot,
332
  )
333
+
334
  log_scale.change(
335
  fn=update_regression_plot,
336
+ inputs=[correlation_column, log_scale],
337
  outputs=correlation_plot,
338
  )
339
 
340
 
341
  def __main__():
 
342
  theme = gr.themes.Ocean(
343
  primary_hue="blue",
344
  secondary_hue="purple",
folding_studio_demo/model_fasta_validators.py CHANGED
@@ -248,15 +248,15 @@ class ChaiFastaValidator(BaseFastaValidator):
248
  )
249
  seen_names.add(name)
250
  # validate sequence format
251
- sequence = str(record.seq).strip()
252
- if (
253
- entity_type in {EntityType.PEPTIDE, EntityType.PROTEIN}
254
- and not get_entity_type(sequence) == entity_type
255
- ):
256
- return (
257
- False,
258
- f"CHAI Validation Error: Sequence type mismatch. Expected '{entity_type}' but found '{get_entity_type(sequence)}'",
259
- )
260
 
261
  return True, None
262
 
 
248
  )
249
  seen_names.add(name)
250
  # validate sequence format
251
+ # sequence = str(record.seq).strip()
252
+ # if (
253
+ # entity_type in {EntityType.PEPTIDE, EntityType.PROTEIN}
254
+ # and not get_entity_type(sequence) == entity_type
255
+ # ):
256
+ # return (
257
+ # False,
258
+ # f"CHAI Validation Error: Sequence type mismatch. Expected '{entity_type}' but found '{get_entity_type(sequence)}'",
259
+ # )
260
 
261
  return True, None
262
 
folding_studio_demo/models.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Models for the Folding Studio API."""
2
+
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ from folding_studio.client import Client
11
+ from folding_studio.query import Query
12
+ from folding_studio.query.boltz import BoltzQuery
13
+ from folding_studio.query.chai import ChaiQuery
14
+ from folding_studio.query.protenix import ProtenixQuery
15
+
16
+ from folding_studio_demo.model_fasta_validators import (
17
+ BaseFastaValidator,
18
+ BoltzFastaValidator,
19
+ ChaiFastaValidator,
20
+ ProtenixFastaValidator,
21
+ )
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class AF3Model:
27
+ def __init__(
28
+ self, api_key: str, model_name: str, query: Query, validator: BaseFastaValidator
29
+ ):
30
+ self.api_key = api_key
31
+ self.model_name = model_name
32
+ self.query = query
33
+ self.validator = validator
34
+
35
+ def call(
36
+ self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
37
+ ) -> None:
38
+ """Predict protein structure from amino acid sequence using AF3 model.
39
+
40
+ Args:
41
+ seq_file (Path | str): Path to FASTA file containing amino acid sequence
42
+ output_dir (Path): Path to output directory
43
+ format_description (bool): Whether to format the description of the sequence
44
+ """
45
+ # Validate FASTA format before calling
46
+ is_valid, error_msg = self.check_file_description(seq_file)
47
+ if format_fasta and not is_valid:
48
+ logger.info("Invalid FASTA file format, forcing formatting...")
49
+ self.format_fasta(seq_file)
50
+ elif not is_valid:
51
+ logger.error(error_msg)
52
+ raise gr.Error(error_msg)
53
+
54
+ # Create a client using API key
55
+ logger.info("Authenticating client with API key")
56
+ client = Client.from_api_key(api_key=self.api_key)
57
+
58
+ # Define query
59
+ query: Query = self.query.from_file(path=seq_file, query_name="gradio")
60
+ query.save_parameters(output_dir)
61
+
62
+ logger.info("Payload: %s", query.payload)
63
+
64
+ # Send a request
65
+ logger.info(f"Sending {self.model_name} request to Folding Studio API")
66
+ response = client.send_request(
67
+ query, project_code=os.environ["FOLDING_PROJECT_CODE"]
68
+ )
69
+
70
+ # Access confidence data
71
+ logger.info("Confidence data: %s", response.confidence_data)
72
+
73
+ response.download_results(output_dir=output_dir, force=True, unzip=True)
74
+ logger.info("Results downloaded to %s", output_dir)
75
+
76
+ def format_fasta(self, seq_file: Path | str) -> None:
77
+ """Format sequence to FASTA format.
78
+
79
+ Args:
80
+ seq_file (Path | str): Path to FASTA file
81
+ """
82
+ formatted_fasta = self.validator.transform_fasta(seq_file)
83
+ with open(seq_file, "w") as f:
84
+ f.write(formatted_fasta)
85
+
86
+ def predictions(self, output_dir: Path) -> list[Path]:
87
+ """Get the path to the prediction.
88
+
89
+ Args:
90
+ output_dir (Path): Path to output directory
91
+
92
+ Returns:
93
+ list[Path]: List of paths to predictions
94
+ """
95
+ raise NotImplementedError()
96
+
97
+ def has_prediction(self, output_dir: Path) -> bool:
98
+ """Check if prediction exists in output directory."""
99
+ return len(self.predictions(output_dir)) > 0
100
+
101
+ def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]:
102
+ """Check if the file description is correct.
103
+
104
+ Args:
105
+ seq_file (Path | str): Path to FASTA file
106
+
107
+ Returns:
108
+ tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not
109
+ """
110
+
111
+ is_valid, error_msg = self.validator.is_valid_fasta(seq_file)
112
+ if not is_valid:
113
+ return False, error_msg
114
+
115
+ return True, None
116
+
117
+
118
+ class ChaiModel(AF3Model):
119
+ def __init__(self, api_key: str):
120
+ super().__init__(api_key, "Chai", ChaiQuery, ChaiFastaValidator())
121
+
122
+ def call(
123
+ self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
124
+ ) -> None:
125
+ """Predict protein structure from amino acid sequence using Chai model.
126
+
127
+ Args:
128
+ seq_file (Path | str): Path to FASTA file containing amino acid sequence
129
+ output_dir (Path): Path to output directory
130
+ format_fasta (bool): Whether to format the FASTA file
131
+ """
132
+ super().call(seq_file, output_dir, format_fasta)
133
+
134
+ def predictions(self, output_dir: Path) -> dict[Path, dict[str, Any]]:
135
+ """Get the path to the prediction."""
136
+ prediction = next(output_dir.rglob("pred.model_idx_[0-9].cif"), None)
137
+ if prediction is None:
138
+ return {}
139
+
140
+ cif_files = {
141
+ int(f.stem.split("model_idx_")[1]): f
142
+ for f in prediction.parent.glob("pred.model_idx_*.cif")
143
+ }
144
+
145
+ # Get all npz files and extract their indices
146
+ npz_files = {
147
+ int(f.stem.split("model_idx_")[1]): f
148
+ for f in prediction.parent.glob("scores.model_idx_*.npz")
149
+ }
150
+
151
+ # Find common indices and create pairs
152
+ common_indices = sorted(set(cif_files.keys()) & set(npz_files.keys()))
153
+
154
+ return {
155
+ idx: {"prediction_path": cif_files[idx], "metrics": np.load(npz_files[idx])}
156
+ for idx in common_indices
157
+ }
158
+
159
+
160
+ class ProtenixModel(AF3Model):
161
+ def __init__(self, api_key: str):
162
+ super().__init__(api_key, "Protenix", ProtenixQuery, ProtenixFastaValidator())
163
+
164
+ def call(
165
+ self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
166
+ ) -> None:
167
+ """Predict protein structure from amino acid sequence using Protenix model.
168
+
169
+ Args:
170
+ seq_file (Path | str): Path to FASTA file containing amino acid sequence
171
+ output_dir (Path): Path to output directory
172
+ format_fasta (bool): Whether to format the FASTA file
173
+ """
174
+ super().call(seq_file, output_dir, format_fasta)
175
+
176
+ def predictions(self, output_dir: Path) -> list[Path]:
177
+ """Get the path to the prediction."""
178
+ return list(output_dir.rglob("*_model_[0-9].cif"))
179
+
180
+
181
+ class BoltzModel(AF3Model):
182
+ def __init__(self, api_key: str):
183
+ super().__init__(api_key, "Boltz", BoltzQuery, BoltzFastaValidator())
184
+
185
+ def call(
186
+ self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
187
+ ) -> None:
188
+ """Predict protein structure from amino acid sequence using Boltz model.
189
+
190
+ Args:
191
+ seq_file (Path | str): Path to FASTA file containing amino acid sequence
192
+ output_dir (Path): Path to output directory
193
+ format_fasta (bool): Whether to format the FASTA file
194
+ """
195
+
196
+ super().call(seq_file, output_dir, format_fasta)
197
+
198
+ def predictions(self, output_dir: Path) -> list[Path]:
199
+ """Get the path to the prediction."""
200
+ prediction_paths = list(output_dir.rglob("*_model_[0-9].cif"))
201
+ return {
202
+ int(cif_path.stem[-1]): {
203
+ "prediction_path": cif_path,
204
+ "metrics": np.load(list(cif_path.parent.glob("plddt_*.npz"))[0]),
205
+ }
206
+ for cif_path in prediction_paths
207
+ }
folding_studio_demo/predict.py CHANGED
@@ -2,29 +2,17 @@
2
 
3
  import hashlib
4
  import logging
5
- import os
6
  from io import StringIO
7
  from pathlib import Path
8
- from typing import Any
9
 
10
  import gradio as gr
11
  import numpy as np
12
  import plotly.graph_objects as go
13
  from Bio import SeqIO
14
  from Bio.PDB import PDBIO, MMCIFParser, PDBParser, Superimposer
15
- from folding_studio.client import Client
16
- from folding_studio.query import Query
17
- from folding_studio.query.boltz import BoltzQuery
18
- from folding_studio.query.chai import ChaiQuery
19
- from folding_studio.query.protenix import ProtenixQuery
20
  from folding_studio_data_models import FoldingModel
21
 
22
- from folding_studio_demo.model_fasta_validators import (
23
- BaseFastaValidator,
24
- BoltzFastaValidator,
25
- ChaiFastaValidator,
26
- ProtenixFastaValidator,
27
- )
28
 
29
  logger = logging.getLogger(__name__)
30
 
@@ -34,6 +22,48 @@ SEQUENCE_DIR.mkdir(parents=True, exist_ok=True)
34
  OUTPUT_DIR = Path("output")
35
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None:
39
  """Convert a .cif file to .pdb format using Biopython.
@@ -52,29 +82,46 @@ def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None:
52
  io.save(pdb_path)
53
 
54
 
55
- def add_plddt_plot(plddt_vals: list[list[float]], model_name: str) -> go.Figure:
 
 
 
 
56
  """Create a plot of metrics."""
57
- visible = True
58
- plddt_traces = [
59
- go.Scatter(
60
- x=np.arange(len(plddt_val)),
61
- y=plddt_val,
62
- hovertemplate="<i>pLDDT</i>: %{y:.2f} <br><i>Residue index:</i> %{x}<br>",
63
- name=f"{model_name} {i}",
64
- visible=visible,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  )
66
- for i, plddt_val in enumerate(plddt_vals)
67
- ]
68
-
69
  plddt_fig = go.Figure(data=plddt_traces)
70
  plddt_fig.update_layout(
71
  title="pLDDT",
72
- xaxis_title="Residue index",
73
  yaxis_title="pLDDT",
74
  height=500,
75
  template="simple_white",
76
  legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
77
  )
 
78
  return plddt_fig
79
 
80
 
@@ -103,178 +150,12 @@ def _write_fasta_file(
103
  return seq_id, seq_file
104
 
105
 
106
- class AF3Model:
107
- def __init__(
108
- self, api_key: str, model_name: str, query: Query, validator: BaseFastaValidator
109
- ):
110
- self.api_key = api_key
111
- self.model_name = model_name
112
- self.query = query
113
- self.validator = validator
114
-
115
- def call(self, seq_file: Path | str, output_dir: Path) -> None:
116
- """Predict protein structure from amino acid sequence using AF3 model.
117
-
118
- Args:
119
- seq_file (Path | str): Path to FASTA file containing amino acid sequence
120
- output_dir (Path): Path to output directory
121
- """
122
- # Validate FASTA format before calling
123
- is_valid, error_msg = self.check_file_description(seq_file)
124
- if not is_valid:
125
- logger.error(error_msg)
126
- raise gr.Error(error_msg)
127
-
128
- # Create a client using API key
129
- logger.info("Authenticating client with API key")
130
- client = Client.from_api_key(api_key=self.api_key)
131
-
132
- # Define query
133
- query: Query = self.query.from_file(path=seq_file, query_name="gradio")
134
- query.save_parameters(output_dir)
135
-
136
- logger.info("Payload: %s", query.payload)
137
-
138
- # Send a request
139
- logger.info(f"Sending {self.model_name} request to Folding Studio API")
140
- response = client.send_request(
141
- query, project_code=os.environ["FOLDING_PROJECT_CODE"]
142
- )
143
-
144
- # Access confidence data
145
- logger.info("Confidence data: %s", response.confidence_data)
146
-
147
- response.download_results(output_dir=output_dir, force=True, unzip=True)
148
- logger.info("Results downloaded to %s", output_dir)
149
-
150
- def format_fasta(self, sequence: str) -> str:
151
- """Format sequence to FASTA format."""
152
- return f">{self.model_name}\n{sequence}"
153
-
154
- def predictions(self, output_dir: Path) -> list[Path]:
155
- """Get the path to the prediction."""
156
- raise NotImplementedError("Not implemented")
157
-
158
- def has_prediction(self, output_dir: Path) -> bool:
159
- """Check if prediction exists in output directory."""
160
- return len(self.predictions(output_dir)) > 0
161
-
162
- def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]:
163
- """Check if the file description is correct.
164
-
165
- Args:
166
- seq_file (Path | str): Path to FASTA file
167
-
168
- Returns:
169
- tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not
170
- """
171
-
172
- is_valid, error_msg = self.validator.is_valid_fasta(seq_file)
173
- if not is_valid:
174
- return False, error_msg
175
-
176
- return True, None
177
-
178
-
179
- class ChaiModel(AF3Model):
180
- def __init__(self, api_key: str):
181
- super().__init__(api_key, "Chai", ChaiQuery, ChaiFastaValidator())
182
-
183
- def call(self, seq_file: Path | str, output_dir: Path) -> None:
184
- """Predict protein structure from amino acid sequence using Chai model.
185
-
186
- Args:
187
- seq_file (Path | str): Path to FASTA file containing amino acid sequence
188
- output_dir (Path): Path to output directory
189
- """
190
- super().call(seq_file, output_dir)
191
-
192
- def _get_chai_paired_files(self, directory: Path) -> list[tuple[Path, Path]]:
193
- """Get pairs of .cif and .npz files with matching model indices.
194
-
195
- Args:
196
- directory (Path): Directory containing the prediction files
197
-
198
- Returns:
199
- list[tuple[Path, Path]]: List of tuples containing (cif_path, npz_path) pairs
200
- """
201
- # Get all cif files and extract their indices
202
-
203
- def predictions(self, output_dir: Path) -> dict[Path, dict[str, Any]]:
204
- """Get the path to the prediction."""
205
- prediction = next(output_dir.rglob("pred.model_idx_[0-9].cif"), None)
206
- if prediction is None:
207
- return {}
208
-
209
- cif_files = {
210
- int(f.stem.split("model_idx_")[1]): f
211
- for f in prediction.parent.glob("pred.model_idx_*.cif")
212
- }
213
-
214
- # Get all npz files and extract their indices
215
- npz_files = {
216
- int(f.stem.split("model_idx_")[1]): f
217
- for f in prediction.parent.glob("scores.model_idx_*.npz")
218
- }
219
-
220
- # Find common indices and create pairs
221
- common_indices = sorted(set(cif_files.keys()) & set(npz_files.keys()))
222
-
223
- return {
224
- idx: {"prediction_path": cif_files[idx], "metrics": np.load(npz_files[idx])}
225
- for idx in common_indices
226
- }
227
-
228
-
229
- class ProtenixModel(AF3Model):
230
- def __init__(self, api_key: str):
231
- super().__init__(api_key, "Protenix", ProtenixQuery, ProtenixFastaValidator())
232
-
233
- def call(self, seq_file: Path | str, output_dir: Path) -> None:
234
- """Predict protein structure from amino acid sequence using Protenix model.
235
-
236
- Args:
237
- seq_file (Path | str): Path to FASTA file containing amino acid sequence
238
- output_dir (Path): Path to output directory
239
- """
240
- super().call(seq_file, output_dir)
241
-
242
- def predictions(self, output_dir: Path) -> list[Path]:
243
- """Get the path to the prediction."""
244
- return list(output_dir.rglob("*_model_[0-9].cif"))
245
-
246
-
247
- class BoltzModel(AF3Model):
248
- def __init__(self, api_key: str):
249
- super().__init__(api_key, "Boltz", BoltzQuery, BoltzFastaValidator())
250
-
251
- def call(self, seq_file: Path | str, output_dir: Path) -> None:
252
- """Predict protein structure from amino acid sequence using Boltz model.
253
-
254
- Args:
255
- seq_file (Path | str): Path to FASTA file containing amino acid sequence
256
- output_dir (Path): Path to output directory
257
- """
258
-
259
- super().call(seq_file, output_dir)
260
-
261
- def predictions(self, output_dir: Path) -> list[Path]:
262
- """Get the path to the prediction."""
263
- prediction_paths = list(output_dir.rglob("*_model_[0-9].cif"))
264
- return {
265
- int(cif_path.stem[-1]): {
266
- "prediction_path": cif_path,
267
- "metrics": np.load(list(cif_path.parent.glob("plddt_*.npz"))[0]),
268
- }
269
- for cif_path in prediction_paths
270
- }
271
-
272
-
273
  def extract_plddt_from_cif(cif_path):
274
  structure = MMCIFParser().get_structure("structure", cif_path)
275
 
276
- # Dictionary to store pLDDT values per residue
277
  plddt_values = []
 
278
 
279
  # Iterate through all atoms
280
  for model in structure:
@@ -285,17 +166,27 @@ def extract_plddt_from_cif(cif_path):
285
  # The B-factor contains the pLDDT value
286
  plddt = residue["CA"].get_bfactor()
287
  plddt_values.append(plddt)
 
 
288
 
289
- return plddt_values
290
 
291
 
292
- def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str, str]:
 
 
 
 
 
 
293
  """Predict protein structure from amino acid sequence using Boltz model.
294
 
295
  Args:
296
  sequence (str): Amino acid sequence to predict structure for
297
  api_key (str): Folding API key
298
  model (FoldingModel): Folding model to use
 
 
299
 
300
  Returns:
301
  tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot
@@ -303,6 +194,7 @@ def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str,
303
  if not api_key:
304
  raise gr.Error("Missing API key, please enter a valid API key")
305
 
 
306
  # Set up unique output directory based on sequence hash
307
  seq_id, seq_file = _write_fasta_file(sequence)
308
  output_dir = OUTPUT_DIR / seq_id / model_type
@@ -319,15 +211,16 @@ def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str,
319
 
320
  # Check if prediction already exists
321
  if not model.has_prediction(output_dir):
322
- # Run Boltz prediction
 
323
  logger.info(f"Predicting {seq_id}")
324
- model.call(seq_file=seq_file, output_dir=output_dir)
325
  logger.info("Prediction done. Output directory: %s", output_dir)
326
  else:
 
327
  logger.info("Prediction already exists. Output directory: %s", output_dir)
328
 
329
- # output_dir = Path("boltz_results") # debug
330
-
331
  # Convert output CIF to PDB
332
  if not model.has_prediction(output_dir):
333
  raise gr.Error("No prediction found")
@@ -335,23 +228,34 @@ def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str,
335
  predictions = model.predictions(output_dir)
336
  pdb_paths = []
337
  model_plddt_vals = []
338
- for model_idx, prediction in predictions.items():
339
- cif_path = prediction["prediction_path"]
340
- logger.info(
341
- "CIF file: %s",
 
 
342
  )
 
 
343
 
344
  converted_pdb_path = str(
345
  output_dir / f"{model.model_name}_prediction_{model_idx}.pdb"
346
  )
347
  convert_cif_to_pdb(str(cif_path), str(converted_pdb_path))
348
- plddt_vals = extract_plddt_from_cif(cif_path)
349
  pdb_paths.append(converted_pdb_path)
350
  model_plddt_vals.append(plddt_vals)
351
- plddt_plot = add_plddt_plot(
352
- plddt_vals=model_plddt_vals, model_name=model.model_name
 
 
 
 
 
353
  )
354
- return pdb_paths, plddt_plot
 
 
355
 
356
 
357
  def align_structures(pdb_paths: list[str]) -> list[str]:
@@ -397,28 +301,148 @@ def align_structures(pdb_paths: list[str]) -> list[str]:
397
  return aligned_paths
398
 
399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  def predict_comparison(
401
- sequence: str, api_key: str, model_types: list[FoldingModel]
402
- ) -> tuple[str, str]:
403
- """Predict protein structure from amino acid sequence using Boltz model.
 
 
 
 
 
 
 
 
 
404
 
405
  Args:
406
  sequence (str): Amino acid sequence to predict structure for
407
  api_key (str): Folding API key
408
- model (FoldingModel): Folding model to use
 
409
 
410
  Returns:
411
- tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot
 
 
 
 
 
 
 
 
412
  """
413
  if not api_key:
414
  raise gr.Error("Missing API key, please enter a valid API key")
415
 
416
  # Set up unique output directory based on sequence hash
417
  pdb_paths = []
418
- for model_type in model_types:
419
- model_pdb_paths, _ = predict(sequence, api_key, model_type)
 
 
 
 
 
 
 
420
  pdb_paths += model_pdb_paths
 
 
421
 
 
422
  aligned_paths = align_structures(pdb_paths)
 
 
 
 
 
 
 
 
 
423
 
424
- return aligned_paths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import hashlib
4
  import logging
 
5
  from io import StringIO
6
  from pathlib import Path
 
7
 
8
  import gradio as gr
9
  import numpy as np
10
  import plotly.graph_objects as go
11
  from Bio import SeqIO
12
  from Bio.PDB import PDBIO, MMCIFParser, PDBParser, Superimposer
 
 
 
 
 
13
  from folding_studio_data_models import FoldingModel
14
 
15
+ from folding_studio_demo.models import BoltzModel, ChaiModel, ProtenixModel
 
 
 
 
 
16
 
17
  logger = logging.getLogger(__name__)
18
 
 
22
  OUTPUT_DIR = Path("output")
23
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
24
 
25
+ THREE_TO_ONE_LETTER = {
26
+ "ALA": "A",
27
+ "ARG": "R",
28
+ "ASN": "N",
29
+ "ASP": "D",
30
+ "CYS": "C",
31
+ "GLN": "Q",
32
+ "GLU": "E",
33
+ "GLY": "G",
34
+ "HIS": "H",
35
+ "ILE": "I",
36
+ "LEU": "L",
37
+ "LYS": "K",
38
+ "MET": "M",
39
+ "PHE": "F",
40
+ "PRO": "P",
41
+ "SER": "S",
42
+ "THR": "T",
43
+ "TRP": "W",
44
+ "TYR": "Y",
45
+ "VAL": "V",
46
+ "SEC": "U",
47
+ "PYL": "O",
48
+ "ASX": "B",
49
+ "GLX": "Z",
50
+ "XAA": "X",
51
+ "XLE": "J",
52
+ "UNK": "X",
53
+ }
54
+
55
+
56
+ def convert_to_one_letter(resname: str) -> str:
57
+ """Convert three-letter amino acid code to one-letter code.
58
+
59
+ Args:
60
+ resname (str): Three-letter amino acid code
61
+
62
+ Returns:
63
+ str: One-letter amino acid code
64
+ """
65
+ return THREE_TO_ONE_LETTER.get(resname, "X")
66
+
67
 
68
  def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None:
69
  """Convert a .cif file to .pdb format using Biopython.
 
82
  io.save(pdb_path)
83
 
84
 
85
+ def create_plddt_figure(
86
+ plddt_vals: list[list[float]],
87
+ model_name: str,
88
+ residue_codes: list[list[str]] = None,
89
+ ) -> go.Figure:
90
  """Create a plot of metrics."""
91
+ plddt_traces = []
92
+ for i, plddt_val in enumerate(plddt_vals):
93
+ # Create hover text with residue codes if available
94
+ if residue_codes and i < len(residue_codes):
95
+ hover_text = [
96
+ f"<i>pLDDT</i>: {plddt:.2f}<br><i>Residue:</i> {code} {idx}"
97
+ for idx, (plddt, code) in enumerate(zip(plddt_val, residue_codes[i]))
98
+ ]
99
+ else:
100
+ hover_text = [
101
+ f"<i>pLDDT</i>: {plddt:.2f}<br><i>Residue index:</i> {idx}"
102
+ for idx, plddt in enumerate(plddt_val)
103
+ ]
104
+
105
+ plddt_traces.append(
106
+ go.Scatter(
107
+ x=np.arange(len(plddt_val)),
108
+ y=plddt_val,
109
+ hovertemplate="%{text}<extra></extra>",
110
+ text=hover_text,
111
+ name=f"{model_name} {i}",
112
+ visible=True,
113
+ )
114
  )
 
 
 
115
  plddt_fig = go.Figure(data=plddt_traces)
116
  plddt_fig.update_layout(
117
  title="pLDDT",
118
+ xaxis_title="Residue",
119
  yaxis_title="pLDDT",
120
  height=500,
121
  template="simple_white",
122
  legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
123
  )
124
+
125
  return plddt_fig
126
 
127
 
 
150
  return seq_id, seq_file
151
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def extract_plddt_from_cif(cif_path):
154
  structure = MMCIFParser().get_structure("structure", cif_path)
155
 
156
+ # Lists to store pLDDT values and residue codes
157
  plddt_values = []
158
+ residue_codes = []
159
 
160
  # Iterate through all atoms
161
  for model in structure:
 
166
  # The B-factor contains the pLDDT value
167
  plddt = residue["CA"].get_bfactor()
168
  plddt_values.append(plddt)
169
+ # Get residue code and convert to one-letter code
170
+ residue_codes.append(convert_to_one_letter(residue.get_resname()))
171
 
172
+ return plddt_values, residue_codes
173
 
174
 
175
+ def predict(
176
+ sequence: str,
177
+ api_key: str,
178
+ model_type: FoldingModel,
179
+ format_fasta: bool = False,
180
+ progress=gr.Progress(),
181
+ ) -> tuple[str, str]:
182
  """Predict protein structure from amino acid sequence using Boltz model.
183
 
184
  Args:
185
  sequence (str): Amino acid sequence to predict structure for
186
  api_key (str): Folding API key
187
  model (FoldingModel): Folding model to use
188
+ format_fasta (bool): Whether to format the FASTA file
189
+ progress (gr.Progress): Gradio progress tracker
190
 
191
  Returns:
192
  tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot
 
194
  if not api_key:
195
  raise gr.Error("Missing API key, please enter a valid API key")
196
 
197
+ progress(0, desc="Setting up prediction...")
198
  # Set up unique output directory based on sequence hash
199
  seq_id, seq_file = _write_fasta_file(sequence)
200
  output_dir = OUTPUT_DIR / seq_id / model_type
 
211
 
212
  # Check if prediction already exists
213
  if not model.has_prediction(output_dir):
214
+ progress(0.2, desc="Running prediction...")
215
+ # Run prediction
216
  logger.info(f"Predicting {seq_id}")
217
+ model.call(seq_file=seq_file, output_dir=output_dir, format_fasta=format_fasta)
218
  logger.info("Prediction done. Output directory: %s", output_dir)
219
  else:
220
+ progress(0.2, desc="Using existing prediction...")
221
  logger.info("Prediction already exists. Output directory: %s", output_dir)
222
 
223
+ progress(0.4, desc="Processing results...")
 
224
  # Convert output CIF to PDB
225
  if not model.has_prediction(output_dir):
226
  raise gr.Error("No prediction found")
 
228
  predictions = model.predictions(output_dir)
229
  pdb_paths = []
230
  model_plddt_vals = []
231
+ model_residue_codes = []
232
+
233
+ total_predictions = len(predictions)
234
+ for i, (model_idx, prediction) in enumerate(predictions.items()):
235
+ progress(
236
+ 0.4 + (0.4 * i / total_predictions), desc=f"Converting model {model_idx}..."
237
  )
238
+ cif_path = prediction["prediction_path"]
239
+ logger.info(f"CIF file: {cif_path}")
240
 
241
  converted_pdb_path = str(
242
  output_dir / f"{model.model_name}_prediction_{model_idx}.pdb"
243
  )
244
  convert_cif_to_pdb(str(cif_path), str(converted_pdb_path))
245
+ plddt_vals, residue_codes = extract_plddt_from_cif(cif_path)
246
  pdb_paths.append(converted_pdb_path)
247
  model_plddt_vals.append(plddt_vals)
248
+ model_residue_codes.append(residue_codes)
249
+
250
+ progress(0.8, desc="Generating plots...")
251
+ plddt_fig = create_plddt_figure(
252
+ plddt_vals=model_plddt_vals,
253
+ model_name=model.model_name,
254
+ residue_codes=model_residue_codes,
255
  )
256
+
257
+ progress(1.0, desc="Done!")
258
+ return pdb_paths, plddt_fig
259
 
260
 
261
  def align_structures(pdb_paths: list[str]) -> list[str]:
 
301
  return aligned_paths
302
 
303
 
304
+ def filter_predictions(
305
+ aligned_paths: list[str],
306
+ plddt_fig: go.Figure,
307
+ chai_selected: list[int],
308
+ boltz_selected: list[int],
309
+ protenix_selected: list[int],
310
+ ) -> tuple[list[str], go.Figure]:
311
+ """Filter predictions based on selected checkboxes.
312
+
313
+ Args:
314
+ aligned_paths (list[str]): List of aligned PDB paths
315
+ plddt_fig (go.Figure): Original pLDDT plot
316
+ chai_selected (list[int]): Selected Chai model indices
317
+ boltz_selected (list[int]): Selected Boltz model indices
318
+ protenix_selected (list[int]): Selected Protenix model indices
319
+ model_predictions (dict[FoldingModel, list[int]]): Dictionary mapping models to their prediction indices
320
+
321
+ Returns:
322
+ tuple[list[str], go.Figure]: Filtered PDB paths and updated pLDDT plot
323
+ """
324
+ # Create a new figure with only selected traces
325
+ filtered_fig = go.Figure()
326
+
327
+ # Keep track of which traces to show
328
+ visible_paths = []
329
+
330
+ # Helper function to check if a trace should be visible
331
+ def should_show_trace(trace_name: str) -> bool:
332
+ model_name = trace_name.split()[0]
333
+ model_idx = int(trace_name.split()[1])
334
+
335
+ if model_name == "Chai" and model_idx in chai_selected:
336
+ return True
337
+ if model_name == "Boltz" and model_idx in boltz_selected:
338
+ return True
339
+ if model_name == "Protenix" and model_idx in protenix_selected:
340
+ return True
341
+ return False
342
+
343
+ # Filter traces and paths
344
+ for i, trace in enumerate(plddt_fig.data):
345
+ if should_show_trace(trace.name):
346
+ filtered_fig.add_trace(trace)
347
+ visible_paths.append(aligned_paths[i])
348
+
349
+ # Update layout
350
+ filtered_fig.update_layout(
351
+ title="pLDDT",
352
+ xaxis_title="Residue index",
353
+ yaxis_title="pLDDT",
354
+ height=500,
355
+ template="simple_white",
356
+ legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
357
+ )
358
+
359
+ return visible_paths, filtered_fig
360
+
361
+
362
  def predict_comparison(
363
+ sequence: str, api_key: str, model_types: list[FoldingModel], progress=gr.Progress()
364
+ ) -> tuple[
365
+ list[str],
366
+ go.Figure,
367
+ gr.CheckboxGroup,
368
+ gr.CheckboxGroup,
369
+ gr.CheckboxGroup,
370
+ list[str],
371
+ go.Figure,
372
+ dict,
373
+ ]:
374
+ """Predict protein structure from amino acid sequence using multiple models.
375
 
376
  Args:
377
  sequence (str): Amino acid sequence to predict structure for
378
  api_key (str): Folding API key
379
+ model_types (list[FoldingModel]): List of folding models to use
380
+ progress (gr.Progress): Gradio progress tracker
381
 
382
  Returns:
383
+ tuple containing:
384
+ - list[str]: Aligned PDB paths
385
+ - go.Figure: pLDDT plot
386
+ - gr.CheckboxGroup: Chai predictions checkbox group
387
+ - gr.CheckboxGroup: Boltz predictions checkbox group
388
+ - gr.CheckboxGroup: Protenix predictions checkbox group
389
+ - list[str]: Original PDB paths
390
+ - go.Figure: Original pLDDT plot
391
+ - dict: Model predictions mapping
392
  """
393
  if not api_key:
394
  raise gr.Error("Missing API key, please enter a valid API key")
395
 
396
  # Set up unique output directory based on sequence hash
397
  pdb_paths = []
398
+ plddt_traces = []
399
+ total_models = len(model_types)
400
+ model_predictions = {}
401
+
402
+ for i, model_type in enumerate(model_types):
403
+ progress(i / total_models, desc=f"Running {model_type} prediction...")
404
+ model_pdb_paths, model_plddt_traces = predict(
405
+ sequence, api_key, model_type, format_fasta=True
406
+ )
407
  pdb_paths += model_pdb_paths
408
+ plddt_traces += model_plddt_traces.data
409
+ model_predictions[model_type] = [int(Path(p).stem[-1]) for p in model_pdb_paths]
410
 
411
+ progress(0.9, desc="Aligning structures...")
412
  aligned_paths = align_structures(pdb_paths)
413
+ plddt_fig = go.Figure(data=plddt_traces)
414
+ plddt_fig.update_layout(
415
+ title="pLDDT",
416
+ xaxis_title="Residue index",
417
+ yaxis_title="pLDDT",
418
+ height=500,
419
+ template="simple_white",
420
+ legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
421
+ )
422
 
423
+ progress(1.0, desc="Done!")
424
+
425
+ # Create checkbox groups for each model type
426
+ chai_predictions = gr.CheckboxGroup(
427
+ visible=model_predictions.get(FoldingModel.CHAI) is not None,
428
+ choices=model_predictions.get(FoldingModel.CHAI, []),
429
+ value=model_predictions.get(FoldingModel.CHAI, []),
430
+ )
431
+ boltz_predictions = gr.CheckboxGroup(
432
+ visible=model_predictions.get(FoldingModel.BOLTZ) is not None,
433
+ choices=model_predictions.get(FoldingModel.BOLTZ, []),
434
+ value=model_predictions.get(FoldingModel.BOLTZ, []),
435
+ )
436
+ protenix_predictions = gr.CheckboxGroup(
437
+ visible=model_predictions.get(FoldingModel.PROTENIX) is not None,
438
+ choices=model_predictions.get(FoldingModel.PROTENIX, []),
439
+ value=model_predictions.get(FoldingModel.PROTENIX, []),
440
+ )
441
+
442
+ return (
443
+ chai_predictions,
444
+ boltz_predictions,
445
+ protenix_predictions,
446
+ aligned_paths,
447
+ plddt_fig,
448
+ )