Add model comparison (#3)
Browse files- add model comparison (485bd69b80145fab090b2e9d368082247e7a4d93)
Co-authored-by: Achille Soulie <[email protected]>
- folding_studio_demo/app.py +39 -22
- folding_studio_demo/predict.py +173 -28
folding_studio_demo/app.py
CHANGED
@@ -3,12 +3,16 @@
|
|
3 |
import logging
|
4 |
|
5 |
import gradio as gr
|
|
|
6 |
from folding_studio_data_models import FoldingModel
|
7 |
from gradio_molecule3d import Molecule3D
|
8 |
-
import pandas as pd
|
9 |
|
10 |
-
from folding_studio_demo.
|
11 |
-
|
|
|
|
|
|
|
|
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
@@ -119,7 +123,7 @@ def model_comparison(api_key: str) -> None:
|
|
119 |
"""
|
120 |
|
121 |
with gr.Row():
|
122 |
-
|
123 |
label="Model",
|
124 |
choices=MODEL_CHOICES,
|
125 |
multiselect=True,
|
@@ -133,13 +137,18 @@ def model_comparison(api_key: str) -> None:
|
|
133 |
predict_btn = gr.Button("Compare Models")
|
134 |
|
135 |
with gr.Row():
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
predict_btn.click(
|
140 |
-
fn=
|
141 |
-
inputs=[sequence, api_key,
|
142 |
-
outputs=[
|
143 |
)
|
144 |
|
145 |
|
@@ -147,12 +156,12 @@ def create_correlation_tab():
|
|
147 |
gr.Markdown("# Correlation with experimental binding affinity data")
|
148 |
spr_data_with_scores = pd.read_csv("spr_af_scores_mapped.csv")
|
149 |
prettified_columns = {
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
spr_data_with_scores = spr_data_with_scores.rename(columns=prettified_columns)
|
157 |
with gr.Row():
|
158 |
columns = [
|
@@ -160,10 +169,13 @@ def create_correlation_tab():
|
|
160 |
"KD (nM)",
|
161 |
"Antibody VH Sequence",
|
162 |
"Antibody VL Sequence",
|
163 |
-
"Antigen Sequence"
|
164 |
]
|
165 |
# Display dataframe with floating point values rounded to 2 decimal places
|
166 |
-
spr_data = gr.DataFrame(
|
|
|
|
|
|
|
167 |
|
168 |
gr.Markdown("# Prediction and correlation")
|
169 |
with gr.Row():
|
@@ -174,22 +186,27 @@ def create_correlation_tab():
|
|
174 |
correlation_ranking_plot = gr.Plot(label="Correlation ranking")
|
175 |
with gr.Row():
|
176 |
# User can select the columns to display in the correlation plot
|
177 |
-
correlation_column = gr.Dropdown(
|
|
|
|
|
178 |
correlation_plot = gr.Plot(label="Correlation with binding affinity")
|
179 |
|
180 |
fake_predict_btn.click(
|
181 |
-
fn=lambda x: fake_predict_and_correlate(
|
|
|
|
|
182 |
inputs=None,
|
183 |
-
outputs=[prediction_dataframe, correlation_ranking_plot]
|
184 |
)
|
185 |
|
186 |
# Call function to update the correlation plot when the user selects the columns
|
187 |
correlation_column.change(
|
188 |
fn=lambda score: select_correlation_plot(spr_data_with_scores, score),
|
189 |
inputs=correlation_column,
|
190 |
-
outputs=correlation_plot
|
191 |
)
|
192 |
-
|
|
|
193 |
def __main__():
|
194 |
with gr.Blocks(title="Folding Studio Demo") as demo:
|
195 |
gr.Markdown(
|
|
|
3 |
import logging
|
4 |
|
5 |
import gradio as gr
|
6 |
+
import pandas as pd
|
7 |
from folding_studio_data_models import FoldingModel
|
8 |
from gradio_molecule3d import Molecule3D
|
|
|
9 |
|
10 |
+
from folding_studio_demo.correlate import (
|
11 |
+
SCORE_COLUMNS,
|
12 |
+
fake_predict_and_correlate,
|
13 |
+
select_correlation_plot,
|
14 |
+
)
|
15 |
+
from folding_studio_demo.predict import predict, predict_comparison
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
18 |
|
|
|
123 |
"""
|
124 |
|
125 |
with gr.Row():
|
126 |
+
models = gr.Dropdown(
|
127 |
label="Model",
|
128 |
choices=MODEL_CHOICES,
|
129 |
multiselect=True,
|
|
|
137 |
predict_btn = gr.Button("Compare Models")
|
138 |
|
139 |
with gr.Row():
|
140 |
+
mol_outputs = Molecule3D(
|
141 |
+
label="Protein Structure",
|
142 |
+
reps=MOLECULE_REPS,
|
143 |
+
file_count="multiple",
|
144 |
+
)
|
145 |
+
|
146 |
+
# metrics_plot = gr.Plot(label="pLDDT")
|
147 |
|
148 |
predict_btn.click(
|
149 |
+
fn=predict_comparison,
|
150 |
+
inputs=[sequence, api_key, models],
|
151 |
+
outputs=[mol_outputs],
|
152 |
)
|
153 |
|
154 |
|
|
|
156 |
gr.Markdown("# Correlation with experimental binding affinity data")
|
157 |
spr_data_with_scores = pd.read_csv("spr_af_scores_mapped.csv")
|
158 |
prettified_columns = {
|
159 |
+
"antibody_name": "Antibody Name",
|
160 |
+
"KD (nM)": "KD (nM)",
|
161 |
+
"antibody_vh_sequence": "Antibody VH Sequence",
|
162 |
+
"antibody_vl_sequence": "Antibody VL Sequence",
|
163 |
+
"antigen_sequence": "Antigen Sequence",
|
164 |
+
}
|
165 |
spr_data_with_scores = spr_data_with_scores.rename(columns=prettified_columns)
|
166 |
with gr.Row():
|
167 |
columns = [
|
|
|
169 |
"KD (nM)",
|
170 |
"Antibody VH Sequence",
|
171 |
"Antibody VL Sequence",
|
172 |
+
"Antigen Sequence",
|
173 |
]
|
174 |
# Display dataframe with floating point values rounded to 2 decimal places
|
175 |
+
spr_data = gr.DataFrame(
|
176 |
+
value=spr_data_with_scores[columns].round(2),
|
177 |
+
label="Experimental Antibody-Antigen Binding Affinity Data",
|
178 |
+
)
|
179 |
|
180 |
gr.Markdown("# Prediction and correlation")
|
181 |
with gr.Row():
|
|
|
186 |
correlation_ranking_plot = gr.Plot(label="Correlation ranking")
|
187 |
with gr.Row():
|
188 |
# User can select the columns to display in the correlation plot
|
189 |
+
correlation_column = gr.Dropdown(
|
190 |
+
label="Score data to display", choices=SCORE_COLUMNS, multiselect=False
|
191 |
+
)
|
192 |
correlation_plot = gr.Plot(label="Correlation with binding affinity")
|
193 |
|
194 |
fake_predict_btn.click(
|
195 |
+
fn=lambda x: fake_predict_and_correlate(
|
196 |
+
spr_data_with_scores, SCORE_COLUMNS, ["Antibody Name", "KD (nM)"]
|
197 |
+
),
|
198 |
inputs=None,
|
199 |
+
outputs=[prediction_dataframe, correlation_ranking_plot],
|
200 |
)
|
201 |
|
202 |
# Call function to update the correlation plot when the user selects the columns
|
203 |
correlation_column.change(
|
204 |
fn=lambda score: select_correlation_plot(spr_data_with_scores, score),
|
205 |
inputs=correlation_column,
|
206 |
+
outputs=correlation_plot,
|
207 |
)
|
208 |
+
|
209 |
+
|
210 |
def __main__():
|
211 |
with gr.Blocks(title="Folding Studio Demo") as demo:
|
212 |
gr.Markdown(
|
folding_studio_demo/predict.py
CHANGED
@@ -3,13 +3,15 @@
|
|
3 |
import hashlib
|
4 |
import logging
|
5 |
import os
|
|
|
6 |
from pathlib import Path
|
|
|
7 |
|
8 |
import gradio as gr
|
9 |
import numpy as np
|
10 |
import plotly.graph_objects as go
|
11 |
from Bio import SeqIO
|
12 |
-
from Bio.PDB import PDBIO, MMCIFParser
|
13 |
from folding_studio.client import Client
|
14 |
from folding_studio.query import Query
|
15 |
from folding_studio.query.boltz import BoltzQuery
|
@@ -50,18 +52,21 @@ def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None:
|
|
50 |
io.save(pdb_path)
|
51 |
|
52 |
|
53 |
-
def add_plddt_plot(plddt_vals: list[float]) ->
|
54 |
"""Create a plot of metrics."""
|
55 |
visible = True
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
63 |
|
64 |
-
plddt_fig = go.Figure(data=
|
65 |
plddt_fig.update_layout(
|
66 |
title="pLDDT",
|
67 |
xaxis_title="Residue index",
|
@@ -85,7 +90,13 @@ def _write_fasta_file(
|
|
85 |
Returns:
|
86 |
tuple[str, Path]: Tuple containing the sequence ID and the path to the FASTA file
|
87 |
"""
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
seq_file = directory / f"sequence_{seq_id}.fasta"
|
90 |
with open(seq_file, "w") as f:
|
91 |
f.write(sequence)
|
@@ -146,7 +157,7 @@ class AF3Model:
|
|
146 |
|
147 |
def has_prediction(self, output_dir: Path) -> bool:
|
148 |
"""Check if prediction exists in output directory."""
|
149 |
-
return
|
150 |
|
151 |
def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]:
|
152 |
"""Check if the file description is correct.
|
@@ -157,10 +168,6 @@ class AF3Model:
|
|
157 |
Returns:
|
158 |
tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not
|
159 |
"""
|
160 |
-
input_rep = list(SeqIO.parse(seq_file, "fasta"))
|
161 |
-
if not input_rep:
|
162 |
-
error_msg = f"{self.model_name.upper()} Validation Error: No sequence found"
|
163 |
-
return False, error_msg
|
164 |
|
165 |
is_valid, error_msg = self.validator.is_valid_fasta(seq_file)
|
166 |
if not is_valid:
|
@@ -182,9 +189,41 @@ class ChaiModel(AF3Model):
|
|
182 |
"""
|
183 |
super().call(seq_file, output_dir)
|
184 |
|
185 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
"""Get the path to the prediction."""
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
|
190 |
class ProtenixModel(AF3Model):
|
@@ -221,7 +260,33 @@ class BoltzModel(AF3Model):
|
|
221 |
|
222 |
def predictions(self, output_dir: Path) -> list[Path]:
|
223 |
"""Get the path to the prediction."""
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
|
227 |
def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str, str]:
|
@@ -235,6 +300,8 @@ def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str,
|
|
235 |
Returns:
|
236 |
tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot
|
237 |
"""
|
|
|
|
|
238 |
|
239 |
# Set up unique output directory based on sequence hash
|
240 |
seq_id, seq_file = _write_fasta_file(sequence)
|
@@ -265,15 +332,93 @@ def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str,
|
|
265 |
if not model.has_prediction(output_dir):
|
266 |
raise gr.Error("No prediction found")
|
267 |
|
268 |
-
|
269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
|
271 |
-
|
272 |
-
|
273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
-
|
276 |
-
logger.info("plddt file: %s", plddt_file)
|
277 |
-
plddt_vals = np.load(plddt_file)["plddt"]
|
278 |
|
279 |
-
return
|
|
|
3 |
import hashlib
|
4 |
import logging
|
5 |
import os
|
6 |
+
from io import StringIO
|
7 |
from pathlib import Path
|
8 |
+
from typing import Any
|
9 |
|
10 |
import gradio as gr
|
11 |
import numpy as np
|
12 |
import plotly.graph_objects as go
|
13 |
from Bio import SeqIO
|
14 |
+
from Bio.PDB import PDBIO, MMCIFParser, PDBParser, Superimposer
|
15 |
from folding_studio.client import Client
|
16 |
from folding_studio.query import Query
|
17 |
from folding_studio.query.boltz import BoltzQuery
|
|
|
52 |
io.save(pdb_path)
|
53 |
|
54 |
|
55 |
+
def add_plddt_plot(plddt_vals: list[list[float]], model_name: str) -> go.Figure:
|
56 |
"""Create a plot of metrics."""
|
57 |
visible = True
|
58 |
+
plddt_traces = [
|
59 |
+
go.Scatter(
|
60 |
+
x=np.arange(len(plddt_val)),
|
61 |
+
y=plddt_val,
|
62 |
+
hovertemplate="<i>pLDDT</i>: %{y:.2f} <br><i>Residue index:</i> %{x}<br>",
|
63 |
+
name=f"{model_name} {i}",
|
64 |
+
visible=visible,
|
65 |
+
)
|
66 |
+
for i, plddt_val in enumerate(plddt_vals)
|
67 |
+
]
|
68 |
|
69 |
+
plddt_fig = go.Figure(data=plddt_traces)
|
70 |
plddt_fig.update_layout(
|
71 |
title="pLDDT",
|
72 |
xaxis_title="Residue index",
|
|
|
90 |
Returns:
|
91 |
tuple[str, Path]: Tuple containing the sequence ID and the path to the FASTA file
|
92 |
"""
|
93 |
+
input_rep = list(SeqIO.parse(StringIO(sequence), "fasta"))
|
94 |
+
if not input_rep:
|
95 |
+
raise gr.Error("No sequence found")
|
96 |
+
|
97 |
+
seq_id = hashlib.sha256(
|
98 |
+
"_".join([str(records.seq) for records in input_rep]).encode()
|
99 |
+
).hexdigest()
|
100 |
seq_file = directory / f"sequence_{seq_id}.fasta"
|
101 |
with open(seq_file, "w") as f:
|
102 |
f.write(sequence)
|
|
|
157 |
|
158 |
def has_prediction(self, output_dir: Path) -> bool:
|
159 |
"""Check if prediction exists in output directory."""
|
160 |
+
return len(self.predictions(output_dir)) > 0
|
161 |
|
162 |
def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]:
|
163 |
"""Check if the file description is correct.
|
|
|
168 |
Returns:
|
169 |
tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not
|
170 |
"""
|
|
|
|
|
|
|
|
|
171 |
|
172 |
is_valid, error_msg = self.validator.is_valid_fasta(seq_file)
|
173 |
if not is_valid:
|
|
|
189 |
"""
|
190 |
super().call(seq_file, output_dir)
|
191 |
|
192 |
+
def _get_chai_paired_files(self, directory: Path) -> list[tuple[Path, Path]]:
|
193 |
+
"""Get pairs of .cif and .npz files with matching model indices.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
directory (Path): Directory containing the prediction files
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
list[tuple[Path, Path]]: List of tuples containing (cif_path, npz_path) pairs
|
200 |
+
"""
|
201 |
+
# Get all cif files and extract their indices
|
202 |
+
|
203 |
+
def predictions(self, output_dir: Path) -> dict[Path, dict[str, Any]]:
|
204 |
"""Get the path to the prediction."""
|
205 |
+
prediction = next(output_dir.rglob("pred.model_idx_[0-9].cif"), None)
|
206 |
+
if prediction is None:
|
207 |
+
return {}
|
208 |
+
|
209 |
+
cif_files = {
|
210 |
+
int(f.stem.split("model_idx_")[1]): f
|
211 |
+
for f in prediction.parent.glob("pred.model_idx_*.cif")
|
212 |
+
}
|
213 |
+
|
214 |
+
# Get all npz files and extract their indices
|
215 |
+
npz_files = {
|
216 |
+
int(f.stem.split("model_idx_")[1]): f
|
217 |
+
for f in prediction.parent.glob("scores.model_idx_*.npz")
|
218 |
+
}
|
219 |
+
|
220 |
+
# Find common indices and create pairs
|
221 |
+
common_indices = sorted(set(cif_files.keys()) & set(npz_files.keys()))
|
222 |
+
|
223 |
+
return {
|
224 |
+
idx: {"prediction_path": cif_files[idx], "metrics": np.load(npz_files[idx])}
|
225 |
+
for idx in common_indices
|
226 |
+
}
|
227 |
|
228 |
|
229 |
class ProtenixModel(AF3Model):
|
|
|
260 |
|
261 |
def predictions(self, output_dir: Path) -> list[Path]:
|
262 |
"""Get the path to the prediction."""
|
263 |
+
prediction_paths = list(output_dir.rglob("*_model_[0-9].cif"))
|
264 |
+
return {
|
265 |
+
int(cif_path.stem[-1]): {
|
266 |
+
"prediction_path": cif_path,
|
267 |
+
"metrics": np.load(list(cif_path.parent.glob("plddt_*.npz"))[0]),
|
268 |
+
}
|
269 |
+
for cif_path in prediction_paths
|
270 |
+
}
|
271 |
+
|
272 |
+
|
273 |
+
def extract_plddt_from_cif(cif_path):
|
274 |
+
structure = MMCIFParser().get_structure("structure", cif_path)
|
275 |
+
|
276 |
+
# Dictionary to store pLDDT values per residue
|
277 |
+
plddt_values = []
|
278 |
+
|
279 |
+
# Iterate through all atoms
|
280 |
+
for model in structure:
|
281 |
+
for chain in model:
|
282 |
+
for residue in chain:
|
283 |
+
# Get the first atom of each residue (usually CA atom)
|
284 |
+
if "CA" in residue:
|
285 |
+
# The B-factor contains the pLDDT value
|
286 |
+
plddt = residue["CA"].get_bfactor()
|
287 |
+
plddt_values.append(plddt)
|
288 |
+
|
289 |
+
return plddt_values
|
290 |
|
291 |
|
292 |
def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str, str]:
|
|
|
300 |
Returns:
|
301 |
tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot
|
302 |
"""
|
303 |
+
if not api_key:
|
304 |
+
raise gr.Error("Missing API key, please enter a valid API key")
|
305 |
|
306 |
# Set up unique output directory based on sequence hash
|
307 |
seq_id, seq_file = _write_fasta_file(sequence)
|
|
|
332 |
if not model.has_prediction(output_dir):
|
333 |
raise gr.Error("No prediction found")
|
334 |
|
335 |
+
predictions = model.predictions(output_dir)
|
336 |
+
pdb_paths = []
|
337 |
+
model_plddt_vals = []
|
338 |
+
for model_idx, prediction in predictions.items():
|
339 |
+
cif_path = prediction["prediction_path"]
|
340 |
+
logger.info(
|
341 |
+
"CIF file: %s",
|
342 |
+
)
|
343 |
+
|
344 |
+
converted_pdb_path = str(
|
345 |
+
output_dir / f"{model.model_name}_prediction_{model_idx}.pdb"
|
346 |
+
)
|
347 |
+
convert_cif_to_pdb(str(cif_path), str(converted_pdb_path))
|
348 |
+
plddt_vals = extract_plddt_from_cif(cif_path)
|
349 |
+
pdb_paths.append(converted_pdb_path)
|
350 |
+
model_plddt_vals.append(plddt_vals)
|
351 |
+
plddt_plot = add_plddt_plot(
|
352 |
+
plddt_vals=model_plddt_vals, model_name=model.model_name
|
353 |
+
)
|
354 |
+
return pdb_paths, plddt_plot
|
355 |
+
|
356 |
+
|
357 |
+
def align_structures(pdb_paths: list[str]) -> list[str]:
|
358 |
+
"""Align multiple PDB structures to the first structure.
|
359 |
|
360 |
+
Args:
|
361 |
+
pdb_paths (list[str]): List of paths to PDB files to align
|
362 |
+
|
363 |
+
Returns:
|
364 |
+
list[str]: List of paths to aligned PDB files
|
365 |
+
"""
|
366 |
+
|
367 |
+
parser = PDBParser()
|
368 |
+
io = PDBIO()
|
369 |
+
|
370 |
+
# Parse the reference structure (first one)
|
371 |
+
ref_structure = parser.get_structure("reference", pdb_paths[0])
|
372 |
+
ref_atoms = [atom for atom in ref_structure.get_atoms() if atom.get_name() == "CA"]
|
373 |
+
|
374 |
+
aligned_paths = [pdb_paths[0]] # First structure is already aligned
|
375 |
+
|
376 |
+
# Align each subsequent structure to the reference
|
377 |
+
for i, pdb_path in enumerate(pdb_paths[1:], start=1):
|
378 |
+
# Parse the structure to align
|
379 |
+
structure = parser.get_structure(f"model_{i}", pdb_path)
|
380 |
+
atoms = [atom for atom in structure.get_atoms() if atom.get_name() == "CA"]
|
381 |
+
|
382 |
+
# Create superimposer
|
383 |
+
sup = Superimposer()
|
384 |
+
|
385 |
+
# Set the reference and moving atoms
|
386 |
+
sup.set_atoms(ref_atoms, atoms)
|
387 |
+
|
388 |
+
# Apply the transformation to all atoms in the structure
|
389 |
+
sup.apply(structure.get_atoms())
|
390 |
+
|
391 |
+
# Save the aligned structure
|
392 |
+
aligned_path = str(Path(pdb_path).parent / f"aligned_{Path(pdb_path).name}")
|
393 |
+
io.set_structure(structure)
|
394 |
+
io.save(aligned_path)
|
395 |
+
aligned_paths.append(aligned_path)
|
396 |
+
|
397 |
+
return aligned_paths
|
398 |
+
|
399 |
+
|
400 |
+
def predict_comparison(
|
401 |
+
sequence: str, api_key: str, model_types: list[FoldingModel]
|
402 |
+
) -> tuple[str, str]:
|
403 |
+
"""Predict protein structure from amino acid sequence using Boltz model.
|
404 |
+
|
405 |
+
Args:
|
406 |
+
sequence (str): Amino acid sequence to predict structure for
|
407 |
+
api_key (str): Folding API key
|
408 |
+
model (FoldingModel): Folding model to use
|
409 |
+
|
410 |
+
Returns:
|
411 |
+
tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot
|
412 |
+
"""
|
413 |
+
if not api_key:
|
414 |
+
raise gr.Error("Missing API key, please enter a valid API key")
|
415 |
+
|
416 |
+
# Set up unique output directory based on sequence hash
|
417 |
+
pdb_paths = []
|
418 |
+
for model_type in model_types:
|
419 |
+
model_pdb_paths, _ = predict(sequence, api_key, model_type)
|
420 |
+
pdb_paths += model_pdb_paths
|
421 |
|
422 |
+
aligned_paths = align_structures(pdb_paths)
|
|
|
|
|
423 |
|
424 |
+
return aligned_paths
|