|
"""Predict protein structure using Folding Studio.""" |
|
|
|
import concurrent.futures |
|
import hashlib |
|
import logging |
|
from io import StringIO |
|
from pathlib import Path |
|
from typing import Any |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import plotly.graph_objects as go |
|
from Bio import SeqIO |
|
from Bio.PDB import PDBIO, MMCIFParser, PDBParser, Superimposer |
|
from folding_studio_data_models import FoldingModel |
|
|
|
from folding_studio_demo.models import ( |
|
AF2Model, |
|
BoltzModel, |
|
ChaiModel, |
|
OpenFoldModel, |
|
ProtenixModel, |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
SEQUENCE_DIR = Path("sequences") |
|
SEQUENCE_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
OUTPUT_DIR = Path("output") |
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
THREE_TO_ONE_LETTER = { |
|
"ALA": "A", |
|
"ARG": "R", |
|
"ASN": "N", |
|
"ASP": "D", |
|
"CYS": "C", |
|
"GLN": "Q", |
|
"GLU": "E", |
|
"GLY": "G", |
|
"HIS": "H", |
|
"ILE": "I", |
|
"LEU": "L", |
|
"LYS": "K", |
|
"MET": "M", |
|
"PHE": "F", |
|
"PRO": "P", |
|
"SER": "S", |
|
"THR": "T", |
|
"TRP": "W", |
|
"TYR": "Y", |
|
"VAL": "V", |
|
"SEC": "U", |
|
"PYL": "O", |
|
"ASX": "B", |
|
"GLX": "Z", |
|
"XAA": "X", |
|
"XLE": "J", |
|
"UNK": "X", |
|
} |
|
|
|
|
|
def convert_to_one_letter(resname: str) -> str: |
|
"""Convert three-letter amino acid code to one-letter code. |
|
|
|
Args: |
|
resname (str): Three-letter amino acid code |
|
|
|
Returns: |
|
str: One-letter amino acid code |
|
""" |
|
return THREE_TO_ONE_LETTER.get(resname, "X") |
|
|
|
|
|
def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None: |
|
"""Convert a .cif file to .pdb format using Biopython. |
|
|
|
Args: |
|
cif_path (str): Path to input .cif file |
|
pdb_path (str): Path to output .pdb file |
|
""" |
|
|
|
parser = MMCIFParser() |
|
structure = parser.get_structure("structure", cif_path) |
|
|
|
|
|
io = PDBIO() |
|
io.set_structure(structure) |
|
io.save(pdb_path) |
|
|
|
|
|
def create_plddt_figure( |
|
plddt_vals: list[dict[str, dict[str, list[float]]]], |
|
model_name: str, |
|
indexes: list[int], |
|
) -> go.Figure: |
|
"""Create a plot of metrics.""" |
|
plddt_traces = [] |
|
|
|
for i, (pred_plddt, index) in enumerate(zip(plddt_vals, indexes)): |
|
hover_text = [] |
|
plddt_values = [] |
|
for chain_id, plddt_val in pred_plddt.items(): |
|
plddt_values += plddt_val["values"] |
|
hover_text += [ |
|
f"<i>{model_name} {index} - Chain {chain_id}</i><br><i>pLDDT</i>: {plddt:.2f}<br><i>Residue:</i> {code} {idx}" |
|
for idx, (plddt, code) in enumerate( |
|
zip(plddt_val["values"], plddt_val["residue_codes"]) |
|
) |
|
] |
|
|
|
plddt_traces.append( |
|
go.Scatter( |
|
x=np.arange(len(plddt_values)), |
|
y=plddt_values, |
|
hovertemplate="%{text}<extra></extra>", |
|
text=hover_text, |
|
name=f"{model_name} {index}", |
|
visible=True, |
|
) |
|
) |
|
plddt_fig = go.Figure(data=plddt_traces) |
|
plddt_fig.update_layout( |
|
title="pLDDT", |
|
xaxis_title="Residue", |
|
yaxis_title="pLDDT", |
|
height=500, |
|
template="simple_white", |
|
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99), |
|
) |
|
|
|
return plddt_fig |
|
|
|
|
|
def _write_fasta_file( |
|
sequence: str, directory: Path = SEQUENCE_DIR |
|
) -> tuple[str, Path]: |
|
"""Write sequence to FASTA file. |
|
|
|
Args: |
|
sequence (str): Sequence to write to FASTA file |
|
directory (Path): Directory to write FASTA file to (default: SEQUENCE_DIR) |
|
|
|
Returns: |
|
tuple[str, Path]: Tuple containing the sequence ID and the path to the FASTA file |
|
""" |
|
input_rep = list(SeqIO.parse(StringIO(sequence), "fasta")) |
|
if not input_rep: |
|
raise gr.Error("No sequence found") |
|
|
|
seq_id = hashlib.sha256( |
|
"_".join([str(records.seq) for records in input_rep]).encode() |
|
).hexdigest() |
|
seq_file = directory / f"sequence_{seq_id}.fasta" |
|
with open(seq_file, "w") as f: |
|
f.write(sequence) |
|
return seq_id, seq_file |
|
|
|
|
|
def extract_plddt_from_structure( |
|
structure_path: str, |
|
) -> dict[str, dict[str, list[float]]]: |
|
"""Extract pLDDT values and residue codes from a structure file. |
|
|
|
Args: |
|
structure_path (Path): Path to structure file |
|
|
|
Returns: |
|
tuple[list[float], list[str]]: Tuple containing lists of pLDDT values and residue codes |
|
""" |
|
if Path(structure_path).suffix == ".cif": |
|
structure = MMCIFParser().get_structure("structure", structure_path) |
|
else: |
|
structure = PDBParser().get_structure("structure", structure_path) |
|
|
|
|
|
plddt_values = {} |
|
|
|
|
|
for model in structure: |
|
for chain in model: |
|
plddt_values[chain.id] = {"values": [], "residue_codes": []} |
|
for residue in chain: |
|
|
|
if "CA" in residue: |
|
|
|
plddt = residue["CA"].get_bfactor() |
|
plddt_values[chain.id]["values"].append(plddt) |
|
|
|
plddt_values[chain.id]["residue_codes"].append( |
|
convert_to_one_letter(residue.get_resname()) |
|
) |
|
|
|
return plddt_values |
|
|
|
|
|
def predict( |
|
sequence: str, |
|
api_key: str, |
|
model_type: FoldingModel, |
|
format_fasta: bool = False, |
|
progress=gr.Progress(), |
|
) -> tuple[str, str]: |
|
"""Predict protein structure from amino acid sequence using Boltz model. |
|
|
|
Args: |
|
sequence (str): Amino acid sequence to predict structure for |
|
api_key (str): Folding API key |
|
model (FoldingModel): Folding model to use |
|
format_fasta (bool): Whether to format the FASTA file |
|
progress (gr.Progress): Gradio progress tracker |
|
|
|
Returns: |
|
tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot |
|
""" |
|
if not api_key: |
|
raise gr.Error("Missing API key, please enter a valid API key") |
|
|
|
progress(0, desc="Setting up prediction...") |
|
|
|
seq_id, seq_file = _write_fasta_file(sequence) |
|
output_dir = OUTPUT_DIR / seq_id / model_type |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
if model_type == FoldingModel.BOLTZ: |
|
model = BoltzModel(api_key) |
|
elif model_type == FoldingModel.CHAI: |
|
model = ChaiModel(api_key) |
|
elif model_type == FoldingModel.PROTENIX: |
|
model = ProtenixModel(api_key) |
|
elif model_type == FoldingModel.AF2: |
|
model = AF2Model(api_key) |
|
elif model_type == FoldingModel.OPENFOLD: |
|
model = OpenFoldModel(api_key) |
|
else: |
|
raise ValueError(f"Model {model_type} not supported") |
|
|
|
|
|
if not model.has_prediction(output_dir): |
|
progress(0.2, desc="Running prediction...") |
|
|
|
logger.info(f"Predicting {seq_id}") |
|
model.call(seq_file=seq_file, output_dir=output_dir, format_fasta=format_fasta) |
|
logger.info("Prediction done. Output directory: %s", output_dir) |
|
else: |
|
progress(0.2, desc="Using existing prediction...") |
|
logger.info("Prediction already exists. Output directory: %s", output_dir) |
|
|
|
progress(0.4, desc="Processing results...") |
|
|
|
if not model.has_prediction(output_dir): |
|
raise gr.Error("No prediction found") |
|
|
|
predictions = model.predictions(output_dir) |
|
pdb_paths = [] |
|
model_plddt_vals = [] |
|
|
|
total_predictions = len(predictions) |
|
for i, (model_idx, prediction) in enumerate(predictions.items()): |
|
progress( |
|
0.4 + (0.4 * i / total_predictions), desc=f"Converting model {model_idx}..." |
|
) |
|
prediction_path = prediction["prediction_path"] |
|
logger.info(f"Prediction file: {prediction_path}") |
|
if Path(prediction_path).suffix == ".cif": |
|
converted_pdb_path = str( |
|
output_dir / f"{model.model_name}_prediction_{model_idx}.pdb" |
|
) |
|
convert_cif_to_pdb(str(prediction_path), str(converted_pdb_path)) |
|
pdb_paths.append(converted_pdb_path) |
|
else: |
|
pdb_paths.append(str(prediction_path)) |
|
plddt_vals = extract_plddt_from_structure(prediction_path) |
|
model_plddt_vals.append(plddt_vals) |
|
|
|
progress(0.8, desc="Generating plots...") |
|
indexes = [] |
|
for pdb_path in pdb_paths: |
|
if model_type in [ |
|
FoldingModel.AF2, |
|
FoldingModel.OPENFOLD, |
|
FoldingModel.SOLOSEQ, |
|
]: |
|
indexes.append(int(Path(pdb_path).stem.split("_")[2])) |
|
else: |
|
indexes.append(int(Path(pdb_path).stem[-1])) |
|
|
|
plddt_fig = create_plddt_figure( |
|
plddt_vals=model_plddt_vals, |
|
model_name=model.model_name, |
|
indexes=indexes, |
|
) |
|
|
|
progress(1.0, desc="Done!") |
|
return pdb_paths, plddt_fig |
|
|
|
|
|
def align_structures( |
|
model_predictions: dict[FoldingModel, dict[int, dict[str, Any]]], |
|
) -> list[str]: |
|
"""Align multiple PDB structures to the first structure. |
|
|
|
Args: |
|
model_predictions (dict[FoldingModel, dict[int, dict[str, Any]]]): Dictionary mapping models to their prediction indices |
|
|
|
Returns: |
|
list[str]: List of paths to aligned PDB files |
|
""" |
|
|
|
parser = PDBParser() |
|
io = PDBIO() |
|
|
|
|
|
first_model = next(iter(model_predictions.keys())) |
|
first_pred = next(iter(model_predictions[first_model].values())) |
|
ref_pdb_path = first_pred["pdb_path"] |
|
|
|
|
|
ref_structure = parser.get_structure("reference", ref_pdb_path) |
|
ref_atoms = [atom for atom in ref_structure.get_atoms() if atom.get_name() == "CA"] |
|
|
|
for model_type in model_predictions.keys(): |
|
for index, prediction in model_predictions[model_type].items(): |
|
pdb_path = prediction["pdb_path"] |
|
|
|
|
|
structure = parser.get_structure(f"{model_type}_{index}", pdb_path) |
|
atoms = [atom for atom in structure.get_atoms() if atom.get_name() == "CA"] |
|
|
|
|
|
sup = Superimposer() |
|
|
|
|
|
sup.set_atoms(ref_atoms, atoms) |
|
|
|
|
|
sup.apply(structure.get_atoms()) |
|
|
|
|
|
aligned_path = str(Path(pdb_path).parent / f"aligned_{Path(pdb_path).name}") |
|
io.set_structure(structure) |
|
io.save(aligned_path) |
|
|
|
model_predictions[model_type][index]["pdb_path"] = aligned_path |
|
|
|
return model_predictions |
|
|
|
|
|
def filter_predictions( |
|
model_predictions: dict[FoldingModel, dict[int, dict[str, Any]]], |
|
af2_selected: list[int], |
|
openfold_selected: list[int], |
|
solo_selected: list[int], |
|
chai_selected: list[int], |
|
boltz_selected: list[int], |
|
protenix_selected: list[int], |
|
) -> tuple[list[str], go.Figure]: |
|
"""Filter predictions based on selected checkboxes. |
|
|
|
Args: |
|
aligned_paths (list[str]): List of aligned PDB paths |
|
plddt_fig (go.Figure): Original pLDDT plot |
|
chai_selected (list[int]): Selected Chai model indices |
|
boltz_selected (list[int]): Selected Boltz model indices |
|
protenix_selected (list[int]): Selected Protenix model indices |
|
model_predictions (dict[FoldingModel, dict[int, dict[str, Any]]]): Dictionary mapping models to their prediction indices |
|
|
|
Returns: |
|
tuple[list[str], go.Figure]: Filtered PDB paths and updated pLDDT plot |
|
""" |
|
|
|
filtered_fig = go.Figure() |
|
|
|
|
|
filtered_paths = [] |
|
|
|
|
|
def should_show_trace(model_name, pred_index: int) -> bool: |
|
if model_name == FoldingModel.CHAI and pred_index in chai_selected: |
|
return True |
|
if model_name == FoldingModel.BOLTZ and pred_index in boltz_selected: |
|
return True |
|
if model_name == FoldingModel.PROTENIX and pred_index in protenix_selected: |
|
return True |
|
if model_name == FoldingModel.AF2 and pred_index in af2_selected: |
|
return True |
|
if model_name == FoldingModel.OPENFOLD and pred_index in openfold_selected: |
|
return True |
|
if model_name == FoldingModel.SOLOSEQ and pred_index in solo_selected: |
|
return True |
|
return False |
|
|
|
|
|
for model_type in model_predictions.keys(): |
|
for index, prediction in model_predictions[model_type].items(): |
|
if should_show_trace(model_type, index): |
|
filtered_fig.add_trace(prediction["plddt_trace"]) |
|
filtered_paths.append(prediction["pdb_path"]) |
|
|
|
|
|
filtered_fig.update_layout( |
|
title="pLDDT", |
|
xaxis_title="Residue index", |
|
yaxis_title="pLDDT", |
|
height=500, |
|
template="simple_white", |
|
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99), |
|
) |
|
return filtered_paths, filtered_fig |
|
|
|
|
|
def run_prediction( |
|
sequence: str, |
|
api_key: str, |
|
model_type: FoldingModel, |
|
format_fasta: bool = False, |
|
) -> dict[FoldingModel, dict[int, dict[str, Any]]]: |
|
"""Run a single prediction. |
|
|
|
Args: |
|
sequence (str): Amino acid sequence to predict structure for |
|
api_key (str): Folding API key |
|
model_type (FoldingModel): Folding model to use |
|
format_fasta (bool): Whether to format the FASTA file |
|
|
|
Returns: |
|
Tuple containing: |
|
- List of PDB paths |
|
- pLDDT plot |
|
- Dictionary mapping model to prediction indices |
|
""" |
|
model_pdb_paths, model_plddt_traces = predict( |
|
sequence, api_key, model_type, format_fasta=format_fasta |
|
) |
|
model_predictions = {} |
|
for pdb_path, plddt_traces in zip(model_pdb_paths, model_plddt_traces.data): |
|
if model_type in [ |
|
FoldingModel.AF2, |
|
FoldingModel.OPENFOLD, |
|
FoldingModel.SOLOSEQ, |
|
]: |
|
index = int(Path(pdb_path).stem.split("_")[2]) |
|
else: |
|
index = int(Path(pdb_path).stem[-1]) |
|
|
|
model_predictions[index] = {"pdb_path": pdb_path, "plddt_trace": plddt_traces} |
|
|
|
return model_predictions |
|
|
|
|
|
def predict_comparison( |
|
sequence: str, api_key: str, model_types: list[FoldingModel], progress=gr.Progress() |
|
) -> tuple[ |
|
dict[FoldingModel, dict[int, dict[str, Any]]], |
|
gr.CheckboxGroup, |
|
gr.CheckboxGroup, |
|
gr.CheckboxGroup, |
|
gr.CheckboxGroup, |
|
gr.CheckboxGroup, |
|
gr.CheckboxGroup, |
|
]: |
|
"""Predict protein structure from amino acid sequence using multiple models. |
|
|
|
Args: |
|
sequence (str): Amino acid sequence to predict structure for |
|
api_key (str): Folding API key |
|
model_types (list[FoldingModel]): List of folding models to use |
|
progress (gr.Progress): Gradio progress tracker |
|
|
|
Returns: |
|
tuple containing: |
|
- dict[FoldingModel, dict[int, dict[str, Any]]]: Model predictions mapping |
|
- gr.CheckboxGroup: AF2 predictions checkbox group |
|
- gr.CheckboxGroup: OpenFold predictions checkbox group |
|
- gr.CheckboxGroup: SoloSeq predictions checkbox group |
|
- gr.CheckboxGroup: Chai predictions checkbox group |
|
- gr.CheckboxGroup: Boltz predictions checkbox group |
|
- gr.CheckboxGroup: Protenix predictions checkbox group |
|
""" |
|
if not api_key: |
|
raise gr.Error("Missing API key, please enter a valid API key") |
|
|
|
progress(0, desc="Starting parallel predictions...") |
|
|
|
|
|
model_predictions = {} |
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
|
|
future_to_model = { |
|
executor.submit( |
|
run_prediction, sequence, api_key, model_type, True |
|
): model_type |
|
for model_type in model_types |
|
} |
|
|
|
|
|
total_models = len(model_types) |
|
completed = 0 |
|
|
|
for future in concurrent.futures.as_completed(future_to_model): |
|
model_type = future_to_model[future] |
|
try: |
|
model_preds = future.result() |
|
model_predictions[model_type] = model_preds |
|
|
|
completed += 1 |
|
progress( |
|
completed / total_models, |
|
desc=f"Completed {model_type} prediction...", |
|
) |
|
except Exception as e: |
|
logger.error(f"Prediction failed for {model_type}: {str(e)}") |
|
raise gr.Error(f"Prediction failed for {model_type}: {str(e)}") |
|
|
|
progress(0.9, desc="Aligning structures...") |
|
|
|
model_predictions = align_structures(model_predictions) |
|
|
|
progress(1.0, desc="Done!") |
|
|
|
|
|
af2_predictions = gr.CheckboxGroup( |
|
visible=model_predictions.get(FoldingModel.AF2) is not None, |
|
choices=list(model_predictions.get(FoldingModel.AF2, {}).keys()), |
|
value=list(model_predictions.get(FoldingModel.AF2, {}).keys()), |
|
) |
|
openfold_predictions = gr.CheckboxGroup( |
|
visible=model_predictions.get(FoldingModel.OPENFOLD) is not None, |
|
choices=list(model_predictions.get(FoldingModel.OPENFOLD, {}).keys()), |
|
value=list(model_predictions.get(FoldingModel.OPENFOLD, {}).keys()), |
|
) |
|
solo_predictions = gr.CheckboxGroup( |
|
visible=model_predictions.get(FoldingModel.SOLOSEQ) is not None, |
|
choices=list(model_predictions.get(FoldingModel.SOLOSEQ, {}).keys()), |
|
value=list(model_predictions.get(FoldingModel.SOLOSEQ, {}).keys()), |
|
) |
|
chai_predictions = gr.CheckboxGroup( |
|
visible=model_predictions.get(FoldingModel.CHAI) is not None, |
|
choices=list(model_predictions.get(FoldingModel.CHAI, {}).keys()), |
|
value=list(model_predictions.get(FoldingModel.CHAI, {}).keys()), |
|
) |
|
boltz_predictions = gr.CheckboxGroup( |
|
visible=model_predictions.get(FoldingModel.BOLTZ) is not None, |
|
choices=list(model_predictions.get(FoldingModel.BOLTZ, {}).keys()), |
|
value=list(model_predictions.get(FoldingModel.BOLTZ, {}).keys()), |
|
) |
|
protenix_predictions = gr.CheckboxGroup( |
|
visible=model_predictions.get(FoldingModel.PROTENIX) is not None, |
|
choices=list(model_predictions.get(FoldingModel.PROTENIX, {}).keys()), |
|
value=list(model_predictions.get(FoldingModel.PROTENIX, {}).keys()), |
|
) |
|
|
|
return ( |
|
model_predictions, |
|
af2_predictions, |
|
openfold_predictions, |
|
solo_predictions, |
|
chai_predictions, |
|
boltz_predictions, |
|
protenix_predictions, |
|
) |
|
|