AchilleSoulieID commited on
Commit
3886d2a
·
1 Parent(s): 250a4a2

big update

Browse files
folding-studio/folding_studio/api_call/predict/simple_predict.py CHANGED
@@ -23,6 +23,7 @@ def single_job_prediction(
23
  fasta_file: Path,
24
  parameters: AF2Parameters | OpenFoldParameters | None = None,
25
  project_code: str | None = None,
 
26
  *,
27
  ignore_cache: bool = False,
28
  **kwargs,
@@ -74,7 +75,7 @@ def single_job_prediction(
74
  if parameters.templates_masks_file
75
  else None,
76
  )
77
- _ = custom_files.upload()
78
 
79
  params = parameters.model_dump(mode="json")
80
  pdb_ids, _ = partition_template_pdb_from_file(
@@ -107,7 +108,7 @@ def single_job_prediction(
107
  response = requests.post(
108
  url,
109
  data=params,
110
- headers=get_auth_headers(),
111
  files=[("fasta_file", fasta_file.open("rb"))],
112
  params={"project_code": project_code},
113
  timeout=REQUEST_TIMEOUT,
 
23
  fasta_file: Path,
24
  parameters: AF2Parameters | OpenFoldParameters | None = None,
25
  project_code: str | None = None,
26
+ api_key: str | None = None,
27
  *,
28
  ignore_cache: bool = False,
29
  **kwargs,
 
75
  if parameters.templates_masks_file
76
  else None,
77
  )
78
+ _ = custom_files.upload(api_key=api_key)
79
 
80
  params = parameters.model_dump(mode="json")
81
  pdb_ids, _ = partition_template_pdb_from_file(
 
108
  response = requests.post(
109
  url,
110
  data=params,
111
+ headers=get_auth_headers(api_key),
112
  files=[("fasta_file", fasta_file.open("rb"))],
113
  params={"project_code": project_code},
114
  timeout=REQUEST_TIMEOUT,
folding-studio/folding_studio/commands/experiment.py CHANGED
@@ -35,6 +35,7 @@ def _download_file_from_signed_url(
35
  output: Path,
36
  force: bool,
37
  unzip: bool = False,
 
38
  ) -> None:
39
  """Download a zip file from an experiment id.
40
 
@@ -71,7 +72,7 @@ def _download_file_from_signed_url(
71
  )
72
  raise typer.Exit(code=1)
73
 
74
- headers = get_auth_headers()
75
  url = API_URL + endpoint
76
 
77
  response = requests.get(
@@ -104,9 +105,10 @@ def _download_file_from_signed_url(
104
  @app.command()
105
  def status(
106
  exp_id: Annotated[str, experiment_ID_argument],
 
107
  ):
108
  """Get an experiment status."""
109
- headers = get_auth_headers()
110
  url = API_URL + "getExperimentStatus"
111
  response = requests.get(
112
  url,
@@ -224,6 +226,7 @@ def features(
224
  @app.command()
225
  def results(
226
  exp_id: Annotated[str, experiment_ID_argument],
 
227
  output: Annotated[
228
  Optional[Path],
229
  typer.Option(
@@ -254,6 +257,7 @@ def results(
254
  output=output,
255
  force=force,
256
  unzip=unzip,
 
257
  )
258
 
259
 
 
35
  output: Path,
36
  force: bool,
37
  unzip: bool = False,
38
+ api_key: str | None = None,
39
  ) -> None:
40
  """Download a zip file from an experiment id.
41
 
 
72
  )
73
  raise typer.Exit(code=1)
74
 
75
+ headers = get_auth_headers(api_key)
76
  url = API_URL + endpoint
77
 
78
  response = requests.get(
 
105
  @app.command()
106
  def status(
107
  exp_id: Annotated[str, experiment_ID_argument],
108
+ api_key: Annotated[str, typer.Option("--api-key", "-k")],
109
  ):
110
  """Get an experiment status."""
111
+ headers = get_auth_headers(api_key)
112
  url = API_URL + "getExperimentStatus"
113
  response = requests.get(
114
  url,
 
226
  @app.command()
227
  def results(
228
  exp_id: Annotated[str, experiment_ID_argument],
229
+ api_key: Annotated[str, typer.Option("--api-key", "-k")],
230
  output: Annotated[
231
  Optional[Path],
232
  typer.Option(
 
257
  output=output,
258
  force=force,
259
  unzip=unzip,
260
+ api_key=api_key,
261
  )
262
 
263
 
folding-studio/folding_studio/utils/data_model.py CHANGED
@@ -207,7 +207,7 @@ class PredictRequestCustomFiles(BaseModel):
207
  f"Unsupported file type {batch_jobs_file.suffix}: {batch_jobs_file}"
208
  )
209
 
210
- def upload(self) -> None:
211
  """Upload local custom paths to GCP through an API request.
212
  Returns:
213
  A dict mapping local to uploaded files path.
@@ -218,7 +218,7 @@ class PredictRequestCustomFiles(BaseModel):
218
 
219
  local_to_uploaded = {}
220
 
221
- headers = get_auth_headers()
222
  if len(self.templates) > 0:
223
  _, templates_to_upload = partition_template_pdb_from_file(
224
  custom_templates=self.templates
 
207
  f"Unsupported file type {batch_jobs_file.suffix}: {batch_jobs_file}"
208
  )
209
 
210
+ def upload(self, api_key: str | None = None) -> None:
211
  """Upload local custom paths to GCP through an API request.
212
  Returns:
213
  A dict mapping local to uploaded files path.
 
218
 
219
  local_to_uploaded = {}
220
 
221
+ headers = get_auth_headers(api_key)
222
  if len(self.templates) > 0:
223
  _, templates_to_upload = partition_template_pdb_from_file(
224
  custom_templates=self.templates
folding-studio/folding_studio/utils/headers.py CHANGED
@@ -4,7 +4,7 @@ from folding_studio.config import FOLDING_API_KEY
4
  from folding_studio.utils.gcp import get_id_token
5
 
6
 
7
- def get_auth_headers() -> dict[str, str]:
8
  """
9
  Create authentication headers based on available credentials.
10
 
@@ -14,6 +14,9 @@ def get_auth_headers() -> dict[str, str]:
14
  Returns:
15
  dict: Authentication headers for API requests.
16
  """
 
 
 
17
  if FOLDING_API_KEY:
18
  return {"X-API-Key": FOLDING_API_KEY}
19
 
 
4
  from folding_studio.utils.gcp import get_id_token
5
 
6
 
7
+ def get_auth_headers(api_key: str | None = None) -> dict[str, str]:
8
  """
9
  Create authentication headers based on available credentials.
10
 
 
14
  Returns:
15
  dict: Authentication headers for API requests.
16
  """
17
+ if api_key is not None:
18
+ return {"X-API-Key": api_key}
19
+
20
  if FOLDING_API_KEY:
21
  return {"X-API-Key": FOLDING_API_KEY}
22
 
folding_studio_demo/app.py CHANGED
@@ -4,7 +4,6 @@ import logging
4
 
5
  import gradio as gr
6
  import pandas as pd
7
- import plotly.graph_objects as go
8
  from folding_studio_data_models import FoldingModel
9
  from gradio_molecule3d import Molecule3D
10
 
@@ -47,30 +46,12 @@ MODEL_CHOICES = [
47
  ("Protenix", FoldingModel.PROTENIX),
48
  ]
49
 
50
- DEFAULT_SEQ = "MALWMRLLPLLALLALWGPDPAAA"
51
- MODEL_EXAMPLES = {
52
- FoldingModel.AF2: [
53
- ["Monomer", f">A\n{DEFAULT_SEQ}"],
54
- ["Multimer", f">A\n{DEFAULT_SEQ}\n>B\n{DEFAULT_SEQ}"],
55
- ],
56
- FoldingModel.OPENFOLD: [
57
- ["Monomer", f">A\n{DEFAULT_SEQ}"],
58
- ["Multimer", f">A\n{DEFAULT_SEQ}\n>B\n{DEFAULT_SEQ}"],
59
- ],
60
- FoldingModel.SOLOSEQ: [["Monomer", f">A\n{DEFAULT_SEQ}"]],
61
- FoldingModel.BOLTZ: [
62
- ["Monomer", f">A|protein\n{DEFAULT_SEQ}"],
63
- ["Multimer", f">A|protein\n{DEFAULT_SEQ}\n>B|protein\n{DEFAULT_SEQ}"],
64
- ],
65
- FoldingModel.CHAI: [
66
- ["Monomer", f">protein|name=A\n{DEFAULT_SEQ}"],
67
- ["Multimer", f">protein|name=A\n{DEFAULT_SEQ}\n>protein|name=B\n{DEFAULT_SEQ}"],
68
- ],
69
- FoldingModel.PROTENIX: [
70
- ["Monomer", f">A|protein\n{DEFAULT_SEQ}"],
71
- ["Multimer", f">A|protein\n{DEFAULT_SEQ}\n>B|protein\n{DEFAULT_SEQ}"],
72
- ],
73
- }
74
 
75
 
76
  def sequence_input(dropdown: gr.Dropdown | None = None) -> gr.Textbox:
@@ -79,31 +60,43 @@ def sequence_input(dropdown: gr.Dropdown | None = None) -> gr.Textbox:
79
  Returns:
80
  gr.Textbox: Sequence input component
81
  """
82
- with gr.Row(equal_height=True):
83
- with gr.Column():
84
- sequence = gr.Textbox(
85
- label="Protein Sequence",
86
- lines=2,
87
- placeholder="Enter a protein sequence or upload a FASTA file",
88
- )
89
- dummy = gr.Textbox(label="Complex type", visible=False)
90
-
91
- examples = gr.Examples(
92
- examples=MODEL_EXAMPLES[FoldingModel.BOLTZ],
93
- inputs=[dummy, sequence],
94
- )
95
- file_input = gr.File(
96
- label="Upload a FASTA file",
97
- file_types=[".fasta", ".fa"],
98
- scale=0,
99
- )
 
 
100
 
101
- if dropdown is not None:
102
- dropdown.change(
103
- fn=lambda x: gr.Dataset(samples=MODEL_EXAMPLES[x]),
104
- inputs=[dropdown],
105
- outputs=[examples.dataset],
106
- )
 
 
 
 
 
 
 
 
 
 
107
 
108
  def _process_file(file: gr.File | None) -> gr.Textbox:
109
  if file is None:
@@ -158,7 +151,7 @@ def simple_prediction(api_key: str) -> None:
158
  metrics_plot = gr.Plot(label="pLDDT")
159
 
160
  predict_btn.click(
161
- fn=predict,
162
  inputs=[sequence, api_key, dropdown],
163
  outputs=[mol_output, metrics_plot],
164
  )
@@ -174,13 +167,12 @@ def model_comparison(api_key: str) -> None:
174
  """
175
  ## Compare Folding Models
176
 
177
- Select multiple models to compare their predictions on your protein sequence.
178
- You can either enter the sequence directly or upload a FASTA file.
179
 
180
- The selected models will run in parallel and generate:
181
- - 3D structures of your protein that you can visualize and compare
182
- - pLDDT confidence scores plotted for each residue
183
-
184
  """
185
  )
186
  with gr.Row():
@@ -188,7 +180,7 @@ def model_comparison(api_key: str) -> None:
188
  label="Model",
189
  choices=MODEL_CHOICES,
190
  scale=0,
191
- min_width=300,
192
  value=[FoldingModel.BOLTZ, FoldingModel.CHAI, FoldingModel.PROTENIX],
193
  )
194
  with gr.Column():
@@ -201,12 +193,28 @@ def model_comparison(api_key: str) -> None:
201
  variant="primary",
202
  )
203
  with gr.Row():
204
- af2_predictions = gr.CheckboxGroup(label="AlphaFold2", visible=False)
205
- openfold_predictions = gr.CheckboxGroup(label="OpenFold", visible=False)
206
- solo_predictions = gr.CheckboxGroup(label="SoloSeq", visible=False)
207
- chai_predictions = gr.CheckboxGroup(label="Chai", visible=False)
208
- protenix_predictions = gr.CheckboxGroup(label="Protenix", visible=False)
209
- boltz_predictions = gr.CheckboxGroup(label="Boltz", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  with gr.Row():
211
  mol_outputs = Molecule3D(
212
  label="Protein Structure", reps=MOLECULE_REPS, height=1000
@@ -306,7 +314,7 @@ def create_antibody_discovery_tab():
306
  "Antigen Sequence",
307
  ]
308
  # Display dataframe with floating point values rounded to 2 decimal places
309
- spr_data = gr.DataFrame(
310
  value=spr_data_with_scores[columns].round(2),
311
  label="Experimental Antibody-Antigen Binding Affinity Data",
312
  )
@@ -350,7 +358,6 @@ def create_antibody_discovery_tab():
350
  correlation_ranking_plot = gr.Plot(label="Correlation ranking")
351
  with gr.Row(visible=False) as regression_row:
352
  with gr.Column(scale=0):
353
-
354
  # User can select the columns to display in the correlation plot
355
  correlation_column = gr.Dropdown(
356
  label="Score data to display",
@@ -375,7 +382,7 @@ def create_antibody_discovery_tab():
375
  spr_data_with_scores, SCORE_COLUMNS, ["Antibody Name", "KD (nM)"]
376
  ),
377
  gr.Row(visible=True),
378
- gr.Row(visible=True)
379
  ),
380
  inputs=[correlation_type],
381
  outputs=[
@@ -391,7 +398,9 @@ def create_antibody_discovery_tab():
391
  logger.info(f"Updating correlation plot for {correlation_type}")
392
  corr_data = compute_correlation_data(spr_data_with_scores, SCORE_COLUMNS)
393
  logger.info(f"Correlation data: {corr_data}")
394
- corr_ranking_plot = plot_correlation_ranking(corr_data, correlation_type, kd_col="KD (nM)" if not use_log else "log_kd")
 
 
395
  regression_plot = make_regression_plot(spr_data_with_scores, score, use_log)
396
  return regression_plot, corr_ranking_plot
397
 
@@ -426,14 +435,21 @@ def __main__():
426
  Folding Studio is a platform for protein structure prediction.
427
  It uses the latest AI-powered folding models to predict the structure of a protein.
428
 
429
- Available models are : AlphaFold2, OpenFold, SoloSeq, Boltz-1, Chai and Protenix.
430
-
431
- ## API Key
432
- To use the Folding Studio API, you need to provide an API key.
433
- You can get your API key by asking to the Folding Studio team.
434
  """
435
  )
436
- api_key = gr.Textbox(label="Folding Studio API Key", type="password")
 
 
 
 
 
 
 
 
 
 
 
437
  gr.Markdown("## Demo Usage")
438
  with gr.Tab("🚀 Basic Folding"):
439
  simple_prediction(api_key)
 
4
 
5
  import gradio as gr
6
  import pandas as pd
 
7
  from folding_studio_data_models import FoldingModel
8
  from gradio_molecule3d import Molecule3D
9
 
 
46
  ("Protenix", FoldingModel.PROTENIX),
47
  ]
48
 
49
+ MONOMER_SEQ_EXAMPLE = ">A|protein\nMALWMRLLPLLALLALWGPDPAAA"
50
+ MULTIMER_SEQ_EXAMPLE = ">A|protein\nSQIPASEQETLVRPKPLLLKLLKSVGAQKDTYTMKEVLFYLGQYIMTKRLYDAAQQHIVYCSNDLLGDLFGVPSFSVKEHRKIYTMIYRNLVVVNQQESSDSGTSVSEN\n>B|protein\nSQETFSDLWKLLPEN"
51
+ EXAMPLES = [
52
+ ["Monomer", MONOMER_SEQ_EXAMPLE],
53
+ ["Multimer", MULTIMER_SEQ_EXAMPLE],
54
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  def sequence_input(dropdown: gr.Dropdown | None = None) -> gr.Textbox:
 
60
  Returns:
61
  gr.Textbox: Sequence input component
62
  """
63
+ with gr.Column():
64
+ with gr.Row():
65
+ with gr.Row():
66
+ with gr.Column():
67
+ sequence = gr.Textbox(
68
+ label="Protein Sequence",
69
+ placeholder="Enter a protein sequence or upload a FASTA file",
70
+ value=MONOMER_SEQ_EXAMPLE,
71
+ lines=5,
72
+ )
73
+ gr.Markdown(
74
+ "Select an example below, enter a sequence manually or upload a FASTA file."
75
+ )
76
+
77
+ file_input = gr.File(
78
+ label="Upload a FASTA file",
79
+ file_types=[".fasta", ".fa"],
80
+ scale=0,
81
+ height=150,
82
+ )
83
 
84
+ with gr.Row(equal_height=True):
85
+ with gr.Column():
86
+ with gr.Row():
87
+ gr.Markdown("**Monomer Example:**")
88
+ gr.Markdown("**Multimer Example:**")
89
+ with gr.Row():
90
+ gr.Markdown("```\n" + MONOMER_SEQ_EXAMPLE + "\n```")
91
+ gr.Markdown("```\n" + MULTIMER_SEQ_EXAMPLE + "\n```")
92
+ with gr.Row():
93
+ gr.Button("Load Monomer Example", size="md").click(
94
+ fn=lambda: MONOMER_SEQ_EXAMPLE,
95
+ outputs=[sequence],
96
+ )
97
+ gr.Button("Load Multimer Example", size="md").click(
98
+ fn=lambda: MULTIMER_SEQ_EXAMPLE, outputs=[sequence]
99
+ )
100
 
101
  def _process_file(file: gr.File | None) -> gr.Textbox:
102
  if file is None:
 
151
  metrics_plot = gr.Plot(label="pLDDT")
152
 
153
  predict_btn.click(
154
+ fn=lambda x, y, z: predict(x, y, z, format_fasta=True),
155
  inputs=[sequence, api_key, dropdown],
156
  outputs=[mol_output, metrics_plot],
157
  )
 
167
  """
168
  ## Compare Folding Models
169
 
170
+ This tab allows you to compare predictions from multiple protein folding models side by side.
171
+ Follow these steps to get started:
172
 
173
+ 1. **Select Models**: Choose one or more models from the list on the left
174
+ 2. **Input Sequence** : Either select an example sequence, enter your protein sequence directly in the text box or upload a FASTA file.
175
+ 3. **Run Comparison**: Click "Compare Models" to start the prediction
 
176
  """
177
  )
178
  with gr.Row():
 
180
  label="Model",
181
  choices=MODEL_CHOICES,
182
  scale=0,
183
+ min_width=150,
184
  value=[FoldingModel.BOLTZ, FoldingModel.CHAI, FoldingModel.PROTENIX],
185
  )
186
  with gr.Column():
 
193
  variant="primary",
194
  )
195
  with gr.Row():
196
+ with gr.Column():
197
+ gr.Markdown(
198
+ """
199
+ ### Understanding the Outputs:
200
+ - **3D Structure**: The molecular viewer shows the predicted protein structure
201
+ - **pLDDT Score**: A confidence score (0-100) for each residue:
202
+ - Very high (>90): Highly accurate
203
+ - Confident (70-90): Good accuracy
204
+ - Low (50-70): Limited accuracy
205
+ - Very low (<50): Poor accuracy
206
+ """
207
+ )
208
+ gr.Markdown(
209
+ "### Model Predictions\nUse the checkboxes to toggle which model predictions to compare:"
210
+ )
211
+ with gr.Row():
212
+ af2_predictions = gr.CheckboxGroup(label="AlphaFold2", visible=False)
213
+ openfold_predictions = gr.CheckboxGroup(label="OpenFold", visible=False)
214
+ solo_predictions = gr.CheckboxGroup(label="SoloSeq", visible=False)
215
+ chai_predictions = gr.CheckboxGroup(label="Chai", visible=False)
216
+ protenix_predictions = gr.CheckboxGroup(label="Protenix", visible=False)
217
+ boltz_predictions = gr.CheckboxGroup(label="Boltz", visible=False)
218
  with gr.Row():
219
  mol_outputs = Molecule3D(
220
  label="Protein Structure", reps=MOLECULE_REPS, height=1000
 
314
  "Antigen Sequence",
315
  ]
316
  # Display dataframe with floating point values rounded to 2 decimal places
317
+ gr.DataFrame(
318
  value=spr_data_with_scores[columns].round(2),
319
  label="Experimental Antibody-Antigen Binding Affinity Data",
320
  )
 
358
  correlation_ranking_plot = gr.Plot(label="Correlation ranking")
359
  with gr.Row(visible=False) as regression_row:
360
  with gr.Column(scale=0):
 
361
  # User can select the columns to display in the correlation plot
362
  correlation_column = gr.Dropdown(
363
  label="Score data to display",
 
382
  spr_data_with_scores, SCORE_COLUMNS, ["Antibody Name", "KD (nM)"]
383
  ),
384
  gr.Row(visible=True),
385
+ gr.Row(visible=True),
386
  ),
387
  inputs=[correlation_type],
388
  outputs=[
 
398
  logger.info(f"Updating correlation plot for {correlation_type}")
399
  corr_data = compute_correlation_data(spr_data_with_scores, SCORE_COLUMNS)
400
  logger.info(f"Correlation data: {corr_data}")
401
+ corr_ranking_plot = plot_correlation_ranking(
402
+ corr_data, correlation_type, kd_col="KD (nM)" if not use_log else "log_kd"
403
+ )
404
  regression_plot = make_regression_plot(spr_data_with_scores, score, use_log)
405
  return regression_plot, corr_ranking_plot
406
 
 
435
  Folding Studio is a platform for protein structure prediction.
436
  It uses the latest AI-powered folding models to predict the structure of a protein.
437
 
438
+ Available models are : AlphaFold2, OpenFold, Boltz-1, Chai and Protenix.
 
 
 
 
439
  """
440
  )
441
+ with gr.Accordion("API Key", open=False):
442
+ gr.Markdown(
443
+ """
444
+ To use the Folding Studio API, you need to provide an API key.
445
+ You can get your API key by asking to the Folding Studio team.
446
+ """
447
+ )
448
+ api_key = gr.Textbox(
449
+ placeholder="Enter your Folding Studio API key",
450
+ type="password",
451
+ show_label=False,
452
+ )
453
  gr.Markdown("## Demo Usage")
454
  with gr.Tab("🚀 Basic Folding"):
455
  simple_prediction(api_key)
folding_studio_demo/models.py CHANGED
@@ -9,6 +9,7 @@ from io import StringIO
9
  from pathlib import Path
10
  from typing import Any
11
 
 
12
  import gradio as gr
13
  import numpy as np
14
  from folding_studio import single_job_prediction
@@ -202,7 +203,33 @@ class ProtenixModel(AF3Model):
202
 
203
  def predictions(self, output_dir: Path) -> list[Path]:
204
  """Get the path to the prediction."""
205
- return list(output_dir.rglob("*_model_[0-9].cif"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
 
208
  class BoltzModel(AF3Model):
@@ -259,12 +286,13 @@ class OldModel:
259
  output = single_job_prediction(
260
  fasta_file=seq_file,
261
  parameters=parameters,
 
262
  )
263
  experiment_id = output["message"]["experiment_id"]
264
  done = False
265
  while not done:
266
  with Capturing() as output:
267
- get_status(experiment_id)
268
  status = output[0]
269
  logger.info(f"Experiment {experiment_id} status: {status}")
270
  if status == "Done":
@@ -275,6 +303,7 @@ class OldModel:
275
  force=True,
276
  unzip=True,
277
  output=output_dir / "results.zip",
 
278
  )
279
  logger.info("Results downloaded to %s", output_dir)
280
  else:
 
9
  from pathlib import Path
10
  from typing import Any
11
 
12
+ import folding_studio
13
  import gradio as gr
14
  import numpy as np
15
  from folding_studio import single_job_prediction
 
203
 
204
  def predictions(self, output_dir: Path) -> list[Path]:
205
  """Get the path to the prediction."""
206
+ prediction = next(output_dir.rglob("sequence_*_sample_[0-9].cif"), None)
207
+ if prediction is None:
208
+ return {}
209
+
210
+ cif_files = {
211
+ int(f.stem[-1]): f
212
+ for f in prediction.parent.glob("sequence_*_sample_[0-9].cif")
213
+ }
214
+
215
+ # Get all npz files and extract their indices
216
+ json_files = {
217
+ int(f.stem[-1]): f
218
+ for f in prediction.parent.glob(
219
+ "sequence_*_summary_confidence_sample_[0-9].json"
220
+ )
221
+ }
222
+
223
+ # Find common indices and create pairs
224
+ common_indices = sorted(set(cif_files.keys()) & set(json_files.keys()))
225
+
226
+ return {
227
+ idx: {
228
+ "prediction_path": cif_files[idx],
229
+ "metrics": json.load(open(json_files[idx])),
230
+ }
231
+ for idx in common_indices
232
+ }
233
 
234
 
235
  class BoltzModel(AF3Model):
 
286
  output = single_job_prediction(
287
  fasta_file=seq_file,
288
  parameters=parameters,
289
+ api_key=self.api_key,
290
  )
291
  experiment_id = output["message"]["experiment_id"]
292
  done = False
293
  while not done:
294
  with Capturing() as output:
295
+ get_status(experiment_id, api_key=self.api_key)
296
  status = output[0]
297
  logger.info(f"Experiment {experiment_id} status: {status}")
298
  if status == "Done":
 
303
  force=True,
304
  unzip=True,
305
  output=output_dir / "results.zip",
306
+ api_key=self.api_key,
307
  )
308
  logger.info("Results downloaded to %s", output_dir)
309
  else:
folding_studio_demo/predict.py CHANGED
@@ -91,34 +91,32 @@ def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None:
91
 
92
 
93
  def create_plddt_figure(
94
- plddt_vals: list[list[float]],
95
  model_name: str,
96
  indexes: list[int],
97
- residue_codes: list[list[str]] = None,
98
  ) -> go.Figure:
99
  """Create a plot of metrics."""
100
  plddt_traces = []
101
 
102
- for i, (plddt_val, index) in enumerate(zip(plddt_vals, indexes)):
103
- # Create hover text with residue codes if available
104
- if residue_codes and i < len(residue_codes):
105
- hover_text = [
106
- f"<i>{model_name} {index}</i><br><i>pLDDT</i>: {plddt:.2f}<br><i>Residue:</i> {code} {idx}"
107
- for idx, (plddt, code) in enumerate(zip(plddt_val, residue_codes[i]))
108
- ]
109
- else:
110
- hover_text = [
111
- f"<i>{model_name} {index}</i><br><i>pLDDT</i>: {plddt:.2f}<br><i>Residue index:</i> {idx}"
112
- for idx, plddt in enumerate(plddt_val)
113
  ]
114
 
115
  plddt_traces.append(
116
  go.Scatter(
117
- x=np.arange(len(plddt_val)),
118
- y=plddt_val,
119
  hovertemplate="%{text}<extra></extra>",
120
  text=hover_text,
121
- name=f"{model_name} {index}",
122
  visible=True,
123
  )
124
  )
@@ -160,7 +158,9 @@ def _write_fasta_file(
160
  return seq_id, seq_file
161
 
162
 
163
- def extract_plddt_from_structure(structure_path: str) -> tuple[list[float], list[str]]:
 
 
164
  """Extract pLDDT values and residue codes from a structure file.
165
 
166
  Args:
@@ -175,22 +175,24 @@ def extract_plddt_from_structure(structure_path: str) -> tuple[list[float], list
175
  structure = PDBParser().get_structure("structure", structure_path)
176
 
177
  # Lists to store pLDDT values and residue codes
178
- plddt_values = []
179
- residue_codes = []
180
 
181
  # Iterate through all atoms
182
  for model in structure:
183
  for chain in model:
 
184
  for residue in chain:
185
  # Get the first atom of each residue (usually CA atom)
186
  if "CA" in residue:
187
  # The B-factor contains the pLDDT value
188
  plddt = residue["CA"].get_bfactor()
189
- plddt_values.append(plddt)
190
  # Get residue code and convert to one-letter code
191
- residue_codes.append(convert_to_one_letter(residue.get_resname()))
 
 
192
 
193
- return plddt_values, residue_codes
194
 
195
 
196
  def predict(
@@ -253,7 +255,6 @@ def predict(
253
  predictions = model.predictions(output_dir)
254
  pdb_paths = []
255
  model_plddt_vals = []
256
- model_residue_codes = []
257
 
258
  total_predictions = len(predictions)
259
  for i, (model_idx, prediction) in enumerate(predictions.items()):
@@ -270,9 +271,8 @@ def predict(
270
  pdb_paths.append(converted_pdb_path)
271
  else:
272
  pdb_paths.append(str(prediction_path))
273
- plddt_vals, residue_codes = extract_plddt_from_structure(prediction_path)
274
  model_plddt_vals.append(plddt_vals)
275
- model_residue_codes.append(residue_codes)
276
 
277
  progress(0.8, desc="Generating plots...")
278
  indexes = []
@@ -290,7 +290,6 @@ def predict(
290
  plddt_vals=model_plddt_vals,
291
  model_name=model.model_name,
292
  indexes=indexes,
293
- residue_codes=model_residue_codes,
294
  )
295
 
296
  progress(1.0, desc="Done!")
@@ -434,9 +433,8 @@ def run_prediction(
434
  model_pdb_paths, model_plddt_traces = predict(
435
  sequence, api_key, model_type, format_fasta=format_fasta
436
  )
437
- model_pdb_paths = sorted(model_pdb_paths)
438
  model_predictions = {}
439
- for pdb_path, plddt_trace in zip(model_pdb_paths, model_plddt_traces.data):
440
  if model_type in [
441
  FoldingModel.AF2,
442
  FoldingModel.OPENFOLD,
@@ -446,7 +444,8 @@ def run_prediction(
446
  else:
447
  index = int(Path(pdb_path).stem[-1])
448
 
449
- model_predictions[index] = {"pdb_path": pdb_path, "plddt_trace": plddt_trace}
 
450
  return model_predictions
451
 
452
 
 
91
 
92
 
93
  def create_plddt_figure(
94
+ plddt_vals: list[dict[str, dict[str, list[float]]]],
95
  model_name: str,
96
  indexes: list[int],
 
97
  ) -> go.Figure:
98
  """Create a plot of metrics."""
99
  plddt_traces = []
100
 
101
+ for i, (pred_plddt, index) in enumerate(zip(plddt_vals, indexes)):
102
+ hover_text = []
103
+ plddt_values = []
104
+ for chain_id, plddt_val in pred_plddt.items():
105
+ plddt_values += plddt_val["values"]
106
+ hover_text += [
107
+ f"<i>{model_name} {index} - Chain {chain_id}</i><br><i>pLDDT</i>: {plddt:.2f}<br><i>Residue:</i> {code} {idx}"
108
+ for idx, (plddt, code) in enumerate(
109
+ zip(plddt_val["values"], plddt_val["residue_codes"])
110
+ )
 
111
  ]
112
 
113
  plddt_traces.append(
114
  go.Scatter(
115
+ x=np.arange(len(plddt_values)),
116
+ y=plddt_values,
117
  hovertemplate="%{text}<extra></extra>",
118
  text=hover_text,
119
+ name=f"{model_name} {index} - Chain {chain_id}",
120
  visible=True,
121
  )
122
  )
 
158
  return seq_id, seq_file
159
 
160
 
161
+ def extract_plddt_from_structure(
162
+ structure_path: str,
163
+ ) -> dict[str, dict[str, list[float]]]:
164
  """Extract pLDDT values and residue codes from a structure file.
165
 
166
  Args:
 
175
  structure = PDBParser().get_structure("structure", structure_path)
176
 
177
  # Lists to store pLDDT values and residue codes
178
+ plddt_values = {}
 
179
 
180
  # Iterate through all atoms
181
  for model in structure:
182
  for chain in model:
183
+ plddt_values[chain.id] = {"values": [], "residue_codes": []}
184
  for residue in chain:
185
  # Get the first atom of each residue (usually CA atom)
186
  if "CA" in residue:
187
  # The B-factor contains the pLDDT value
188
  plddt = residue["CA"].get_bfactor()
189
+ plddt_values[chain.id]["values"].append(plddt)
190
  # Get residue code and convert to one-letter code
191
+ plddt_values[chain.id]["residue_codes"].append(
192
+ convert_to_one_letter(residue.get_resname())
193
+ )
194
 
195
+ return plddt_values
196
 
197
 
198
  def predict(
 
255
  predictions = model.predictions(output_dir)
256
  pdb_paths = []
257
  model_plddt_vals = []
 
258
 
259
  total_predictions = len(predictions)
260
  for i, (model_idx, prediction) in enumerate(predictions.items()):
 
271
  pdb_paths.append(converted_pdb_path)
272
  else:
273
  pdb_paths.append(str(prediction_path))
274
+ plddt_vals = extract_plddt_from_structure(prediction_path)
275
  model_plddt_vals.append(plddt_vals)
 
276
 
277
  progress(0.8, desc="Generating plots...")
278
  indexes = []
 
290
  plddt_vals=model_plddt_vals,
291
  model_name=model.model_name,
292
  indexes=indexes,
 
293
  )
294
 
295
  progress(1.0, desc="Done!")
 
433
  model_pdb_paths, model_plddt_traces = predict(
434
  sequence, api_key, model_type, format_fasta=format_fasta
435
  )
 
436
  model_predictions = {}
437
+ for pdb_path, plddt_traces in zip(model_pdb_paths, model_plddt_traces.data):
438
  if model_type in [
439
  FoldingModel.AF2,
440
  FoldingModel.OPENFOLD,
 
444
  else:
445
  index = int(Path(pdb_path).stem[-1])
446
 
447
+ model_predictions[index] = {"pdb_path": pdb_path, "plddt_trace": plddt_traces}
448
+
449
  return model_predictions
450
 
451