jfaustin AchilleSoulieID commited on
Commit
f601557
·
verified ·
1 Parent(s): 142035a

Add model comparison (#3)

Browse files

- add model comparison (485bd69b80145fab090b2e9d368082247e7a4d93)


Co-authored-by: Achille Soulie <[email protected]>

folding_studio_demo/app.py CHANGED
@@ -3,12 +3,16 @@
3
  import logging
4
 
5
  import gradio as gr
 
6
  from folding_studio_data_models import FoldingModel
7
  from gradio_molecule3d import Molecule3D
8
- import pandas as pd
9
 
10
- from folding_studio_demo.predict import predict
11
- from folding_studio_demo.correlate import fake_predict_and_correlate, SCORE_COLUMNS, select_correlation_plot
 
 
 
 
12
 
13
  logger = logging.getLogger(__name__)
14
 
@@ -119,7 +123,7 @@ def model_comparison(api_key: str) -> None:
119
  """
120
 
121
  with gr.Row():
122
- model = gr.Dropdown(
123
  label="Model",
124
  choices=MODEL_CHOICES,
125
  multiselect=True,
@@ -133,13 +137,18 @@ def model_comparison(api_key: str) -> None:
133
  predict_btn = gr.Button("Compare Models")
134
 
135
  with gr.Row():
136
- mol_output = Molecule3D(label="Protein Structure", reps=MOLECULE_REPS)
137
- metrics_plot = gr.Plot(label="pLDDT")
 
 
 
 
 
138
 
139
  predict_btn.click(
140
- fn=predict,
141
- inputs=[sequence, api_key, model],
142
- outputs=[mol_output, metrics_plot],
143
  )
144
 
145
 
@@ -147,12 +156,12 @@ def create_correlation_tab():
147
  gr.Markdown("# Correlation with experimental binding affinity data")
148
  spr_data_with_scores = pd.read_csv("spr_af_scores_mapped.csv")
149
  prettified_columns = {
150
- "antibody_name": "Antibody Name",
151
- "KD (nM)": "KD (nM)",
152
- "antibody_vh_sequence": "Antibody VH Sequence",
153
- "antibody_vl_sequence": "Antibody VL Sequence",
154
- "antigen_sequence": "Antigen Sequence"
155
- }
156
  spr_data_with_scores = spr_data_with_scores.rename(columns=prettified_columns)
157
  with gr.Row():
158
  columns = [
@@ -160,10 +169,13 @@ def create_correlation_tab():
160
  "KD (nM)",
161
  "Antibody VH Sequence",
162
  "Antibody VL Sequence",
163
- "Antigen Sequence"
164
  ]
165
  # Display dataframe with floating point values rounded to 2 decimal places
166
- spr_data = gr.DataFrame(value=spr_data_with_scores[columns].round(2), label="Experimental Antibody-Antigen Binding Affinity Data")
 
 
 
167
 
168
  gr.Markdown("# Prediction and correlation")
169
  with gr.Row():
@@ -174,22 +186,27 @@ def create_correlation_tab():
174
  correlation_ranking_plot = gr.Plot(label="Correlation ranking")
175
  with gr.Row():
176
  # User can select the columns to display in the correlation plot
177
- correlation_column = gr.Dropdown(label="Score data to display", choices=SCORE_COLUMNS, multiselect=False)
 
 
178
  correlation_plot = gr.Plot(label="Correlation with binding affinity")
179
 
180
  fake_predict_btn.click(
181
- fn=lambda x: fake_predict_and_correlate(spr_data_with_scores, SCORE_COLUMNS, ["Antibody Name", "KD (nM)"]),
 
 
182
  inputs=None,
183
- outputs=[prediction_dataframe, correlation_ranking_plot]
184
  )
185
 
186
  # Call function to update the correlation plot when the user selects the columns
187
  correlation_column.change(
188
  fn=lambda score: select_correlation_plot(spr_data_with_scores, score),
189
  inputs=correlation_column,
190
- outputs=correlation_plot
191
  )
192
-
 
193
  def __main__():
194
  with gr.Blocks(title="Folding Studio Demo") as demo:
195
  gr.Markdown(
 
3
  import logging
4
 
5
  import gradio as gr
6
+ import pandas as pd
7
  from folding_studio_data_models import FoldingModel
8
  from gradio_molecule3d import Molecule3D
 
9
 
10
+ from folding_studio_demo.correlate import (
11
+ SCORE_COLUMNS,
12
+ fake_predict_and_correlate,
13
+ select_correlation_plot,
14
+ )
15
+ from folding_studio_demo.predict import predict, predict_comparison
16
 
17
  logger = logging.getLogger(__name__)
18
 
 
123
  """
124
 
125
  with gr.Row():
126
+ models = gr.Dropdown(
127
  label="Model",
128
  choices=MODEL_CHOICES,
129
  multiselect=True,
 
137
  predict_btn = gr.Button("Compare Models")
138
 
139
  with gr.Row():
140
+ mol_outputs = Molecule3D(
141
+ label="Protein Structure",
142
+ reps=MOLECULE_REPS,
143
+ file_count="multiple",
144
+ )
145
+
146
+ # metrics_plot = gr.Plot(label="pLDDT")
147
 
148
  predict_btn.click(
149
+ fn=predict_comparison,
150
+ inputs=[sequence, api_key, models],
151
+ outputs=[mol_outputs],
152
  )
153
 
154
 
 
156
  gr.Markdown("# Correlation with experimental binding affinity data")
157
  spr_data_with_scores = pd.read_csv("spr_af_scores_mapped.csv")
158
  prettified_columns = {
159
+ "antibody_name": "Antibody Name",
160
+ "KD (nM)": "KD (nM)",
161
+ "antibody_vh_sequence": "Antibody VH Sequence",
162
+ "antibody_vl_sequence": "Antibody VL Sequence",
163
+ "antigen_sequence": "Antigen Sequence",
164
+ }
165
  spr_data_with_scores = spr_data_with_scores.rename(columns=prettified_columns)
166
  with gr.Row():
167
  columns = [
 
169
  "KD (nM)",
170
  "Antibody VH Sequence",
171
  "Antibody VL Sequence",
172
+ "Antigen Sequence",
173
  ]
174
  # Display dataframe with floating point values rounded to 2 decimal places
175
+ spr_data = gr.DataFrame(
176
+ value=spr_data_with_scores[columns].round(2),
177
+ label="Experimental Antibody-Antigen Binding Affinity Data",
178
+ )
179
 
180
  gr.Markdown("# Prediction and correlation")
181
  with gr.Row():
 
186
  correlation_ranking_plot = gr.Plot(label="Correlation ranking")
187
  with gr.Row():
188
  # User can select the columns to display in the correlation plot
189
+ correlation_column = gr.Dropdown(
190
+ label="Score data to display", choices=SCORE_COLUMNS, multiselect=False
191
+ )
192
  correlation_plot = gr.Plot(label="Correlation with binding affinity")
193
 
194
  fake_predict_btn.click(
195
+ fn=lambda x: fake_predict_and_correlate(
196
+ spr_data_with_scores, SCORE_COLUMNS, ["Antibody Name", "KD (nM)"]
197
+ ),
198
  inputs=None,
199
+ outputs=[prediction_dataframe, correlation_ranking_plot],
200
  )
201
 
202
  # Call function to update the correlation plot when the user selects the columns
203
  correlation_column.change(
204
  fn=lambda score: select_correlation_plot(spr_data_with_scores, score),
205
  inputs=correlation_column,
206
+ outputs=correlation_plot,
207
  )
208
+
209
+
210
  def __main__():
211
  with gr.Blocks(title="Folding Studio Demo") as demo:
212
  gr.Markdown(
folding_studio_demo/predict.py CHANGED
@@ -3,13 +3,15 @@
3
  import hashlib
4
  import logging
5
  import os
 
6
  from pathlib import Path
 
7
 
8
  import gradio as gr
9
  import numpy as np
10
  import plotly.graph_objects as go
11
  from Bio import SeqIO
12
- from Bio.PDB import PDBIO, MMCIFParser
13
  from folding_studio.client import Client
14
  from folding_studio.query import Query
15
  from folding_studio.query.boltz import BoltzQuery
@@ -50,18 +52,21 @@ def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None:
50
  io.save(pdb_path)
51
 
52
 
53
- def add_plddt_plot(plddt_vals: list[float]) -> str:
54
  """Create a plot of metrics."""
55
  visible = True
56
- plddt_trace = go.Scatter(
57
- x=np.arange(len(plddt_vals)),
58
- y=plddt_vals,
59
- hovertemplate="<i>pLDDT</i>: %{y:.2f} <br><i>Residue index:</i> %{x}<br>",
60
- name="seq",
61
- visible=visible,
62
- )
 
 
 
63
 
64
- plddt_fig = go.Figure(data=[plddt_trace])
65
  plddt_fig.update_layout(
66
  title="pLDDT",
67
  xaxis_title="Residue index",
@@ -85,7 +90,13 @@ def _write_fasta_file(
85
  Returns:
86
  tuple[str, Path]: Tuple containing the sequence ID and the path to the FASTA file
87
  """
88
- seq_id = hashlib.sha1(sequence.encode()).hexdigest()
 
 
 
 
 
 
89
  seq_file = directory / f"sequence_{seq_id}.fasta"
90
  with open(seq_file, "w") as f:
91
  f.write(sequence)
@@ -146,7 +157,7 @@ class AF3Model:
146
 
147
  def has_prediction(self, output_dir: Path) -> bool:
148
  """Check if prediction exists in output directory."""
149
- return any(self.predictions(output_dir))
150
 
151
  def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]:
152
  """Check if the file description is correct.
@@ -157,10 +168,6 @@ class AF3Model:
157
  Returns:
158
  tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not
159
  """
160
- input_rep = list(SeqIO.parse(seq_file, "fasta"))
161
- if not input_rep:
162
- error_msg = f"{self.model_name.upper()} Validation Error: No sequence found"
163
- return False, error_msg
164
 
165
  is_valid, error_msg = self.validator.is_valid_fasta(seq_file)
166
  if not is_valid:
@@ -182,9 +189,41 @@ class ChaiModel(AF3Model):
182
  """
183
  super().call(seq_file, output_dir)
184
 
185
- def predictions(self, output_dir: Path) -> list[Path]:
 
 
 
 
 
 
 
 
 
 
 
186
  """Get the path to the prediction."""
187
- return list(output_dir.rglob("*_model_[0-9].cif"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
 
190
  class ProtenixModel(AF3Model):
@@ -221,7 +260,33 @@ class BoltzModel(AF3Model):
221
 
222
  def predictions(self, output_dir: Path) -> list[Path]:
223
  """Get the path to the prediction."""
224
- return list(output_dir.rglob("*_model_[0-9].cif"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
 
227
  def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str, str]:
@@ -235,6 +300,8 @@ def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str,
235
  Returns:
236
  tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot
237
  """
 
 
238
 
239
  # Set up unique output directory based on sequence hash
240
  seq_id, seq_file = _write_fasta_file(sequence)
@@ -265,15 +332,93 @@ def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str,
265
  if not model.has_prediction(output_dir):
266
  raise gr.Error("No prediction found")
267
 
268
- pred_cif = model.predictions(output_dir)[0]
269
- logger.info("Output file: %s", pred_cif)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
- converted_pdb_path = str(output_dir / f"pred_{seq_id}.pdb")
272
- convert_cif_to_pdb(str(pred_cif), str(converted_pdb_path))
273
- logger.info("Converted PDB file: %s", converted_pdb_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
- plddt_file = list(pred_cif.parent.glob("plddt_*.npz"))[0]
276
- logger.info("plddt file: %s", plddt_file)
277
- plddt_vals = np.load(plddt_file)["plddt"]
278
 
279
- return converted_pdb_path, add_plddt_plot(plddt_vals=plddt_vals)
 
3
  import hashlib
4
  import logging
5
  import os
6
+ from io import StringIO
7
  from pathlib import Path
8
+ from typing import Any
9
 
10
  import gradio as gr
11
  import numpy as np
12
  import plotly.graph_objects as go
13
  from Bio import SeqIO
14
+ from Bio.PDB import PDBIO, MMCIFParser, PDBParser, Superimposer
15
  from folding_studio.client import Client
16
  from folding_studio.query import Query
17
  from folding_studio.query.boltz import BoltzQuery
 
52
  io.save(pdb_path)
53
 
54
 
55
+ def add_plddt_plot(plddt_vals: list[list[float]], model_name: str) -> go.Figure:
56
  """Create a plot of metrics."""
57
  visible = True
58
+ plddt_traces = [
59
+ go.Scatter(
60
+ x=np.arange(len(plddt_val)),
61
+ y=plddt_val,
62
+ hovertemplate="<i>pLDDT</i>: %{y:.2f} <br><i>Residue index:</i> %{x}<br>",
63
+ name=f"{model_name} {i}",
64
+ visible=visible,
65
+ )
66
+ for i, plddt_val in enumerate(plddt_vals)
67
+ ]
68
 
69
+ plddt_fig = go.Figure(data=plddt_traces)
70
  plddt_fig.update_layout(
71
  title="pLDDT",
72
  xaxis_title="Residue index",
 
90
  Returns:
91
  tuple[str, Path]: Tuple containing the sequence ID and the path to the FASTA file
92
  """
93
+ input_rep = list(SeqIO.parse(StringIO(sequence), "fasta"))
94
+ if not input_rep:
95
+ raise gr.Error("No sequence found")
96
+
97
+ seq_id = hashlib.sha256(
98
+ "_".join([str(records.seq) for records in input_rep]).encode()
99
+ ).hexdigest()
100
  seq_file = directory / f"sequence_{seq_id}.fasta"
101
  with open(seq_file, "w") as f:
102
  f.write(sequence)
 
157
 
158
  def has_prediction(self, output_dir: Path) -> bool:
159
  """Check if prediction exists in output directory."""
160
+ return len(self.predictions(output_dir)) > 0
161
 
162
  def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]:
163
  """Check if the file description is correct.
 
168
  Returns:
169
  tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not
170
  """
 
 
 
 
171
 
172
  is_valid, error_msg = self.validator.is_valid_fasta(seq_file)
173
  if not is_valid:
 
189
  """
190
  super().call(seq_file, output_dir)
191
 
192
+ def _get_chai_paired_files(self, directory: Path) -> list[tuple[Path, Path]]:
193
+ """Get pairs of .cif and .npz files with matching model indices.
194
+
195
+ Args:
196
+ directory (Path): Directory containing the prediction files
197
+
198
+ Returns:
199
+ list[tuple[Path, Path]]: List of tuples containing (cif_path, npz_path) pairs
200
+ """
201
+ # Get all cif files and extract their indices
202
+
203
+ def predictions(self, output_dir: Path) -> dict[Path, dict[str, Any]]:
204
  """Get the path to the prediction."""
205
+ prediction = next(output_dir.rglob("pred.model_idx_[0-9].cif"), None)
206
+ if prediction is None:
207
+ return {}
208
+
209
+ cif_files = {
210
+ int(f.stem.split("model_idx_")[1]): f
211
+ for f in prediction.parent.glob("pred.model_idx_*.cif")
212
+ }
213
+
214
+ # Get all npz files and extract their indices
215
+ npz_files = {
216
+ int(f.stem.split("model_idx_")[1]): f
217
+ for f in prediction.parent.glob("scores.model_idx_*.npz")
218
+ }
219
+
220
+ # Find common indices and create pairs
221
+ common_indices = sorted(set(cif_files.keys()) & set(npz_files.keys()))
222
+
223
+ return {
224
+ idx: {"prediction_path": cif_files[idx], "metrics": np.load(npz_files[idx])}
225
+ for idx in common_indices
226
+ }
227
 
228
 
229
  class ProtenixModel(AF3Model):
 
260
 
261
  def predictions(self, output_dir: Path) -> list[Path]:
262
  """Get the path to the prediction."""
263
+ prediction_paths = list(output_dir.rglob("*_model_[0-9].cif"))
264
+ return {
265
+ int(cif_path.stem[-1]): {
266
+ "prediction_path": cif_path,
267
+ "metrics": np.load(list(cif_path.parent.glob("plddt_*.npz"))[0]),
268
+ }
269
+ for cif_path in prediction_paths
270
+ }
271
+
272
+
273
+ def extract_plddt_from_cif(cif_path):
274
+ structure = MMCIFParser().get_structure("structure", cif_path)
275
+
276
+ # Dictionary to store pLDDT values per residue
277
+ plddt_values = []
278
+
279
+ # Iterate through all atoms
280
+ for model in structure:
281
+ for chain in model:
282
+ for residue in chain:
283
+ # Get the first atom of each residue (usually CA atom)
284
+ if "CA" in residue:
285
+ # The B-factor contains the pLDDT value
286
+ plddt = residue["CA"].get_bfactor()
287
+ plddt_values.append(plddt)
288
+
289
+ return plddt_values
290
 
291
 
292
  def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str, str]:
 
300
  Returns:
301
  tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot
302
  """
303
+ if not api_key:
304
+ raise gr.Error("Missing API key, please enter a valid API key")
305
 
306
  # Set up unique output directory based on sequence hash
307
  seq_id, seq_file = _write_fasta_file(sequence)
 
332
  if not model.has_prediction(output_dir):
333
  raise gr.Error("No prediction found")
334
 
335
+ predictions = model.predictions(output_dir)
336
+ pdb_paths = []
337
+ model_plddt_vals = []
338
+ for model_idx, prediction in predictions.items():
339
+ cif_path = prediction["prediction_path"]
340
+ logger.info(
341
+ "CIF file: %s",
342
+ )
343
+
344
+ converted_pdb_path = str(
345
+ output_dir / f"{model.model_name}_prediction_{model_idx}.pdb"
346
+ )
347
+ convert_cif_to_pdb(str(cif_path), str(converted_pdb_path))
348
+ plddt_vals = extract_plddt_from_cif(cif_path)
349
+ pdb_paths.append(converted_pdb_path)
350
+ model_plddt_vals.append(plddt_vals)
351
+ plddt_plot = add_plddt_plot(
352
+ plddt_vals=model_plddt_vals, model_name=model.model_name
353
+ )
354
+ return pdb_paths, plddt_plot
355
+
356
+
357
+ def align_structures(pdb_paths: list[str]) -> list[str]:
358
+ """Align multiple PDB structures to the first structure.
359
 
360
+ Args:
361
+ pdb_paths (list[str]): List of paths to PDB files to align
362
+
363
+ Returns:
364
+ list[str]: List of paths to aligned PDB files
365
+ """
366
+
367
+ parser = PDBParser()
368
+ io = PDBIO()
369
+
370
+ # Parse the reference structure (first one)
371
+ ref_structure = parser.get_structure("reference", pdb_paths[0])
372
+ ref_atoms = [atom for atom in ref_structure.get_atoms() if atom.get_name() == "CA"]
373
+
374
+ aligned_paths = [pdb_paths[0]] # First structure is already aligned
375
+
376
+ # Align each subsequent structure to the reference
377
+ for i, pdb_path in enumerate(pdb_paths[1:], start=1):
378
+ # Parse the structure to align
379
+ structure = parser.get_structure(f"model_{i}", pdb_path)
380
+ atoms = [atom for atom in structure.get_atoms() if atom.get_name() == "CA"]
381
+
382
+ # Create superimposer
383
+ sup = Superimposer()
384
+
385
+ # Set the reference and moving atoms
386
+ sup.set_atoms(ref_atoms, atoms)
387
+
388
+ # Apply the transformation to all atoms in the structure
389
+ sup.apply(structure.get_atoms())
390
+
391
+ # Save the aligned structure
392
+ aligned_path = str(Path(pdb_path).parent / f"aligned_{Path(pdb_path).name}")
393
+ io.set_structure(structure)
394
+ io.save(aligned_path)
395
+ aligned_paths.append(aligned_path)
396
+
397
+ return aligned_paths
398
+
399
+
400
+ def predict_comparison(
401
+ sequence: str, api_key: str, model_types: list[FoldingModel]
402
+ ) -> tuple[str, str]:
403
+ """Predict protein structure from amino acid sequence using Boltz model.
404
+
405
+ Args:
406
+ sequence (str): Amino acid sequence to predict structure for
407
+ api_key (str): Folding API key
408
+ model (FoldingModel): Folding model to use
409
+
410
+ Returns:
411
+ tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot
412
+ """
413
+ if not api_key:
414
+ raise gr.Error("Missing API key, please enter a valid API key")
415
+
416
+ # Set up unique output directory based on sequence hash
417
+ pdb_paths = []
418
+ for model_type in model_types:
419
+ model_pdb_paths, _ = predict(sequence, api_key, model_type)
420
+ pdb_paths += model_pdb_paths
421
 
422
+ aligned_paths = align_structures(pdb_paths)
 
 
423
 
424
+ return aligned_paths