"""Predict protein structure using Folding Studio."""
import hashlib
import logging
import os
from pathlib import Path
import gradio as gr
import numpy as np
import plotly.graph_objects as go
from Bio import SeqIO
from Bio.PDB import PDBIO, MMCIFParser
from folding_studio.client import Client
from folding_studio.query import Query
from folding_studio.query.boltz import BoltzQuery
from folding_studio.query.chai import ChaiQuery
from folding_studio.query.protenix import ProtenixQuery
from folding_studio_data_models import FoldingModel
from folding_studio_demo.model_fasta_validators import (
BaseFastaValidator,
BoltzFastaValidator,
ChaiFastaValidator,
ProtenixFastaValidator,
)
logger = logging.getLogger(__name__)
SEQUENCE_DIR = Path("sequences")
SEQUENCE_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR = Path("output")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None:
"""Convert a .cif file to .pdb format using Biopython.
Args:
cif_path (str): Path to input .cif file
pdb_path (str): Path to output .pdb file
"""
# Parse the CIF file
parser = MMCIFParser()
structure = parser.get_structure("structure", cif_path)
# Save as PDB
io = PDBIO()
io.set_structure(structure)
io.save(pdb_path)
def add_plddt_plot(plddt_vals: list[float]) -> str:
"""Create a plot of metrics."""
visible = True
plddt_trace = go.Scatter(
x=np.arange(len(plddt_vals)),
y=plddt_vals,
hovertemplate="pLDDT: %{y:.2f}
Residue index: %{x}
",
name="seq",
visible=visible,
)
plddt_fig = go.Figure(data=[plddt_trace])
plddt_fig.update_layout(
title="pLDDT",
xaxis_title="Residue index",
yaxis_title="pLDDT",
height=500,
template="simple_white",
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
)
return plddt_fig
def _write_fasta_file(
sequence: str, directory: Path = SEQUENCE_DIR
) -> tuple[str, Path]:
"""Write sequence to FASTA file.
Args:
sequence (str): Sequence to write to FASTA file
directory (Path): Directory to write FASTA file to (default: SEQUENCE_DIR)
Returns:
tuple[str, Path]: Tuple containing the sequence ID and the path to the FASTA file
"""
seq_id = hashlib.sha1(sequence.encode()).hexdigest()
seq_file = directory / f"sequence_{seq_id}.fasta"
with open(seq_file, "w") as f:
f.write(sequence)
return seq_id, seq_file
class AF3Model:
def __init__(
self, api_key: str, model_name: str, query: Query, validator: BaseFastaValidator
):
self.api_key = api_key
self.model_name = model_name
self.query = query
self.validator = validator
def call(self, seq_file: Path | str, output_dir: Path) -> None:
"""Predict protein structure from amino acid sequence using AF3 model.
Args:
seq_file (Path | str): Path to FASTA file containing amino acid sequence
output_dir (Path): Path to output directory
"""
# Validate FASTA format before calling
is_valid, error_msg = self.check_file_description(seq_file)
if not is_valid:
logger.error(error_msg)
raise gr.Error(error_msg)
# Create a client using API key
logger.info("Authenticating client with API key")
client = Client.from_api_key(api_key=self.api_key)
# Define query
query: Query = self.query.from_file(path=seq_file, query_name="gradio")
query.save_parameters(output_dir)
logger.info("Payload: %s", query.payload)
# Send a request
logger.info(f"Sending {self.model_name} request to Folding Studio API")
response = client.send_request(
query, project_code=os.environ["FOLDING_PROJECT_CODE"]
)
# Access confidence data
logger.info("Confidence data: %s", response.confidence_data)
response.download_results(output_dir=output_dir, force=True, unzip=True)
logger.info("Results downloaded to %s", output_dir)
def format_fasta(self, sequence: str) -> str:
"""Format sequence to FASTA format."""
return f">{self.model_name}\n{sequence}"
def predictions(self, output_dir: Path) -> list[Path]:
"""Get the path to the prediction."""
raise NotImplementedError("Not implemented")
def has_prediction(self, output_dir: Path) -> bool:
"""Check if prediction exists in output directory."""
return any(self.predictions(output_dir))
def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]:
"""Check if the file description is correct.
Args:
seq_file (Path | str): Path to FASTA file
Returns:
tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not
"""
input_rep = list(SeqIO.parse(seq_file, "fasta"))
if not input_rep:
error_msg = f"{self.model_name.upper()} Validation Error: No sequence found"
return False, error_msg
is_valid, error_msg = self.validator.is_valid_fasta(seq_file)
if not is_valid:
return False, error_msg
return True, None
class ChaiModel(AF3Model):
def __init__(self, api_key: str):
super().__init__(api_key, "Chai", ChaiQuery, ChaiFastaValidator())
def call(self, seq_file: Path | str, output_dir: Path) -> None:
"""Predict protein structure from amino acid sequence using Chai model.
Args:
seq_file (Path | str): Path to FASTA file containing amino acid sequence
output_dir (Path): Path to output directory
"""
super().call(seq_file, output_dir)
def predictions(self, output_dir: Path) -> list[Path]:
"""Get the path to the prediction."""
return list(output_dir.rglob("*_model_[0-9].cif"))
class ProtenixModel(AF3Model):
def __init__(self, api_key: str):
super().__init__(api_key, "Protenix", ProtenixQuery, ProtenixFastaValidator())
def call(self, seq_file: Path | str, output_dir: Path) -> None:
"""Predict protein structure from amino acid sequence using Protenix model.
Args:
seq_file (Path | str): Path to FASTA file containing amino acid sequence
output_dir (Path): Path to output directory
"""
super().call(seq_file, output_dir)
def predictions(self, output_dir: Path) -> list[Path]:
"""Get the path to the prediction."""
return list(output_dir.rglob("*_model_[0-9].cif"))
class BoltzModel(AF3Model):
def __init__(self, api_key: str):
super().__init__(api_key, "Boltz", BoltzQuery, BoltzFastaValidator())
def call(self, seq_file: Path | str, output_dir: Path) -> None:
"""Predict protein structure from amino acid sequence using Boltz model.
Args:
seq_file (Path | str): Path to FASTA file containing amino acid sequence
output_dir (Path): Path to output directory
"""
super().call(seq_file, output_dir)
def predictions(self, output_dir: Path) -> list[Path]:
"""Get the path to the prediction."""
return list(output_dir.rglob("*_model_[0-9].cif"))
def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str, str]:
"""Predict protein structure from amino acid sequence using Boltz model.
Args:
sequence (str): Amino acid sequence to predict structure for
api_key (str): Folding API key
model (FoldingModel): Folding model to use
Returns:
tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot
"""
# Set up unique output directory based on sequence hash
seq_id, seq_file = _write_fasta_file(sequence)
output_dir = OUTPUT_DIR / seq_id / model_type
output_dir.mkdir(parents=True, exist_ok=True)
if model_type == FoldingModel.BOLTZ:
model = BoltzModel(api_key)
elif model_type == FoldingModel.CHAI:
model = ChaiModel(api_key)
elif model_type == FoldingModel.PROTENIX:
model = ProtenixModel(api_key)
else:
raise ValueError(f"Model {model_type} not supported")
# Check if prediction already exists
if not model.has_prediction(output_dir):
# Run Boltz prediction
logger.info(f"Predicting {seq_id}")
model.call(seq_file=seq_file, output_dir=output_dir)
logger.info("Prediction done. Output directory: %s", output_dir)
else:
logger.info("Prediction already exists. Output directory: %s", output_dir)
# output_dir = Path("boltz_results") # debug
# Convert output CIF to PDB
if not model.has_prediction(output_dir):
raise gr.Error("No prediction found")
pred_cif = model.predictions(output_dir)[0]
logger.info("Output file: %s", pred_cif)
converted_pdb_path = str(output_dir / f"pred_{seq_id}.pdb")
convert_cif_to_pdb(str(pred_cif), str(converted_pdb_path))
logger.info("Converted PDB file: %s", converted_pdb_path)
plddt_file = list(pred_cif.parent.glob("plddt_*.npz"))[0]
logger.info("plddt file: %s", plddt_file)
plddt_vals = np.load(plddt_file)["plddt"]
return converted_pdb_path, add_plddt_plot(plddt_vals=plddt_vals)