jfaustin commited on
Commit
d861d5c
·
2 Parent(s): 24f13c2 6354ea8

Merge remote-tracking branch 'origin/main' into pr/12

Browse files
folding_studio_demo/app.py CHANGED
@@ -39,8 +39,8 @@ MOLECULE_REPS = [
39
 
40
 
41
  MODEL_CHOICES = [
42
- # ("AlphaFold2", FoldingModel.AF2),
43
- # ("OpenFold", FoldingModel.OPENFOLD),
44
  # ("SoloSeq", FoldingModel.SOLOSEQ),
45
  ("Boltz-1", FoldingModel.BOLTZ),
46
  ("Chai-1", FoldingModel.CHAI),
@@ -49,6 +49,15 @@ MODEL_CHOICES = [
49
 
50
  DEFAULT_SEQ = "MALWMRLLPLLALLALWGPDPAAA"
51
  MODEL_EXAMPLES = {
 
 
 
 
 
 
 
 
 
52
  FoldingModel.BOLTZ: [
53
  ["Monomer", f">A|protein\n{DEFAULT_SEQ}"],
54
  ["Multimer", f">A|protein\n{DEFAULT_SEQ}\n>B|protein\n{DEFAULT_SEQ}"],
@@ -70,27 +79,31 @@ def sequence_input(dropdown: gr.Dropdown | None = None) -> gr.Textbox:
70
  Returns:
71
  gr.Textbox: Sequence input component
72
  """
73
- sequence = gr.Textbox(
74
- label="Protein Sequence",
75
- lines=2,
76
- placeholder="Enter a protein sequence or upload a FASTA file",
77
- )
78
- dummy = gr.Textbox(label="Complex type", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- examples = gr.Examples(
81
- examples=MODEL_EXAMPLES[FoldingModel.BOLTZ],
82
- inputs=[dummy, sequence],
83
- )
84
  if dropdown is not None:
85
  dropdown.change(
86
  fn=lambda x: gr.Dataset(samples=MODEL_EXAMPLES[x]),
87
  inputs=[dropdown],
88
  outputs=[examples.dataset],
89
  )
90
- file_input = gr.File(
91
- label="Upload a FASTA file",
92
- file_types=[".fasta", ".fa"],
93
- )
94
 
95
  def _process_file(file: gr.File | None) -> gr.Textbox:
96
  if file is None:
@@ -115,7 +128,7 @@ def simple_prediction(api_key: str) -> None:
115
  """
116
  gr.Markdown(
117
  """
118
- ### Predict a Protein Structure
119
 
120
  It will be run in the background and the results will be displayed in the output section.
121
  The output will contain the protein structure and the pLDDT plot.
@@ -157,7 +170,19 @@ def model_comparison(api_key: str) -> None:
157
  Args:
158
  api_key (str): Folding Studio API key
159
  """
 
 
 
160
 
 
 
 
 
 
 
 
 
 
161
  with gr.Row():
162
  models = gr.CheckboxGroup(
163
  label="Model",
@@ -176,6 +201,9 @@ def model_comparison(api_key: str) -> None:
176
  variant="primary",
177
  )
178
  with gr.Row():
 
 
 
179
  chai_predictions = gr.CheckboxGroup(label="Chai", visible=False)
180
  protenix_predictions = gr.CheckboxGroup(label="Protenix", visible=False)
181
  boltz_predictions = gr.CheckboxGroup(label="Boltz", visible=False)
@@ -186,28 +214,50 @@ def model_comparison(api_key: str) -> None:
186
  metrics_plot = gr.Plot(label="pLDDT")
187
 
188
  # Store the initial predictions
189
- aligned_paths = gr.State()
190
- plddt_fig = gr.State()
191
 
192
  predict_btn.click(
193
  fn=predict_comparison,
194
  inputs=[sequence, api_key, models],
195
  outputs=[
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  chai_predictions,
197
  boltz_predictions,
198
  protenix_predictions,
199
- aligned_paths,
200
- plddt_fig,
201
  ],
 
202
  )
203
 
204
  # Handle checkbox changes
205
- for checkbox in [chai_predictions, boltz_predictions, protenix_predictions]:
 
 
 
 
 
 
 
206
  checkbox.change(
207
  fn=filter_predictions,
208
  inputs=[
209
- aligned_paths,
210
- plddt_fig,
 
 
211
  chai_predictions,
212
  boltz_predictions,
213
  protenix_predictions,
@@ -242,63 +292,64 @@ def create_correlation_tab():
242
  "antigen_sequence": "Antigen Sequence",
243
  }
244
  spr_data_with_scores = spr_data_with_scores.rename(columns=prettified_columns)
245
- with gr.Row():
246
- columns = [
247
- "Antibody Name",
248
- "KD (nM)",
249
- "Antibody VH Sequence",
250
- "Antibody VL Sequence",
251
- "Antigen Sequence",
252
- ]
253
- # Display dataframe with floating point values rounded to 2 decimal places
254
- spr_data = gr.DataFrame(
255
- value=spr_data_with_scores[columns].round(2),
256
- label="Experimental Antibody-Antigen Binding Affinity Data",
257
- )
258
 
259
  gr.Markdown("# Prediction and correlation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  with gr.Row():
261
- fake_predict_btn = gr.Button(
262
- "Predict structures of all complexes",
263
- elem_classes="gradient-button",
264
- variant="primary",
 
 
265
  )
 
266
  with gr.Row():
267
- prediction_dataframe = gr.Dataframe(label="Predicted Structures Data")
268
- with gr.Row():
269
- with gr.Row():
270
- correlation_type = gr.Radio(
271
- choices=["Spearman", "Pearson", "R²"],
272
- value="Spearman",
273
- label="Correlation Type",
274
- interactive=True,
 
 
 
 
 
 
 
 
 
 
 
275
  )
276
- with gr.Row():
277
- correlation_ranking_plot = gr.Plot(label="Correlation ranking")
278
- with gr.Row():
279
- with gr.Column():
280
- with gr.Row():
281
- # User can select the columns to display in the correlation plot
282
- correlation_column = gr.Dropdown(
283
- label="Score data to display",
284
- choices=SCORE_COLUMNS,
285
- multiselect=False,
286
- value=SCORE_COLUMNS[0],
287
- )
288
- # Add checkbox for log scale and update plot when either input changes
289
- with gr.Row():
290
- log_scale = gr.Checkbox(
291
- label="Display x-axis on logarithmic scale", value=False
292
- )
293
- with gr.Row():
294
- score_description = gr.Markdown(
295
- get_score_description(correlation_column.value)
296
- )
297
- correlation_column.change(
298
- fn=lambda x: get_score_description(x),
299
- inputs=correlation_column,
300
- outputs=score_description,
301
- )
302
  with gr.Column():
303
  regression_plot = gr.Plot(label="Correlation with binding affinity")
304
 
@@ -333,7 +384,7 @@ def create_correlation_tab():
333
 
334
  log_scale.change(
335
  fn=update_regression_plot,
336
- inputs=[correlation_column, log_scale],
337
  outputs=regression_plot,
338
  )
339
 
@@ -360,7 +411,7 @@ def __main__():
360
  )
361
  api_key = gr.Textbox(label="Folding Studio API Key", type="password")
362
  gr.Markdown("## Demo Usage")
363
- with gr.Tab("🚀 Simple Prediction"):
364
  simple_prediction(api_key)
365
  with gr.Tab("📊 Model Comparison"):
366
  model_comparison(api_key)
 
39
 
40
 
41
  MODEL_CHOICES = [
42
+ ("AlphaFold2", FoldingModel.AF2),
43
+ ("OpenFold", FoldingModel.OPENFOLD),
44
  # ("SoloSeq", FoldingModel.SOLOSEQ),
45
  ("Boltz-1", FoldingModel.BOLTZ),
46
  ("Chai-1", FoldingModel.CHAI),
 
49
 
50
  DEFAULT_SEQ = "MALWMRLLPLLALLALWGPDPAAA"
51
  MODEL_EXAMPLES = {
52
+ FoldingModel.AF2: [
53
+ ["Monomer", f">A\n{DEFAULT_SEQ}"],
54
+ ["Multimer", f">A\n{DEFAULT_SEQ}\n>B\n{DEFAULT_SEQ}"],
55
+ ],
56
+ FoldingModel.OPENFOLD: [
57
+ ["Monomer", f">A\n{DEFAULT_SEQ}"],
58
+ ["Multimer", f">A\n{DEFAULT_SEQ}\n>B\n{DEFAULT_SEQ}"],
59
+ ],
60
+ FoldingModel.SOLOSEQ: [["Monomer", f">A\n{DEFAULT_SEQ}"]],
61
  FoldingModel.BOLTZ: [
62
  ["Monomer", f">A|protein\n{DEFAULT_SEQ}"],
63
  ["Multimer", f">A|protein\n{DEFAULT_SEQ}\n>B|protein\n{DEFAULT_SEQ}"],
 
79
  Returns:
80
  gr.Textbox: Sequence input component
81
  """
82
+ with gr.Row(equal_height=True):
83
+ with gr.Column():
84
+ sequence = gr.Textbox(
85
+ label="Protein Sequence",
86
+ lines=2,
87
+ placeholder="Enter a protein sequence or upload a FASTA file",
88
+ )
89
+ dummy = gr.Textbox(label="Complex type", visible=False)
90
+
91
+ examples = gr.Examples(
92
+ examples=MODEL_EXAMPLES[FoldingModel.BOLTZ],
93
+ inputs=[dummy, sequence],
94
+ )
95
+ file_input = gr.File(
96
+ label="Upload a FASTA file",
97
+ file_types=[".fasta", ".fa"],
98
+ scale=0,
99
+ )
100
 
 
 
 
 
101
  if dropdown is not None:
102
  dropdown.change(
103
  fn=lambda x: gr.Dataset(samples=MODEL_EXAMPLES[x]),
104
  inputs=[dropdown],
105
  outputs=[examples.dataset],
106
  )
 
 
 
 
107
 
108
  def _process_file(file: gr.File | None) -> gr.Textbox:
109
  if file is None:
 
128
  """
129
  gr.Markdown(
130
  """
131
+ ## Predict a Protein Structure
132
 
133
  It will be run in the background and the results will be displayed in the output section.
134
  The output will contain the protein structure and the pLDDT plot.
 
170
  Args:
171
  api_key (str): Folding Studio API key
172
  """
173
+ gr.Markdown(
174
+ """
175
+ ## Compare Folding Models
176
 
177
+ Select multiple models to compare their predictions on your protein sequence.
178
+ You can either enter the sequence directly or upload a FASTA file.
179
+
180
+ The selected models will run in parallel and generate:
181
+ - 3D structures of your protein that you can visualize and compare
182
+ - pLDDT confidence scores plotted for each residue
183
+
184
+ """
185
+ )
186
  with gr.Row():
187
  models = gr.CheckboxGroup(
188
  label="Model",
 
201
  variant="primary",
202
  )
203
  with gr.Row():
204
+ af2_predictions = gr.CheckboxGroup(label="AlphaFold2", visible=False)
205
+ openfold_predictions = gr.CheckboxGroup(label="OpenFold", visible=False)
206
+ solo_predictions = gr.CheckboxGroup(label="SoloSeq", visible=False)
207
  chai_predictions = gr.CheckboxGroup(label="Chai", visible=False)
208
  protenix_predictions = gr.CheckboxGroup(label="Protenix", visible=False)
209
  boltz_predictions = gr.CheckboxGroup(label="Boltz", visible=False)
 
214
  metrics_plot = gr.Plot(label="pLDDT")
215
 
216
  # Store the initial predictions
217
+ prediction_outputs = gr.State()
 
218
 
219
  predict_btn.click(
220
  fn=predict_comparison,
221
  inputs=[sequence, api_key, models],
222
  outputs=[
223
+ prediction_outputs,
224
+ af2_predictions,
225
+ openfold_predictions,
226
+ solo_predictions,
227
+ chai_predictions,
228
+ boltz_predictions,
229
+ protenix_predictions,
230
+ ],
231
+ ).then(
232
+ fn=filter_predictions,
233
+ inputs=[
234
+ prediction_outputs,
235
+ af2_predictions,
236
+ openfold_predictions,
237
+ solo_predictions,
238
  chai_predictions,
239
  boltz_predictions,
240
  protenix_predictions,
 
 
241
  ],
242
+ outputs=[mol_outputs, metrics_plot],
243
  )
244
 
245
  # Handle checkbox changes
246
+ for checkbox in [
247
+ af2_predictions,
248
+ openfold_predictions,
249
+ solo_predictions,
250
+ chai_predictions,
251
+ boltz_predictions,
252
+ protenix_predictions,
253
+ ]:
254
  checkbox.change(
255
  fn=filter_predictions,
256
  inputs=[
257
+ prediction_outputs,
258
+ af2_predictions,
259
+ openfold_predictions,
260
+ solo_predictions,
261
  chai_predictions,
262
  boltz_predictions,
263
  protenix_predictions,
 
292
  "antigen_sequence": "Antigen Sequence",
293
  }
294
  spr_data_with_scores = spr_data_with_scores.rename(columns=prettified_columns)
295
+ columns = [
296
+ "Antibody Name",
297
+ "KD (nM)",
298
+ "Antibody VH Sequence",
299
+ "Antibody VL Sequence",
300
+ "Antigen Sequence",
301
+ ]
302
+ # Display dataframe with floating point values rounded to 2 decimal places
303
+ spr_data = gr.DataFrame(
304
+ value=spr_data_with_scores[columns].round(2),
305
+ label="Experimental Antibody-Antigen Binding Affinity Data",
306
+ )
 
307
 
308
  gr.Markdown("# Prediction and correlation")
309
+
310
+ fake_predict_btn = gr.Button(
311
+ "Predict structures of all complexes",
312
+ elem_classes="gradient-button",
313
+ variant="primary",
314
+ )
315
+ prediction_dataframe = gr.Dataframe(
316
+ label="Predicted Structures Data", visible=False
317
+ )
318
+ prediction_dataframe.change(
319
+ fn=lambda x: gr.Dataframe(x, visible=True),
320
+ inputs=[prediction_dataframe],
321
+ outputs=[prediction_dataframe],
322
+ )
323
  with gr.Row():
324
+ correlation_type = gr.Radio(
325
+ choices=["Spearman", "Pearson", "R²"],
326
+ value="Spearman",
327
+ label="Correlation Type",
328
+ interactive=True,
329
+ scale=0,
330
  )
331
+ correlation_ranking_plot = gr.Plot(label="Correlation ranking")
332
  with gr.Row():
333
+ with gr.Column(scale=0):
334
+ # User can select the columns to display in the correlation plot
335
+ correlation_column = gr.Dropdown(
336
+ label="Score data to display",
337
+ choices=SCORE_COLUMNS,
338
+ multiselect=False,
339
+ value=SCORE_COLUMNS[0],
340
+ )
341
+ # Add checkbox for log scale and update plot when either input changes
342
+ log_scale = gr.Checkbox(
343
+ label="Display x-axis on logarithmic scale", value=False
344
+ )
345
+ score_description = gr.Markdown(
346
+ get_score_description(correlation_column.value)
347
+ )
348
+ correlation_column.change(
349
+ fn=lambda x: get_score_description(x),
350
+ inputs=correlation_column,
351
+ outputs=score_description,
352
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  with gr.Column():
354
  regression_plot = gr.Plot(label="Correlation with binding affinity")
355
 
 
384
 
385
  log_scale.change(
386
  fn=update_regression_plot,
387
+ inputs=[correlation_column, log_scale],
388
  outputs=regression_plot,
389
  )
390
 
 
411
  )
412
  api_key = gr.Textbox(label="Folding Studio API Key", type="password")
413
  gr.Markdown("## Demo Usage")
414
+ with gr.Tab("🚀 Basic Folding"):
415
  simple_prediction(api_key)
416
  with gr.Tab("📊 Model Comparison"):
417
  model_comparison(api_key)
folding_studio_demo/correlate.py CHANGED
@@ -1,9 +1,10 @@
1
  import logging
2
- import pandas as pd
3
  from pathlib import Path
 
4
  import numpy as np
 
5
  import plotly.graph_objects as go
6
- from scipy.stats import spearmanr, pearsonr, linregress
7
 
8
  logger = logging.getLogger(__name__)
9
 
@@ -16,7 +17,7 @@ SCORE_COLUMN_NAMES = {
16
  "complex_pde_boltz": "Boltz Complex pDE",
17
  "complex_ipde_boltz": "Boltz Complex ipDE",
18
  "interchain_pae_monomer": "AlphaFold2 GapTrick Interchain PAE",
19
- "interface_pae_monomer": "AlphaFold2 GapTrick Interface PAE",
20
  "overall_pae_monomer": "AlphaFold2 GapTrick Overall PAE",
21
  "interface_plddt_monomer": "AlphaFold2 GapTrick Interface pLDDT",
22
  "average_plddt_monomer": "AlphaFold2 GapTrick Average pLDDT",
@@ -24,15 +25,16 @@ SCORE_COLUMN_NAMES = {
24
  "interface_ptm_monomer": "AlphaFold2 GapTrick Interface pTM",
25
  "interchain_pae_multimer": "AlphaFold2 Multimer Interchain PAE",
26
  "interface_pae_multimer": "AlphaFold2 Multimer Interface PAE",
27
- "overall_pae_multimer": "AlphaFold2 Multimer Overall PAE",
28
  "interface_plddt_multimer": "AlphaFold2 Multimer Interface pLDDT",
29
  "average_plddt_multimer": "AlphaFold2 Multimer Average pLDDT",
30
  "ptm_multimer": "AlphaFold2 Multimer pTM Score",
31
- "interface_ptm_multimer": "AlphaFold2 Multimer Interface pTM"
32
  }
33
 
34
  SCORE_COLUMNS = list(SCORE_COLUMN_NAMES.values())
35
 
 
36
  def get_score_description(score: str) -> str:
37
  descriptions = {
38
  "Boltz Confidence Score": "The Boltz model confidence score provides an overall assessment of prediction quality (0-1, higher is better).",
@@ -49,22 +51,24 @@ def get_score_description(score: str) -> str:
49
  "AlphaFold2 GapTrick Average pLDDT": "The AlphaFold2 GapTrick model average pLDDT provides the mean confidence across all residues in monomeric predictions (0-100, higher is better).",
50
  "AlphaFold2 GapTrick pTM Score": "The AlphaFold2 GapTrick model pTM score assesses overall fold accuracy in monomeric predictions (0-1, higher is better).",
51
  "AlphaFold2 GapTrick Interface pTM": "The AlphaFold2 GapTrick model interface pTM specifically evaluates accuracy of interface regions in monomeric predictions (0-1, higher is better).",
52
- "AlphaFold2 GapTrick Interchain PAE": "The AlphaFold2 GapTrick model interchain PAE estimates position errors between chains in multimeric predictions (lower is better).",
53
- "AlphaFold2 Multimer Interface PAE": "The AlphaFold2 Multimer model interface PAE estimates position errors specifically at interfaces in multimeric predictions (lower is better).",
54
  "AlphaFold2 Multimer Overall PAE": "The AlphaFold2 Multimer model overall PAE estimates position errors across the entire structure in multimeric predictions (lower is better).",
55
  "AlphaFold2 Multimer Interface pLDDT": "The AlphaFold2 Multimer model interface pLDDT measures confidence in interface region predictions for multimeric models (0-100, higher is better).",
56
  "AlphaFold2 Multimer Average pLDDT": "The AlphaFold2 Multimer model average pLDDT provides the mean confidence across all residues in multimeric predictions (0-100, higher is better).",
57
  "AlphaFold2 Multimer pTM Score": "The AlphaFold2 Multimer model pTM score assesses overall fold accuracy in multimeric predictions (0-1, higher is better).",
58
- "AlphaFold2 Multimer Interface pTM": "The AlphaFold2 Multimer model interface pTM specifically evaluates accuracy of interface regions in multimeric predictions (0-1, higher is better)."
59
  }
60
  return descriptions.get(score, "No description available for this score.")
61
 
62
- def compute_correlation_data(spr_data_with_scores: pd.DataFrame, score_cols: list[str]) -> pd.DataFrame:
 
 
 
63
  corr_data_file = Path("corr_data.csv")
64
  if corr_data_file.exists():
65
  logger.info(f"Loading correlation data from {corr_data_file}")
66
  return pd.read_csv(corr_data_file)
67
-
68
  corr_data = []
69
  spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
70
  kd_col = "KD (nM)"
@@ -74,53 +78,71 @@ def compute_correlation_data(spr_data_with_scores: pd.DataFrame, score_cols: lis
74
  corr_funcs["R²"] = linregress
75
  for correlation_type, corr_func in corr_funcs.items():
76
  for score_col in score_cols:
77
- logger.info(f"Computing {correlation_type} correlation between {score_col} and KD (nM)")
78
- res = corr_func(spr_data_with_scores[kd_col], spr_data_with_scores[score_col])
 
 
 
 
79
  logger.info(f"Correlation function: {corr_func}")
80
- correlation_value = res.rvalue**2 if correlation_type == "R²" else res.statistic
81
- corr_data.append({
82
- "correlation_type": correlation_type,
83
- "score": score_col,
84
- "correlation": correlation_value,
85
- "p-value": res.pvalue
86
- })
87
- logger.info(f"Correlation {correlation_type} between {score_col} and KD (nM): {correlation_value}")
 
 
 
 
 
 
88
 
89
  corr_data = pd.DataFrame(corr_data)
90
  # Find the lines in corr_data with NaN values and remove them
91
  corr_data = corr_data[corr_data["correlation"].notna()]
92
  # Sort correlation data by correlation value
93
- corr_data = corr_data.sort_values('correlation', ascending=True)
94
-
95
  corr_data.to_csv("corr_data.csv", index=False)
96
-
97
  return corr_data
98
 
99
- def plot_correlation_ranking(corr_data: pd.DataFrame, correlation_type: str) -> go.Figure:
 
 
 
100
  # Create bar plot of correlations
101
  data = corr_data[corr_data["correlation_type"] == correlation_type]
102
- corr_ranking_plot = go.Figure(data=[
103
- go.Bar(
104
- x=data["correlation"],
105
- y=data["score"],
106
- name=correlation_type,
107
- text=data["correlation"],
108
- orientation='h',
109
- hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>"
110
- )
111
- ])
 
 
112
  corr_ranking_plot.update_layout(
113
  title="Correlation with Binding Affinity",
114
  yaxis_title="Score",
115
  xaxis_title=correlation_type,
116
  template="simple_white",
117
- showlegend=False
118
  )
119
  return corr_ranking_plot
120
 
121
- def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: list[str], main_cols: list[str]) -> tuple[pd.DataFrame, go.Figure]:
 
 
 
122
  """Fake predict structures of all complexes and correlate the results."""
123
-
124
  corr_data = compute_correlation_data(spr_data_with_scores, score_cols)
125
  corr_ranking_plot = plot_correlation_ranking(corr_data, "Spearman")
126
 
@@ -131,17 +153,20 @@ def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: l
131
 
132
  return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot
133
 
134
- def make_regression_plot(spr_data_with_scores: pd.DataFrame, score: str, use_log: bool) -> go.Figure:
 
 
 
135
  """Select the regression plot to display."""
136
  # corr_plot is a scatter plot of the regression between the binding affinity and each of the scores
137
- scatter = go.Scatter(
138
- x=spr_data_with_scores["KD (nM)"],
139
- y=spr_data_with_scores[score],
140
- name=f"Samples",
141
- mode='markers', # Only show markers/dots, no lines
142
- hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>",
143
- marker=dict(color='#1f77b4') # Set color to match default first color
144
- )
145
  corr_plot = go.Figure(data=scatter)
146
  corr_plot.update_layout(
147
  xaxis_title="KD (nM)",
@@ -154,7 +179,7 @@ def make_regression_plot(spr_data_with_scores: pd.DataFrame, score: str, use_log
154
  xanchor="right",
155
  x=1,
156
  ),
157
- xaxis_type="log" if use_log else "linear" # Set x-axis to logarithmic scale
158
  )
159
  # compute the regression line
160
  if use_log:
@@ -162,23 +187,25 @@ def make_regression_plot(spr_data_with_scores: pd.DataFrame, score: str, use_log
162
  x_vals = np.log10(spr_data_with_scores["KD (nM)"])
163
  else:
164
  x_vals = spr_data_with_scores["KD (nM)"]
165
-
166
  # Fit line to data
167
  corr_line = np.polyfit(x_vals, spr_data_with_scores[score], 1)
168
-
169
  # Generate x points for line
170
  corr_line_x = np.linspace(min(x_vals), max(x_vals), 100)
171
  corr_line_y = corr_line[0] * corr_line_x + corr_line[1]
172
-
173
  # Convert back from log space if needed
174
  if use_log:
175
  corr_line_x = 10**corr_line_x
176
  # add the regression line to the plot
177
- corr_plot.add_trace(go.Scatter(
178
- x=corr_line_x,
179
- y=corr_line_y,
180
- mode='lines',
181
- name=f"Regression line",
182
- line=dict(color='#1f77b4') # Set same color as scatter points
183
- ))
184
- return corr_plot
 
 
 
1
  import logging
 
2
  from pathlib import Path
3
+
4
  import numpy as np
5
+ import pandas as pd
6
  import plotly.graph_objects as go
7
+ from scipy.stats import linregress, pearsonr, spearmanr
8
 
9
  logger = logging.getLogger(__name__)
10
 
 
17
  "complex_pde_boltz": "Boltz Complex pDE",
18
  "complex_ipde_boltz": "Boltz Complex ipDE",
19
  "interchain_pae_monomer": "AlphaFold2 GapTrick Interchain PAE",
20
+ "interface_pae_monomer": "AlphaFold2 GapTrick Interface PAE",
21
  "overall_pae_monomer": "AlphaFold2 GapTrick Overall PAE",
22
  "interface_plddt_monomer": "AlphaFold2 GapTrick Interface pLDDT",
23
  "average_plddt_monomer": "AlphaFold2 GapTrick Average pLDDT",
 
25
  "interface_ptm_monomer": "AlphaFold2 GapTrick Interface pTM",
26
  "interchain_pae_multimer": "AlphaFold2 Multimer Interchain PAE",
27
  "interface_pae_multimer": "AlphaFold2 Multimer Interface PAE",
28
+ "overall_pae_multimer": "AlphaFold2 Multimer Overall PAE",
29
  "interface_plddt_multimer": "AlphaFold2 Multimer Interface pLDDT",
30
  "average_plddt_multimer": "AlphaFold2 Multimer Average pLDDT",
31
  "ptm_multimer": "AlphaFold2 Multimer pTM Score",
32
+ "interface_ptm_multimer": "AlphaFold2 Multimer Interface pTM",
33
  }
34
 
35
  SCORE_COLUMNS = list(SCORE_COLUMN_NAMES.values())
36
 
37
+
38
  def get_score_description(score: str) -> str:
39
  descriptions = {
40
  "Boltz Confidence Score": "The Boltz model confidence score provides an overall assessment of prediction quality (0-1, higher is better).",
 
51
  "AlphaFold2 GapTrick Average pLDDT": "The AlphaFold2 GapTrick model average pLDDT provides the mean confidence across all residues in monomeric predictions (0-100, higher is better).",
52
  "AlphaFold2 GapTrick pTM Score": "The AlphaFold2 GapTrick model pTM score assesses overall fold accuracy in monomeric predictions (0-1, higher is better).",
53
  "AlphaFold2 GapTrick Interface pTM": "The AlphaFold2 GapTrick model interface pTM specifically evaluates accuracy of interface regions in monomeric predictions (0-1, higher is better).",
54
+ "AlphaFold2 Multimer Interface PAE": "The AlphaFold2 Multimer model interface PAE estimates position errors specifically at interfaces in multimeric predictions (lower is better).",
 
55
  "AlphaFold2 Multimer Overall PAE": "The AlphaFold2 Multimer model overall PAE estimates position errors across the entire structure in multimeric predictions (lower is better).",
56
  "AlphaFold2 Multimer Interface pLDDT": "The AlphaFold2 Multimer model interface pLDDT measures confidence in interface region predictions for multimeric models (0-100, higher is better).",
57
  "AlphaFold2 Multimer Average pLDDT": "The AlphaFold2 Multimer model average pLDDT provides the mean confidence across all residues in multimeric predictions (0-100, higher is better).",
58
  "AlphaFold2 Multimer pTM Score": "The AlphaFold2 Multimer model pTM score assesses overall fold accuracy in multimeric predictions (0-1, higher is better).",
59
+ "AlphaFold2 Multimer Interface pTM": "The AlphaFold2 Multimer model interface pTM specifically evaluates accuracy of interface regions in multimeric predictions (0-1, higher is better).",
60
  }
61
  return descriptions.get(score, "No description available for this score.")
62
 
63
+
64
+ def compute_correlation_data(
65
+ spr_data_with_scores: pd.DataFrame, score_cols: list[str]
66
+ ) -> pd.DataFrame:
67
  corr_data_file = Path("corr_data.csv")
68
  if corr_data_file.exists():
69
  logger.info(f"Loading correlation data from {corr_data_file}")
70
  return pd.read_csv(corr_data_file)
71
+
72
  corr_data = []
73
  spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
74
  kd_col = "KD (nM)"
 
78
  corr_funcs["R²"] = linregress
79
  for correlation_type, corr_func in corr_funcs.items():
80
  for score_col in score_cols:
81
+ logger.info(
82
+ f"Computing {correlation_type} correlation between {score_col} and KD (nM)"
83
+ )
84
+ res = corr_func(
85
+ spr_data_with_scores[kd_col], spr_data_with_scores[score_col]
86
+ )
87
  logger.info(f"Correlation function: {corr_func}")
88
+ correlation_value = (
89
+ res.rvalue**2 if correlation_type == "R²" else res.statistic
90
+ )
91
+ corr_data.append(
92
+ {
93
+ "correlation_type": correlation_type,
94
+ "score": score_col,
95
+ "correlation": correlation_value,
96
+ "p-value": res.pvalue,
97
+ }
98
+ )
99
+ logger.info(
100
+ f"Correlation {correlation_type} between {score_col} and KD (nM): {correlation_value}"
101
+ )
102
 
103
  corr_data = pd.DataFrame(corr_data)
104
  # Find the lines in corr_data with NaN values and remove them
105
  corr_data = corr_data[corr_data["correlation"].notna()]
106
  # Sort correlation data by correlation value
107
+ corr_data = corr_data.sort_values("correlation", ascending=True)
108
+
109
  corr_data.to_csv("corr_data.csv", index=False)
110
+
111
  return corr_data
112
 
113
+
114
+ def plot_correlation_ranking(
115
+ corr_data: pd.DataFrame, correlation_type: str
116
+ ) -> go.Figure:
117
  # Create bar plot of correlations
118
  data = corr_data[corr_data["correlation_type"] == correlation_type]
119
+ corr_ranking_plot = go.Figure(
120
+ data=[
121
+ go.Bar(
122
+ x=data["correlation"],
123
+ y=data["score"],
124
+ name=correlation_type,
125
+ text=data["correlation"],
126
+ orientation="h",
127
+ hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>",
128
+ )
129
+ ]
130
+ )
131
  corr_ranking_plot.update_layout(
132
  title="Correlation with Binding Affinity",
133
  yaxis_title="Score",
134
  xaxis_title=correlation_type,
135
  template="simple_white",
136
+ showlegend=False,
137
  )
138
  return corr_ranking_plot
139
 
140
+
141
+ def fake_predict_and_correlate(
142
+ spr_data_with_scores: pd.DataFrame, score_cols: list[str], main_cols: list[str]
143
+ ) -> tuple[pd.DataFrame, go.Figure]:
144
  """Fake predict structures of all complexes and correlate the results."""
145
+
146
  corr_data = compute_correlation_data(spr_data_with_scores, score_cols)
147
  corr_ranking_plot = plot_correlation_ranking(corr_data, "Spearman")
148
 
 
153
 
154
  return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot
155
 
156
+
157
+ def make_regression_plot(
158
+ spr_data_with_scores: pd.DataFrame, score: str, use_log: bool
159
+ ) -> go.Figure:
160
  """Select the regression plot to display."""
161
  # corr_plot is a scatter plot of the regression between the binding affinity and each of the scores
162
+ scatter = go.Scatter(
163
+ x=spr_data_with_scores["KD (nM)"],
164
+ y=spr_data_with_scores[score],
165
+ name=f"Samples",
166
+ mode="markers", # Only show markers/dots, no lines
167
+ hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>",
168
+ marker=dict(color="#1f77b4"), # Set color to match default first color
169
+ )
170
  corr_plot = go.Figure(data=scatter)
171
  corr_plot.update_layout(
172
  xaxis_title="KD (nM)",
 
179
  xanchor="right",
180
  x=1,
181
  ),
182
+ xaxis_type="log" if use_log else "linear", # Set x-axis to logarithmic scale
183
  )
184
  # compute the regression line
185
  if use_log:
 
187
  x_vals = np.log10(spr_data_with_scores["KD (nM)"])
188
  else:
189
  x_vals = spr_data_with_scores["KD (nM)"]
190
+
191
  # Fit line to data
192
  corr_line = np.polyfit(x_vals, spr_data_with_scores[score], 1)
193
+
194
  # Generate x points for line
195
  corr_line_x = np.linspace(min(x_vals), max(x_vals), 100)
196
  corr_line_y = corr_line[0] * corr_line_x + corr_line[1]
197
+
198
  # Convert back from log space if needed
199
  if use_log:
200
  corr_line_x = 10**corr_line_x
201
  # add the regression line to the plot
202
+ corr_plot.add_trace(
203
+ go.Scatter(
204
+ x=corr_line_x,
205
+ y=corr_line_y,
206
+ mode="lines",
207
+ name=f"Regression line",
208
+ line=dict(color="#1f77b4"), # Set same color as scatter points
209
+ )
210
+ )
211
+ return corr_plot
folding_studio_demo/model_fasta_validators.py CHANGED
@@ -248,15 +248,15 @@ class ChaiFastaValidator(BaseFastaValidator):
248
  )
249
  seen_names.add(name)
250
  # validate sequence format
251
- # sequence = str(record.seq).strip()
252
- # if (
253
- # entity_type in {EntityType.PEPTIDE, EntityType.PROTEIN}
254
- # and not get_entity_type(sequence) == entity_type
255
- # ):
256
- # return (
257
- # False,
258
- # f"CHAI Validation Error: Sequence type mismatch. Expected '{entity_type}' but found '{get_entity_type(sequence)}'",
259
- # )
260
 
261
  return True, None
262
 
 
248
  )
249
  seen_names.add(name)
250
  # validate sequence format
251
+ sequence = str(record.seq).strip()
252
+ if (
253
+ entity_type in {EntityType.PEPTIDE, EntityType.PROTEIN}
254
+ and not get_entity_type(sequence) == entity_type
255
+ ):
256
+ return (
257
+ False,
258
+ f"CHAI Validation Error: Sequence type mismatch. Expected '{entity_type}' but found '{get_entity_type(sequence)}'",
259
+ )
260
 
261
  return True, None
262
 
folding_studio_demo/models.py CHANGED
@@ -1,17 +1,26 @@
1
  """Models for the Folding Studio API."""
2
 
 
3
  import logging
4
  import os
 
 
 
5
  from pathlib import Path
6
  from typing import Any
7
 
8
  import gradio as gr
9
  import numpy as np
 
10
  from folding_studio.client import Client
 
 
11
  from folding_studio.query import Query
12
  from folding_studio.query.boltz import BoltzQuery
13
  from folding_studio.query.chai import ChaiQuery
14
  from folding_studio.query.protenix import ProtenixQuery
 
 
15
 
16
  from folding_studio_demo.model_fasta_validators import (
17
  BaseFastaValidator,
@@ -20,15 +29,29 @@ from folding_studio_demo.model_fasta_validators import (
20
  ProtenixFastaValidator,
21
  )
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  logger = logging.getLogger(__name__)
24
 
25
 
26
  class AF3Model:
27
- def __init__(
28
- self, api_key: str, model_name: str, query: Query, validator: BaseFastaValidator
29
- ):
30
  self.api_key = api_key
31
- self.model_name = model_name
32
  self.query = query
33
  self.validator = validator
34
 
@@ -116,8 +139,10 @@ class AF3Model:
116
 
117
 
118
  class ChaiModel(AF3Model):
 
 
119
  def __init__(self, api_key: str):
120
- super().__init__(api_key, "Chai", ChaiQuery, ChaiFastaValidator())
121
 
122
  def call(
123
  self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
@@ -158,8 +183,10 @@ class ChaiModel(AF3Model):
158
 
159
 
160
  class ProtenixModel(AF3Model):
 
 
161
  def __init__(self, api_key: str):
162
- super().__init__(api_key, "Protenix", ProtenixQuery, ProtenixFastaValidator())
163
 
164
  def call(
165
  self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
@@ -179,8 +206,10 @@ class ProtenixModel(AF3Model):
179
 
180
 
181
  class BoltzModel(AF3Model):
 
 
182
  def __init__(self, api_key: str):
183
- super().__init__(api_key, "Boltz", BoltzQuery, BoltzFastaValidator())
184
 
185
  def call(
186
  self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
@@ -205,3 +234,113 @@ class BoltzModel(AF3Model):
205
  }
206
  for cif_path in prediction_paths
207
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Models for the Folding Studio API."""
2
 
3
+ import json
4
  import logging
5
  import os
6
+ import sys
7
+ import time
8
+ from io import StringIO
9
  from pathlib import Path
10
  from typing import Any
11
 
12
  import gradio as gr
13
  import numpy as np
14
+ from folding_studio import single_job_prediction
15
  from folding_studio.client import Client
16
+ from folding_studio.commands.experiment import results as get_results
17
+ from folding_studio.commands.experiment import status as get_status
18
  from folding_studio.query import Query
19
  from folding_studio.query.boltz import BoltzQuery
20
  from folding_studio.query.chai import ChaiQuery
21
  from folding_studio.query.protenix import ProtenixQuery
22
+ from folding_studio_data_models import AF2Parameters, OpenFoldParameters
23
+ from folding_studio_data_models.parameters.base import BaseFoldingParameters
24
 
25
  from folding_studio_demo.model_fasta_validators import (
26
  BaseFastaValidator,
 
29
  ProtenixFastaValidator,
30
  )
31
 
32
+
33
+ class Capturing(list):
34
+ """Capture stdout output."""
35
+
36
+ def __enter__(self):
37
+ self._stdout = sys.stdout
38
+ sys.stdout = self._stringio = StringIO()
39
+ return self
40
+
41
+ def __exit__(self, *args):
42
+ self.extend(self._stringio.getvalue().splitlines())
43
+ del self._stringio # free up some memory
44
+ sys.stdout = self._stdout
45
+
46
+
47
  logger = logging.getLogger(__name__)
48
 
49
 
50
  class AF3Model:
51
+ model_name = None
52
+
53
+ def __init__(self, api_key: str, query: Query, validator: BaseFastaValidator):
54
  self.api_key = api_key
 
55
  self.query = query
56
  self.validator = validator
57
 
 
139
 
140
 
141
  class ChaiModel(AF3Model):
142
+ model_name = "Chai"
143
+
144
  def __init__(self, api_key: str):
145
+ super().__init__(api_key, ChaiQuery, ChaiFastaValidator())
146
 
147
  def call(
148
  self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
 
183
 
184
 
185
  class ProtenixModel(AF3Model):
186
+ model_name = "Protenix"
187
+
188
  def __init__(self, api_key: str):
189
+ super().__init__(api_key, ProtenixQuery, ProtenixFastaValidator())
190
 
191
  def call(
192
  self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
 
206
 
207
 
208
  class BoltzModel(AF3Model):
209
+ model_name = "Boltz"
210
+
211
  def __init__(self, api_key: str):
212
+ super().__init__(api_key, BoltzQuery, BoltzFastaValidator())
213
 
214
  def call(
215
  self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False
 
234
  }
235
  for cif_path in prediction_paths
236
  }
237
+
238
+
239
+ class OldModel:
240
+ model_name = None
241
+
242
+ def __init__(self, api_key: str):
243
+ self.api_key = api_key
244
+
245
+ def call(
246
+ self,
247
+ seq_file: Path | str,
248
+ output_dir: Path,
249
+ parameters: BaseFoldingParameters,
250
+ *args,
251
+ **kwargs,
252
+ ) -> None:
253
+ """Predict protein structure from amino acid sequence using AF2 model.
254
+
255
+ Args:
256
+ seq_file (Path | str): Path to FASTA file containing amino acid sequence
257
+ output_dir (Path): Path to output directory
258
+ """
259
+ output = single_job_prediction(
260
+ fasta_file=seq_file,
261
+ parameters=parameters,
262
+ )
263
+ experiment_id = output["message"]["experiment_id"]
264
+ done = False
265
+ while not done:
266
+ with Capturing() as output:
267
+ get_status(experiment_id)
268
+ status = output[0]
269
+ logger.info(f"Experiment {experiment_id} status: {status}")
270
+ if status == "Done":
271
+ done = True
272
+ logger.info("Downloading results")
273
+ get_results(
274
+ experiment_id,
275
+ force=True,
276
+ unzip=True,
277
+ output=output_dir / "results.zip",
278
+ )
279
+ logger.info("Results downloaded to %s", output_dir)
280
+ else:
281
+ logger.info("Sleeping for 10 seconds")
282
+ time.sleep(10)
283
+
284
+ def format_fasta(self, seq_file: Path | str) -> None:
285
+ """Format sequence to FASTA format.
286
+
287
+ Args:
288
+ seq_file (Path | str): Path to FASTA file
289
+ """
290
+ return
291
+
292
+ def predictions(self, output_dir: Path) -> dict[int, dict[str, Any]]:
293
+ """Get the path to the prediction.
294
+
295
+ Args:
296
+ output_dir (Path): Path to output directory
297
+
298
+ Returns:
299
+ dict[int, dict[str, Any]]: Dictionary mapping model indices to their prediction paths and metrics
300
+ """
301
+ prediction_paths = list(
302
+ (output_dir / "results").rglob("relaxed_model_[0-9]_ptm_pred_0.pdb")
303
+ )
304
+ metrics_path = output_dir / "results" / "metrics_per_model.json"
305
+ if not metrics_path.exists():
306
+ return {}
307
+ with open(metrics_path, "r") as f:
308
+ metrics = json.load(f)
309
+ output = {
310
+ int(pred_path.stem.split("_")[2]): {
311
+ "prediction_path": pred_path,
312
+ "metrics": metrics[f"model_{int(pred_path.stem.split('_')[2])}_ptm"],
313
+ }
314
+ for pred_path in prediction_paths
315
+ }
316
+ return output
317
+
318
+ def has_prediction(self, output_dir: Path) -> bool:
319
+ """Check if prediction exists in output directory."""
320
+ return len(self.predictions(output_dir)) > 0
321
+
322
+ def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]:
323
+ """Check if the file description is correct.
324
+
325
+ Args:
326
+ seq_file (Path | str): Path to FASTA file
327
+
328
+ Returns:
329
+ tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not
330
+ """
331
+
332
+ return True, None
333
+
334
+
335
+ class AF2Model(OldModel):
336
+ model_name = "AlphaFold2"
337
+
338
+ def call(self, seq_file: Path | str, output_dir: Path, *args, **kwargs) -> None:
339
+ super().call(seq_file, output_dir, AF2Parameters(), *args, **kwargs)
340
+
341
+
342
+ class OpenFoldModel(OldModel):
343
+ model_name = "OpenFold"
344
+
345
+ def call(self, seq_file: Path | str, output_dir: Path, *args, **kwargs) -> None:
346
+ super().call(seq_file, output_dir, OpenFoldParameters(), *args, **kwargs)
folding_studio_demo/predict.py CHANGED
@@ -1,9 +1,11 @@
1
  """Predict protein structure using Folding Studio."""
2
 
 
3
  import hashlib
4
  import logging
5
  from io import StringIO
6
  from pathlib import Path
 
7
 
8
  import gradio as gr
9
  import numpy as np
@@ -12,7 +14,13 @@ from Bio import SeqIO
12
  from Bio.PDB import PDBIO, MMCIFParser, PDBParser, Superimposer
13
  from folding_studio_data_models import FoldingModel
14
 
15
- from folding_studio_demo.models import BoltzModel, ChaiModel, ProtenixModel
 
 
 
 
 
 
16
 
17
  logger = logging.getLogger(__name__)
18
 
@@ -85,20 +93,22 @@ def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None:
85
  def create_plddt_figure(
86
  plddt_vals: list[list[float]],
87
  model_name: str,
 
88
  residue_codes: list[list[str]] = None,
89
  ) -> go.Figure:
90
  """Create a plot of metrics."""
91
  plddt_traces = []
92
- for i, plddt_val in enumerate(plddt_vals):
 
93
  # Create hover text with residue codes if available
94
  if residue_codes and i < len(residue_codes):
95
  hover_text = [
96
- f"<i>pLDDT</i>: {plddt:.2f}<br><i>Residue:</i> {code} {idx}"
97
  for idx, (plddt, code) in enumerate(zip(plddt_val, residue_codes[i]))
98
  ]
99
  else:
100
  hover_text = [
101
- f"<i>pLDDT</i>: {plddt:.2f}<br><i>Residue index:</i> {idx}"
102
  for idx, plddt in enumerate(plddt_val)
103
  ]
104
 
@@ -108,7 +118,7 @@ def create_plddt_figure(
108
  y=plddt_val,
109
  hovertemplate="%{text}<extra></extra>",
110
  text=hover_text,
111
- name=f"{model_name} {i}",
112
  visible=True,
113
  )
114
  )
@@ -150,8 +160,19 @@ def _write_fasta_file(
150
  return seq_id, seq_file
151
 
152
 
153
- def extract_plddt_from_cif(cif_path):
154
- structure = MMCIFParser().get_structure("structure", cif_path)
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  # Lists to store pLDDT values and residue codes
157
  plddt_values = []
@@ -206,6 +227,10 @@ def predict(
206
  model = ChaiModel(api_key)
207
  elif model_type == FoldingModel.PROTENIX:
208
  model = ProtenixModel(api_key)
 
 
 
 
209
  else:
210
  raise ValueError(f"Model {model_type} not supported")
211
 
@@ -235,22 +260,36 @@ def predict(
235
  progress(
236
  0.4 + (0.4 * i / total_predictions), desc=f"Converting model {model_idx}..."
237
  )
238
- cif_path = prediction["prediction_path"]
239
- logger.info(f"CIF file: {cif_path}")
240
-
241
- converted_pdb_path = str(
242
- output_dir / f"{model.model_name}_prediction_{model_idx}.pdb"
243
- )
244
- convert_cif_to_pdb(str(cif_path), str(converted_pdb_path))
245
- plddt_vals, residue_codes = extract_plddt_from_cif(cif_path)
246
- pdb_paths.append(converted_pdb_path)
 
 
247
  model_plddt_vals.append(plddt_vals)
248
  model_residue_codes.append(residue_codes)
249
 
250
  progress(0.8, desc="Generating plots...")
 
 
 
 
 
 
 
 
 
 
 
251
  plddt_fig = create_plddt_figure(
252
  plddt_vals=model_plddt_vals,
253
  model_name=model.model_name,
 
254
  residue_codes=model_residue_codes,
255
  )
256
 
@@ -258,11 +297,13 @@ def predict(
258
  return pdb_paths, plddt_fig
259
 
260
 
261
- def align_structures(pdb_paths: list[str]) -> list[str]:
 
 
262
  """Align multiple PDB structures to the first structure.
263
 
264
  Args:
265
- pdb_paths (list[str]): List of paths to PDB files to align
266
 
267
  Returns:
268
  list[str]: List of paths to aligned PDB files
@@ -271,39 +312,47 @@ def align_structures(pdb_paths: list[str]) -> list[str]:
271
  parser = PDBParser()
272
  io = PDBIO()
273
 
274
- # Parse the reference structure (first one)
275
- ref_structure = parser.get_structure("reference", pdb_paths[0])
 
 
 
 
 
276
  ref_atoms = [atom for atom in ref_structure.get_atoms() if atom.get_name() == "CA"]
277
 
278
- aligned_paths = [pdb_paths[0]] # First structure is already aligned
 
 
 
 
 
 
279
 
280
- # Align each subsequent structure to the reference
281
- for i, pdb_path in enumerate(pdb_paths[1:], start=1):
282
- # Parse the structure to align
283
- structure = parser.get_structure(f"model_{i}", pdb_path)
284
- atoms = [atom for atom in structure.get_atoms() if atom.get_name() == "CA"]
285
 
286
- # Create superimposer
287
- sup = Superimposer()
288
 
289
- # Set the reference and moving atoms
290
- sup.set_atoms(ref_atoms, atoms)
291
 
292
- # Apply the transformation to all atoms in the structure
293
- sup.apply(structure.get_atoms())
 
 
294
 
295
- # Save the aligned structure
296
- aligned_path = str(Path(pdb_path).parent / f"aligned_{Path(pdb_path).name}")
297
- io.set_structure(structure)
298
- io.save(aligned_path)
299
- aligned_paths.append(aligned_path)
300
 
301
- return aligned_paths
302
 
303
 
304
  def filter_predictions(
305
- aligned_paths: list[str],
306
- plddt_fig: go.Figure,
 
 
307
  chai_selected: list[int],
308
  boltz_selected: list[int],
309
  protenix_selected: list[int],
@@ -316,7 +365,7 @@ def filter_predictions(
316
  chai_selected (list[int]): Selected Chai model indices
317
  boltz_selected (list[int]): Selected Boltz model indices
318
  protenix_selected (list[int]): Selected Protenix model indices
319
- model_predictions (dict[FoldingModel, list[int]]): Dictionary mapping models to their prediction indices
320
 
321
  Returns:
322
  tuple[list[str], go.Figure]: Filtered PDB paths and updated pLDDT plot
@@ -325,26 +374,30 @@ def filter_predictions(
325
  filtered_fig = go.Figure()
326
 
327
  # Keep track of which traces to show
328
- visible_paths = []
329
 
330
  # Helper function to check if a trace should be visible
331
- def should_show_trace(trace_name: str) -> bool:
332
- model_name = trace_name.split()[0]
333
- model_idx = int(trace_name.split()[1])
334
-
335
- if model_name == "Chai" and model_idx in chai_selected:
 
 
 
336
  return True
337
- if model_name == "Boltz" and model_idx in boltz_selected:
338
  return True
339
- if model_name == "Protenix" and model_idx in protenix_selected:
340
  return True
341
  return False
342
 
343
  # Filter traces and paths
344
- for i, trace in enumerate(plddt_fig.data):
345
- if should_show_trace(trace.name):
346
- filtered_fig.add_trace(trace)
347
- visible_paths.append(aligned_paths[i])
 
348
 
349
  # Update layout
350
  filtered_fig.update_layout(
@@ -355,21 +408,58 @@ def filter_predictions(
355
  template="simple_white",
356
  legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
357
  )
 
358
 
359
- return visible_paths, filtered_fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
 
362
  def predict_comparison(
363
  sequence: str, api_key: str, model_types: list[FoldingModel], progress=gr.Progress()
364
  ) -> tuple[
365
- list[str],
366
- go.Figure,
 
 
367
  gr.CheckboxGroup,
368
  gr.CheckboxGroup,
369
  gr.CheckboxGroup,
370
- list[str],
371
- go.Figure,
372
- dict,
373
  ]:
374
  """Predict protein structure from amino acid sequence using multiple models.
375
 
@@ -381,68 +471,94 @@ def predict_comparison(
381
 
382
  Returns:
383
  tuple containing:
384
- - list[str]: Aligned PDB paths
385
- - go.Figure: pLDDT plot
 
 
386
  - gr.CheckboxGroup: Chai predictions checkbox group
387
  - gr.CheckboxGroup: Boltz predictions checkbox group
388
  - gr.CheckboxGroup: Protenix predictions checkbox group
389
- - list[str]: Original PDB paths
390
- - go.Figure: Original pLDDT plot
391
- - dict: Model predictions mapping
392
  """
393
  if not api_key:
394
  raise gr.Error("Missing API key, please enter a valid API key")
395
 
396
- # Set up unique output directory based on sequence hash
397
- pdb_paths = []
398
- plddt_traces = []
399
- total_models = len(model_types)
400
  model_predictions = {}
401
 
402
- for i, model_type in enumerate(model_types):
403
- progress(i / total_models, desc=f"Running {model_type} prediction...")
404
- model_pdb_paths, model_plddt_traces = predict(
405
- sequence, api_key, model_type, format_fasta=True
406
- )
407
- pdb_paths += model_pdb_paths
408
- plddt_traces += model_plddt_traces.data
409
- model_predictions[model_type] = [int(Path(p).stem[-1]) for p in model_pdb_paths]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
  progress(0.9, desc="Aligning structures...")
412
- aligned_paths = align_structures(pdb_paths)
413
- plddt_fig = go.Figure(data=plddt_traces)
414
- plddt_fig.update_layout(
415
- title="pLDDT",
416
- xaxis_title="Residue index",
417
- yaxis_title="pLDDT",
418
- height=500,
419
- template="simple_white",
420
- legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
421
- )
422
 
423
  progress(1.0, desc="Done!")
424
 
425
  # Create checkbox groups for each model type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  chai_predictions = gr.CheckboxGroup(
427
  visible=model_predictions.get(FoldingModel.CHAI) is not None,
428
- choices=model_predictions.get(FoldingModel.CHAI, []),
429
- value=model_predictions.get(FoldingModel.CHAI, []),
430
  )
431
  boltz_predictions = gr.CheckboxGroup(
432
  visible=model_predictions.get(FoldingModel.BOLTZ) is not None,
433
- choices=model_predictions.get(FoldingModel.BOLTZ, []),
434
- value=model_predictions.get(FoldingModel.BOLTZ, []),
435
  )
436
  protenix_predictions = gr.CheckboxGroup(
437
  visible=model_predictions.get(FoldingModel.PROTENIX) is not None,
438
- choices=model_predictions.get(FoldingModel.PROTENIX, []),
439
- value=model_predictions.get(FoldingModel.PROTENIX, []),
440
  )
441
 
442
  return (
 
 
 
 
443
  chai_predictions,
444
  boltz_predictions,
445
  protenix_predictions,
446
- aligned_paths,
447
- plddt_fig,
448
  )
 
1
  """Predict protein structure using Folding Studio."""
2
 
3
+ import concurrent.futures
4
  import hashlib
5
  import logging
6
  from io import StringIO
7
  from pathlib import Path
8
+ from typing import Any
9
 
10
  import gradio as gr
11
  import numpy as np
 
14
  from Bio.PDB import PDBIO, MMCIFParser, PDBParser, Superimposer
15
  from folding_studio_data_models import FoldingModel
16
 
17
+ from folding_studio_demo.models import (
18
+ AF2Model,
19
+ BoltzModel,
20
+ ChaiModel,
21
+ OpenFoldModel,
22
+ ProtenixModel,
23
+ )
24
 
25
  logger = logging.getLogger(__name__)
26
 
 
93
  def create_plddt_figure(
94
  plddt_vals: list[list[float]],
95
  model_name: str,
96
+ indexes: list[int],
97
  residue_codes: list[list[str]] = None,
98
  ) -> go.Figure:
99
  """Create a plot of metrics."""
100
  plddt_traces = []
101
+
102
+ for i, (plddt_val, index) in enumerate(zip(plddt_vals, indexes)):
103
  # Create hover text with residue codes if available
104
  if residue_codes and i < len(residue_codes):
105
  hover_text = [
106
+ f"<i>{model_name} {index}</i><br><i>pLDDT</i>: {plddt:.2f}<br><i>Residue:</i> {code} {idx}"
107
  for idx, (plddt, code) in enumerate(zip(plddt_val, residue_codes[i]))
108
  ]
109
  else:
110
  hover_text = [
111
+ f"<i>{model_name} {index}</i><br><i>pLDDT</i>: {plddt:.2f}<br><i>Residue index:</i> {idx}"
112
  for idx, plddt in enumerate(plddt_val)
113
  ]
114
 
 
118
  y=plddt_val,
119
  hovertemplate="%{text}<extra></extra>",
120
  text=hover_text,
121
+ name=f"{model_name} {index}",
122
  visible=True,
123
  )
124
  )
 
160
  return seq_id, seq_file
161
 
162
 
163
+ def extract_plddt_from_structure(structure_path: str) -> tuple[list[float], list[str]]:
164
+ """Extract pLDDT values and residue codes from a structure file.
165
+
166
+ Args:
167
+ structure_path (Path): Path to structure file
168
+
169
+ Returns:
170
+ tuple[list[float], list[str]]: Tuple containing lists of pLDDT values and residue codes
171
+ """
172
+ if Path(structure_path).suffix == ".cif":
173
+ structure = MMCIFParser().get_structure("structure", structure_path)
174
+ else:
175
+ structure = PDBParser().get_structure("structure", structure_path)
176
 
177
  # Lists to store pLDDT values and residue codes
178
  plddt_values = []
 
227
  model = ChaiModel(api_key)
228
  elif model_type == FoldingModel.PROTENIX:
229
  model = ProtenixModel(api_key)
230
+ elif model_type == FoldingModel.AF2:
231
+ model = AF2Model(api_key)
232
+ elif model_type == FoldingModel.OPENFOLD:
233
+ model = OpenFoldModel(api_key)
234
  else:
235
  raise ValueError(f"Model {model_type} not supported")
236
 
 
260
  progress(
261
  0.4 + (0.4 * i / total_predictions), desc=f"Converting model {model_idx}..."
262
  )
263
+ prediction_path = prediction["prediction_path"]
264
+ logger.info(f"Prediction file: {prediction_path}")
265
+ if Path(prediction_path).suffix == ".cif":
266
+ converted_pdb_path = str(
267
+ output_dir / f"{model.model_name}_prediction_{model_idx}.pdb"
268
+ )
269
+ convert_cif_to_pdb(str(prediction_path), str(converted_pdb_path))
270
+ pdb_paths.append(converted_pdb_path)
271
+ else:
272
+ pdb_paths.append(str(prediction_path))
273
+ plddt_vals, residue_codes = extract_plddt_from_structure(prediction_path)
274
  model_plddt_vals.append(plddt_vals)
275
  model_residue_codes.append(residue_codes)
276
 
277
  progress(0.8, desc="Generating plots...")
278
+ indexes = []
279
+ for pdb_path in pdb_paths:
280
+ if model_type in [
281
+ FoldingModel.AF2,
282
+ FoldingModel.OPENFOLD,
283
+ FoldingModel.SOLOSEQ,
284
+ ]:
285
+ indexes.append(int(Path(pdb_path).stem.split("_")[2]))
286
+ else:
287
+ indexes.append(int(Path(pdb_path).stem[-1]))
288
+
289
  plddt_fig = create_plddt_figure(
290
  plddt_vals=model_plddt_vals,
291
  model_name=model.model_name,
292
+ indexes=indexes,
293
  residue_codes=model_residue_codes,
294
  )
295
 
 
297
  return pdb_paths, plddt_fig
298
 
299
 
300
+ def align_structures(
301
+ model_predictions: dict[FoldingModel, dict[int, dict[str, Any]]],
302
+ ) -> list[str]:
303
  """Align multiple PDB structures to the first structure.
304
 
305
  Args:
306
+ model_predictions (dict[FoldingModel, dict[int, dict[str, Any]]]): Dictionary mapping models to their prediction indices
307
 
308
  Returns:
309
  list[str]: List of paths to aligned PDB files
 
312
  parser = PDBParser()
313
  io = PDBIO()
314
 
315
+ # Get the first structure as reference
316
+ first_model = next(iter(model_predictions.keys()))
317
+ first_pred = next(iter(model_predictions[first_model].values()))
318
+ ref_pdb_path = first_pred["pdb_path"]
319
+
320
+ # Parse reference structure and get CA atoms
321
+ ref_structure = parser.get_structure("reference", ref_pdb_path)
322
  ref_atoms = [atom for atom in ref_structure.get_atoms() if atom.get_name() == "CA"]
323
 
324
+ for model_type in model_predictions.keys():
325
+ for index, prediction in model_predictions[model_type].items():
326
+ pdb_path = prediction["pdb_path"]
327
+
328
+ # Parse the structure to align
329
+ structure = parser.get_structure(f"{model_type}_{index}", pdb_path)
330
+ atoms = [atom for atom in structure.get_atoms() if atom.get_name() == "CA"]
331
 
332
+ # Create superimposer
333
+ sup = Superimposer()
 
 
 
334
 
335
+ # Set the reference and moving atoms
336
+ sup.set_atoms(ref_atoms, atoms)
337
 
338
+ # Apply the transformation to all atoms in the structure
339
+ sup.apply(structure.get_atoms())
340
 
341
+ # Save the aligned structure
342
+ aligned_path = str(Path(pdb_path).parent / f"aligned_{Path(pdb_path).name}")
343
+ io.set_structure(structure)
344
+ io.save(aligned_path)
345
 
346
+ model_predictions[model_type][index]["pdb_path"] = aligned_path
 
 
 
 
347
 
348
+ return model_predictions
349
 
350
 
351
  def filter_predictions(
352
+ model_predictions: dict[FoldingModel, dict[int, dict[str, Any]]],
353
+ af2_selected: list[int],
354
+ openfold_selected: list[int],
355
+ solo_selected: list[int],
356
  chai_selected: list[int],
357
  boltz_selected: list[int],
358
  protenix_selected: list[int],
 
365
  chai_selected (list[int]): Selected Chai model indices
366
  boltz_selected (list[int]): Selected Boltz model indices
367
  protenix_selected (list[int]): Selected Protenix model indices
368
+ model_predictions (dict[FoldingModel, dict[int, dict[str, Any]]]): Dictionary mapping models to their prediction indices
369
 
370
  Returns:
371
  tuple[list[str], go.Figure]: Filtered PDB paths and updated pLDDT plot
 
374
  filtered_fig = go.Figure()
375
 
376
  # Keep track of which traces to show
377
+ filtered_paths = []
378
 
379
  # Helper function to check if a trace should be visible
380
+ def should_show_trace(model_name, pred_index: int) -> bool:
381
+ if model_name == FoldingModel.CHAI and pred_index in chai_selected:
382
+ return True
383
+ if model_name == FoldingModel.BOLTZ and pred_index in boltz_selected:
384
+ return True
385
+ if model_name == FoldingModel.PROTENIX and pred_index in protenix_selected:
386
+ return True
387
+ if model_name == FoldingModel.AF2 and pred_index in af2_selected:
388
  return True
389
+ if model_name == FoldingModel.OPENFOLD and pred_index in openfold_selected:
390
  return True
391
+ if model_name == FoldingModel.SOLOSEQ and pred_index in solo_selected:
392
  return True
393
  return False
394
 
395
  # Filter traces and paths
396
+ for model_type in model_predictions.keys():
397
+ for index, prediction in model_predictions[model_type].items():
398
+ if should_show_trace(model_type, index):
399
+ filtered_fig.add_trace(prediction["plddt_trace"])
400
+ filtered_paths.append(prediction["pdb_path"])
401
 
402
  # Update layout
403
  filtered_fig.update_layout(
 
408
  template="simple_white",
409
  legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
410
  )
411
+ return filtered_paths, filtered_fig
412
 
413
+
414
+ def run_prediction(
415
+ sequence: str,
416
+ api_key: str,
417
+ model_type: FoldingModel,
418
+ format_fasta: bool = False,
419
+ ) -> dict[FoldingModel, dict[int, dict[str, Any]]]:
420
+ """Run a single prediction.
421
+
422
+ Args:
423
+ sequence (str): Amino acid sequence to predict structure for
424
+ api_key (str): Folding API key
425
+ model_type (FoldingModel): Folding model to use
426
+ format_fasta (bool): Whether to format the FASTA file
427
+
428
+ Returns:
429
+ Tuple containing:
430
+ - List of PDB paths
431
+ - pLDDT plot
432
+ - Dictionary mapping model to prediction indices
433
+ """
434
+ model_pdb_paths, model_plddt_traces = predict(
435
+ sequence, api_key, model_type, format_fasta=format_fasta
436
+ )
437
+ model_pdb_paths = sorted(model_pdb_paths)
438
+ model_predictions = {}
439
+ for pdb_path, plddt_trace in zip(model_pdb_paths, model_plddt_traces.data):
440
+ if model_type in [
441
+ FoldingModel.AF2,
442
+ FoldingModel.OPENFOLD,
443
+ FoldingModel.SOLOSEQ,
444
+ ]:
445
+ index = int(Path(pdb_path).stem.split("_")[2])
446
+ else:
447
+ index = int(Path(pdb_path).stem[-1])
448
+
449
+ model_predictions[index] = {"pdb_path": pdb_path, "plddt_trace": plddt_trace}
450
+ return model_predictions
451
 
452
 
453
  def predict_comparison(
454
  sequence: str, api_key: str, model_types: list[FoldingModel], progress=gr.Progress()
455
  ) -> tuple[
456
+ dict[FoldingModel, dict[int, dict[str, Any]]],
457
+ gr.CheckboxGroup,
458
+ gr.CheckboxGroup,
459
+ gr.CheckboxGroup,
460
  gr.CheckboxGroup,
461
  gr.CheckboxGroup,
462
  gr.CheckboxGroup,
 
 
 
463
  ]:
464
  """Predict protein structure from amino acid sequence using multiple models.
465
 
 
471
 
472
  Returns:
473
  tuple containing:
474
+ - dict[FoldingModel, dict[int, dict[str, Any]]]: Model predictions mapping
475
+ - gr.CheckboxGroup: AF2 predictions checkbox group
476
+ - gr.CheckboxGroup: OpenFold predictions checkbox group
477
+ - gr.CheckboxGroup: SoloSeq predictions checkbox group
478
  - gr.CheckboxGroup: Chai predictions checkbox group
479
  - gr.CheckboxGroup: Boltz predictions checkbox group
480
  - gr.CheckboxGroup: Protenix predictions checkbox group
 
 
 
481
  """
482
  if not api_key:
483
  raise gr.Error("Missing API key, please enter a valid API key")
484
 
485
+ progress(0, desc="Starting parallel predictions...")
486
+
487
+ # Run predictions in parallel
 
488
  model_predictions = {}
489
 
490
+ with concurrent.futures.ThreadPoolExecutor() as executor:
491
+ # Create a future for each model prediction
492
+ future_to_model = {
493
+ executor.submit(
494
+ run_prediction, sequence, api_key, model_type, True
495
+ ): model_type
496
+ for model_type in model_types
497
+ }
498
+
499
+ # Process results as they complete
500
+ total_models = len(model_types)
501
+ completed = 0
502
+
503
+ for future in concurrent.futures.as_completed(future_to_model):
504
+ model_type = future_to_model[future]
505
+ try:
506
+ model_preds = future.result()
507
+ model_predictions[model_type] = model_preds
508
+
509
+ completed += 1
510
+ progress(
511
+ completed / total_models,
512
+ desc=f"Completed {model_type} prediction...",
513
+ )
514
+ except Exception as e:
515
+ logger.error(f"Prediction failed for {model_type}: {str(e)}")
516
+ raise gr.Error(f"Prediction failed for {model_type}: {str(e)}")
517
 
518
  progress(0.9, desc="Aligning structures...")
519
+
520
+ model_predictions = align_structures(model_predictions)
 
 
 
 
 
 
 
 
521
 
522
  progress(1.0, desc="Done!")
523
 
524
  # Create checkbox groups for each model type
525
+ af2_predictions = gr.CheckboxGroup(
526
+ visible=model_predictions.get(FoldingModel.AF2) is not None,
527
+ choices=list(model_predictions.get(FoldingModel.AF2, {}).keys()),
528
+ value=list(model_predictions.get(FoldingModel.AF2, {}).keys()),
529
+ )
530
+ openfold_predictions = gr.CheckboxGroup(
531
+ visible=model_predictions.get(FoldingModel.OPENFOLD) is not None,
532
+ choices=list(model_predictions.get(FoldingModel.OPENFOLD, {}).keys()),
533
+ value=list(model_predictions.get(FoldingModel.OPENFOLD, {}).keys()),
534
+ )
535
+ solo_predictions = gr.CheckboxGroup(
536
+ visible=model_predictions.get(FoldingModel.SOLOSEQ) is not None,
537
+ choices=list(model_predictions.get(FoldingModel.SOLOSEQ, {}).keys()),
538
+ value=list(model_predictions.get(FoldingModel.SOLOSEQ, {}).keys()),
539
+ )
540
  chai_predictions = gr.CheckboxGroup(
541
  visible=model_predictions.get(FoldingModel.CHAI) is not None,
542
+ choices=list(model_predictions.get(FoldingModel.CHAI, {}).keys()),
543
+ value=list(model_predictions.get(FoldingModel.CHAI, {}).keys()),
544
  )
545
  boltz_predictions = gr.CheckboxGroup(
546
  visible=model_predictions.get(FoldingModel.BOLTZ) is not None,
547
+ choices=list(model_predictions.get(FoldingModel.BOLTZ, {}).keys()),
548
+ value=list(model_predictions.get(FoldingModel.BOLTZ, {}).keys()),
549
  )
550
  protenix_predictions = gr.CheckboxGroup(
551
  visible=model_predictions.get(FoldingModel.PROTENIX) is not None,
552
+ choices=list(model_predictions.get(FoldingModel.PROTENIX, {}).keys()),
553
+ value=list(model_predictions.get(FoldingModel.PROTENIX, {}).keys()),
554
  )
555
 
556
  return (
557
+ model_predictions,
558
+ af2_predictions,
559
+ openfold_predictions,
560
+ solo_predictions,
561
  chai_predictions,
562
  boltz_predictions,
563
  protenix_predictions,
 
 
564
  )