"""Predict protein structure using Folding Studio."""
import hashlib
import logging
from io import StringIO
from pathlib import Path
import gradio as gr
import numpy as np
import plotly.graph_objects as go
from Bio import SeqIO
from Bio.PDB import PDBIO, MMCIFParser, PDBParser, Superimposer
from folding_studio_data_models import FoldingModel
from folding_studio_demo.models import BoltzModel, ChaiModel, ProtenixModel
logger = logging.getLogger(__name__)
SEQUENCE_DIR = Path("sequences")
SEQUENCE_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR = Path("output")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
THREE_TO_ONE_LETTER = {
"ALA": "A",
"ARG": "R",
"ASN": "N",
"ASP": "D",
"CYS": "C",
"GLN": "Q",
"GLU": "E",
"GLY": "G",
"HIS": "H",
"ILE": "I",
"LEU": "L",
"LYS": "K",
"MET": "M",
"PHE": "F",
"PRO": "P",
"SER": "S",
"THR": "T",
"TRP": "W",
"TYR": "Y",
"VAL": "V",
"SEC": "U",
"PYL": "O",
"ASX": "B",
"GLX": "Z",
"XAA": "X",
"XLE": "J",
"UNK": "X",
}
def convert_to_one_letter(resname: str) -> str:
"""Convert three-letter amino acid code to one-letter code.
Args:
resname (str): Three-letter amino acid code
Returns:
str: One-letter amino acid code
"""
return THREE_TO_ONE_LETTER.get(resname, "X")
def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None:
"""Convert a .cif file to .pdb format using Biopython.
Args:
cif_path (str): Path to input .cif file
pdb_path (str): Path to output .pdb file
"""
# Parse the CIF file
parser = MMCIFParser()
structure = parser.get_structure("structure", cif_path)
# Save as PDB
io = PDBIO()
io.set_structure(structure)
io.save(pdb_path)
def create_plddt_figure(
plddt_vals: list[list[float]],
model_name: str,
residue_codes: list[list[str]] = None,
) -> go.Figure:
"""Create a plot of metrics."""
plddt_traces = []
for i, plddt_val in enumerate(plddt_vals):
# Create hover text with residue codes if available
if residue_codes and i < len(residue_codes):
hover_text = [
f"pLDDT: {plddt:.2f}
Residue: {code} {idx}"
for idx, (plddt, code) in enumerate(zip(plddt_val, residue_codes[i]))
]
else:
hover_text = [
f"pLDDT: {plddt:.2f}
Residue index: {idx}"
for idx, plddt in enumerate(plddt_val)
]
plddt_traces.append(
go.Scatter(
x=np.arange(len(plddt_val)),
y=plddt_val,
hovertemplate="%{text}",
text=hover_text,
name=f"{model_name} {i}",
visible=True,
)
)
plddt_fig = go.Figure(data=plddt_traces)
plddt_fig.update_layout(
title="pLDDT",
xaxis_title="Residue",
yaxis_title="pLDDT",
height=500,
template="simple_white",
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
)
return plddt_fig
def _write_fasta_file(
sequence: str, directory: Path = SEQUENCE_DIR
) -> tuple[str, Path]:
"""Write sequence to FASTA file.
Args:
sequence (str): Sequence to write to FASTA file
directory (Path): Directory to write FASTA file to (default: SEQUENCE_DIR)
Returns:
tuple[str, Path]: Tuple containing the sequence ID and the path to the FASTA file
"""
input_rep = list(SeqIO.parse(StringIO(sequence), "fasta"))
if not input_rep:
raise gr.Error("No sequence found")
seq_id = hashlib.sha256(
"_".join([str(records.seq) for records in input_rep]).encode()
).hexdigest()
seq_file = directory / f"sequence_{seq_id}.fasta"
with open(seq_file, "w") as f:
f.write(sequence)
return seq_id, seq_file
def extract_plddt_from_cif(cif_path):
structure = MMCIFParser().get_structure("structure", cif_path)
# Lists to store pLDDT values and residue codes
plddt_values = []
residue_codes = []
# Iterate through all atoms
for model in structure:
for chain in model:
for residue in chain:
# Get the first atom of each residue (usually CA atom)
if "CA" in residue:
# The B-factor contains the pLDDT value
plddt = residue["CA"].get_bfactor()
plddt_values.append(plddt)
# Get residue code and convert to one-letter code
residue_codes.append(convert_to_one_letter(residue.get_resname()))
return plddt_values, residue_codes
def predict(
sequence: str,
api_key: str,
model_type: FoldingModel,
format_fasta: bool = False,
progress=gr.Progress(),
) -> tuple[str, str]:
"""Predict protein structure from amino acid sequence using Boltz model.
Args:
sequence (str): Amino acid sequence to predict structure for
api_key (str): Folding API key
model (FoldingModel): Folding model to use
format_fasta (bool): Whether to format the FASTA file
progress (gr.Progress): Gradio progress tracker
Returns:
tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot
"""
if not api_key:
raise gr.Error("Missing API key, please enter a valid API key")
progress(0, desc="Setting up prediction...")
# Set up unique output directory based on sequence hash
seq_id, seq_file = _write_fasta_file(sequence)
output_dir = OUTPUT_DIR / seq_id / model_type
output_dir.mkdir(parents=True, exist_ok=True)
if model_type == FoldingModel.BOLTZ:
model = BoltzModel(api_key)
elif model_type == FoldingModel.CHAI:
model = ChaiModel(api_key)
elif model_type == FoldingModel.PROTENIX:
model = ProtenixModel(api_key)
else:
raise ValueError(f"Model {model_type} not supported")
# Check if prediction already exists
if not model.has_prediction(output_dir):
progress(0.2, desc="Running prediction...")
# Run prediction
logger.info(f"Predicting {seq_id}")
model.call(seq_file=seq_file, output_dir=output_dir, format_fasta=format_fasta)
logger.info("Prediction done. Output directory: %s", output_dir)
else:
progress(0.2, desc="Using existing prediction...")
logger.info("Prediction already exists. Output directory: %s", output_dir)
progress(0.4, desc="Processing results...")
# Convert output CIF to PDB
if not model.has_prediction(output_dir):
raise gr.Error("No prediction found")
predictions = model.predictions(output_dir)
pdb_paths = []
model_plddt_vals = []
model_residue_codes = []
total_predictions = len(predictions)
for i, (model_idx, prediction) in enumerate(predictions.items()):
progress(
0.4 + (0.4 * i / total_predictions), desc=f"Converting model {model_idx}..."
)
cif_path = prediction["prediction_path"]
logger.info(f"CIF file: {cif_path}")
converted_pdb_path = str(
output_dir / f"{model.model_name}_prediction_{model_idx}.pdb"
)
convert_cif_to_pdb(str(cif_path), str(converted_pdb_path))
plddt_vals, residue_codes = extract_plddt_from_cif(cif_path)
pdb_paths.append(converted_pdb_path)
model_plddt_vals.append(plddt_vals)
model_residue_codes.append(residue_codes)
progress(0.8, desc="Generating plots...")
plddt_fig = create_plddt_figure(
plddt_vals=model_plddt_vals,
model_name=model.model_name,
residue_codes=model_residue_codes,
)
progress(1.0, desc="Done!")
return pdb_paths, plddt_fig
def align_structures(pdb_paths: list[str]) -> list[str]:
"""Align multiple PDB structures to the first structure.
Args:
pdb_paths (list[str]): List of paths to PDB files to align
Returns:
list[str]: List of paths to aligned PDB files
"""
parser = PDBParser()
io = PDBIO()
# Parse the reference structure (first one)
ref_structure = parser.get_structure("reference", pdb_paths[0])
ref_atoms = [atom for atom in ref_structure.get_atoms() if atom.get_name() == "CA"]
aligned_paths = [pdb_paths[0]] # First structure is already aligned
# Align each subsequent structure to the reference
for i, pdb_path in enumerate(pdb_paths[1:], start=1):
# Parse the structure to align
structure = parser.get_structure(f"model_{i}", pdb_path)
atoms = [atom for atom in structure.get_atoms() if atom.get_name() == "CA"]
# Create superimposer
sup = Superimposer()
# Set the reference and moving atoms
sup.set_atoms(ref_atoms, atoms)
# Apply the transformation to all atoms in the structure
sup.apply(structure.get_atoms())
# Save the aligned structure
aligned_path = str(Path(pdb_path).parent / f"aligned_{Path(pdb_path).name}")
io.set_structure(structure)
io.save(aligned_path)
aligned_paths.append(aligned_path)
return aligned_paths
def filter_predictions(
aligned_paths: list[str],
plddt_fig: go.Figure,
chai_selected: list[int],
boltz_selected: list[int],
protenix_selected: list[int],
) -> tuple[list[str], go.Figure]:
"""Filter predictions based on selected checkboxes.
Args:
aligned_paths (list[str]): List of aligned PDB paths
plddt_fig (go.Figure): Original pLDDT plot
chai_selected (list[int]): Selected Chai model indices
boltz_selected (list[int]): Selected Boltz model indices
protenix_selected (list[int]): Selected Protenix model indices
model_predictions (dict[FoldingModel, list[int]]): Dictionary mapping models to their prediction indices
Returns:
tuple[list[str], go.Figure]: Filtered PDB paths and updated pLDDT plot
"""
# Create a new figure with only selected traces
filtered_fig = go.Figure()
# Keep track of which traces to show
visible_paths = []
# Helper function to check if a trace should be visible
def should_show_trace(trace_name: str) -> bool:
model_name = trace_name.split()[0]
model_idx = int(trace_name.split()[1])
if model_name == "Chai" and model_idx in chai_selected:
return True
if model_name == "Boltz" and model_idx in boltz_selected:
return True
if model_name == "Protenix" and model_idx in protenix_selected:
return True
return False
# Filter traces and paths
for i, trace in enumerate(plddt_fig.data):
if should_show_trace(trace.name):
filtered_fig.add_trace(trace)
visible_paths.append(aligned_paths[i])
# Update layout
filtered_fig.update_layout(
title="pLDDT",
xaxis_title="Residue index",
yaxis_title="pLDDT",
height=500,
template="simple_white",
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
)
return visible_paths, filtered_fig
def predict_comparison(
sequence: str, api_key: str, model_types: list[FoldingModel], progress=gr.Progress()
) -> tuple[
list[str],
go.Figure,
gr.CheckboxGroup,
gr.CheckboxGroup,
gr.CheckboxGroup,
list[str],
go.Figure,
dict,
]:
"""Predict protein structure from amino acid sequence using multiple models.
Args:
sequence (str): Amino acid sequence to predict structure for
api_key (str): Folding API key
model_types (list[FoldingModel]): List of folding models to use
progress (gr.Progress): Gradio progress tracker
Returns:
tuple containing:
- list[str]: Aligned PDB paths
- go.Figure: pLDDT plot
- gr.CheckboxGroup: Chai predictions checkbox group
- gr.CheckboxGroup: Boltz predictions checkbox group
- gr.CheckboxGroup: Protenix predictions checkbox group
- list[str]: Original PDB paths
- go.Figure: Original pLDDT plot
- dict: Model predictions mapping
"""
if not api_key:
raise gr.Error("Missing API key, please enter a valid API key")
# Set up unique output directory based on sequence hash
pdb_paths = []
plddt_traces = []
total_models = len(model_types)
model_predictions = {}
for i, model_type in enumerate(model_types):
progress(i / total_models, desc=f"Running {model_type} prediction...")
model_pdb_paths, model_plddt_traces = predict(
sequence, api_key, model_type, format_fasta=True
)
pdb_paths += model_pdb_paths
plddt_traces += model_plddt_traces.data
model_predictions[model_type] = [int(Path(p).stem[-1]) for p in model_pdb_paths]
progress(0.9, desc="Aligning structures...")
aligned_paths = align_structures(pdb_paths)
plddt_fig = go.Figure(data=plddt_traces)
plddt_fig.update_layout(
title="pLDDT",
xaxis_title="Residue index",
yaxis_title="pLDDT",
height=500,
template="simple_white",
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
)
progress(1.0, desc="Done!")
# Create checkbox groups for each model type
chai_predictions = gr.CheckboxGroup(
visible=model_predictions.get(FoldingModel.CHAI) is not None,
choices=model_predictions.get(FoldingModel.CHAI, []),
value=model_predictions.get(FoldingModel.CHAI, []),
)
boltz_predictions = gr.CheckboxGroup(
visible=model_predictions.get(FoldingModel.BOLTZ) is not None,
choices=model_predictions.get(FoldingModel.BOLTZ, []),
value=model_predictions.get(FoldingModel.BOLTZ, []),
)
protenix_predictions = gr.CheckboxGroup(
visible=model_predictions.get(FoldingModel.PROTENIX) is not None,
choices=model_predictions.get(FoldingModel.PROTENIX, []),
value=model_predictions.get(FoldingModel.PROTENIX, []),
)
return (
chai_predictions,
boltz_predictions,
protenix_predictions,
aligned_paths,
plddt_fig,
)