jfaustin's picture
Add model comparison (#3)
f601557 verified
raw
history blame
14.2 kB
"""Predict protein structure using Folding Studio."""
import hashlib
import logging
import os
from io import StringIO
from pathlib import Path
from typing import Any
import gradio as gr
import numpy as np
import plotly.graph_objects as go
from Bio import SeqIO
from Bio.PDB import PDBIO, MMCIFParser, PDBParser, Superimposer
from folding_studio.client import Client
from folding_studio.query import Query
from folding_studio.query.boltz import BoltzQuery
from folding_studio.query.chai import ChaiQuery
from folding_studio.query.protenix import ProtenixQuery
from folding_studio_data_models import FoldingModel
from folding_studio_demo.model_fasta_validators import (
BaseFastaValidator,
BoltzFastaValidator,
ChaiFastaValidator,
ProtenixFastaValidator,
)
logger = logging.getLogger(__name__)
SEQUENCE_DIR = Path("sequences")
SEQUENCE_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR = Path("output")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None:
"""Convert a .cif file to .pdb format using Biopython.
Args:
cif_path (str): Path to input .cif file
pdb_path (str): Path to output .pdb file
"""
# Parse the CIF file
parser = MMCIFParser()
structure = parser.get_structure("structure", cif_path)
# Save as PDB
io = PDBIO()
io.set_structure(structure)
io.save(pdb_path)
def add_plddt_plot(plddt_vals: list[list[float]], model_name: str) -> go.Figure:
"""Create a plot of metrics."""
visible = True
plddt_traces = [
go.Scatter(
x=np.arange(len(plddt_val)),
y=plddt_val,
hovertemplate="<i>pLDDT</i>: %{y:.2f} <br><i>Residue index:</i> %{x}<br>",
name=f"{model_name} {i}",
visible=visible,
)
for i, plddt_val in enumerate(plddt_vals)
]
plddt_fig = go.Figure(data=plddt_traces)
plddt_fig.update_layout(
title="pLDDT",
xaxis_title="Residue index",
yaxis_title="pLDDT",
height=500,
template="simple_white",
legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99),
)
return plddt_fig
def _write_fasta_file(
sequence: str, directory: Path = SEQUENCE_DIR
) -> tuple[str, Path]:
"""Write sequence to FASTA file.
Args:
sequence (str): Sequence to write to FASTA file
directory (Path): Directory to write FASTA file to (default: SEQUENCE_DIR)
Returns:
tuple[str, Path]: Tuple containing the sequence ID and the path to the FASTA file
"""
input_rep = list(SeqIO.parse(StringIO(sequence), "fasta"))
if not input_rep:
raise gr.Error("No sequence found")
seq_id = hashlib.sha256(
"_".join([str(records.seq) for records in input_rep]).encode()
).hexdigest()
seq_file = directory / f"sequence_{seq_id}.fasta"
with open(seq_file, "w") as f:
f.write(sequence)
return seq_id, seq_file
class AF3Model:
def __init__(
self, api_key: str, model_name: str, query: Query, validator: BaseFastaValidator
):
self.api_key = api_key
self.model_name = model_name
self.query = query
self.validator = validator
def call(self, seq_file: Path | str, output_dir: Path) -> None:
"""Predict protein structure from amino acid sequence using AF3 model.
Args:
seq_file (Path | str): Path to FASTA file containing amino acid sequence
output_dir (Path): Path to output directory
"""
# Validate FASTA format before calling
is_valid, error_msg = self.check_file_description(seq_file)
if not is_valid:
logger.error(error_msg)
raise gr.Error(error_msg)
# Create a client using API key
logger.info("Authenticating client with API key")
client = Client.from_api_key(api_key=self.api_key)
# Define query
query: Query = self.query.from_file(path=seq_file, query_name="gradio")
query.save_parameters(output_dir)
logger.info("Payload: %s", query.payload)
# Send a request
logger.info(f"Sending {self.model_name} request to Folding Studio API")
response = client.send_request(
query, project_code=os.environ["FOLDING_PROJECT_CODE"]
)
# Access confidence data
logger.info("Confidence data: %s", response.confidence_data)
response.download_results(output_dir=output_dir, force=True, unzip=True)
logger.info("Results downloaded to %s", output_dir)
def format_fasta(self, sequence: str) -> str:
"""Format sequence to FASTA format."""
return f">{self.model_name}\n{sequence}"
def predictions(self, output_dir: Path) -> list[Path]:
"""Get the path to the prediction."""
raise NotImplementedError("Not implemented")
def has_prediction(self, output_dir: Path) -> bool:
"""Check if prediction exists in output directory."""
return len(self.predictions(output_dir)) > 0
def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]:
"""Check if the file description is correct.
Args:
seq_file (Path | str): Path to FASTA file
Returns:
tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not
"""
is_valid, error_msg = self.validator.is_valid_fasta(seq_file)
if not is_valid:
return False, error_msg
return True, None
class ChaiModel(AF3Model):
def __init__(self, api_key: str):
super().__init__(api_key, "Chai", ChaiQuery, ChaiFastaValidator())
def call(self, seq_file: Path | str, output_dir: Path) -> None:
"""Predict protein structure from amino acid sequence using Chai model.
Args:
seq_file (Path | str): Path to FASTA file containing amino acid sequence
output_dir (Path): Path to output directory
"""
super().call(seq_file, output_dir)
def _get_chai_paired_files(self, directory: Path) -> list[tuple[Path, Path]]:
"""Get pairs of .cif and .npz files with matching model indices.
Args:
directory (Path): Directory containing the prediction files
Returns:
list[tuple[Path, Path]]: List of tuples containing (cif_path, npz_path) pairs
"""
# Get all cif files and extract their indices
def predictions(self, output_dir: Path) -> dict[Path, dict[str, Any]]:
"""Get the path to the prediction."""
prediction = next(output_dir.rglob("pred.model_idx_[0-9].cif"), None)
if prediction is None:
return {}
cif_files = {
int(f.stem.split("model_idx_")[1]): f
for f in prediction.parent.glob("pred.model_idx_*.cif")
}
# Get all npz files and extract their indices
npz_files = {
int(f.stem.split("model_idx_")[1]): f
for f in prediction.parent.glob("scores.model_idx_*.npz")
}
# Find common indices and create pairs
common_indices = sorted(set(cif_files.keys()) & set(npz_files.keys()))
return {
idx: {"prediction_path": cif_files[idx], "metrics": np.load(npz_files[idx])}
for idx in common_indices
}
class ProtenixModel(AF3Model):
def __init__(self, api_key: str):
super().__init__(api_key, "Protenix", ProtenixQuery, ProtenixFastaValidator())
def call(self, seq_file: Path | str, output_dir: Path) -> None:
"""Predict protein structure from amino acid sequence using Protenix model.
Args:
seq_file (Path | str): Path to FASTA file containing amino acid sequence
output_dir (Path): Path to output directory
"""
super().call(seq_file, output_dir)
def predictions(self, output_dir: Path) -> list[Path]:
"""Get the path to the prediction."""
return list(output_dir.rglob("*_model_[0-9].cif"))
class BoltzModel(AF3Model):
def __init__(self, api_key: str):
super().__init__(api_key, "Boltz", BoltzQuery, BoltzFastaValidator())
def call(self, seq_file: Path | str, output_dir: Path) -> None:
"""Predict protein structure from amino acid sequence using Boltz model.
Args:
seq_file (Path | str): Path to FASTA file containing amino acid sequence
output_dir (Path): Path to output directory
"""
super().call(seq_file, output_dir)
def predictions(self, output_dir: Path) -> list[Path]:
"""Get the path to the prediction."""
prediction_paths = list(output_dir.rglob("*_model_[0-9].cif"))
return {
int(cif_path.stem[-1]): {
"prediction_path": cif_path,
"metrics": np.load(list(cif_path.parent.glob("plddt_*.npz"))[0]),
}
for cif_path in prediction_paths
}
def extract_plddt_from_cif(cif_path):
structure = MMCIFParser().get_structure("structure", cif_path)
# Dictionary to store pLDDT values per residue
plddt_values = []
# Iterate through all atoms
for model in structure:
for chain in model:
for residue in chain:
# Get the first atom of each residue (usually CA atom)
if "CA" in residue:
# The B-factor contains the pLDDT value
plddt = residue["CA"].get_bfactor()
plddt_values.append(plddt)
return plddt_values
def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str, str]:
"""Predict protein structure from amino acid sequence using Boltz model.
Args:
sequence (str): Amino acid sequence to predict structure for
api_key (str): Folding API key
model (FoldingModel): Folding model to use
Returns:
tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot
"""
if not api_key:
raise gr.Error("Missing API key, please enter a valid API key")
# Set up unique output directory based on sequence hash
seq_id, seq_file = _write_fasta_file(sequence)
output_dir = OUTPUT_DIR / seq_id / model_type
output_dir.mkdir(parents=True, exist_ok=True)
if model_type == FoldingModel.BOLTZ:
model = BoltzModel(api_key)
elif model_type == FoldingModel.CHAI:
model = ChaiModel(api_key)
elif model_type == FoldingModel.PROTENIX:
model = ProtenixModel(api_key)
else:
raise ValueError(f"Model {model_type} not supported")
# Check if prediction already exists
if not model.has_prediction(output_dir):
# Run Boltz prediction
logger.info(f"Predicting {seq_id}")
model.call(seq_file=seq_file, output_dir=output_dir)
logger.info("Prediction done. Output directory: %s", output_dir)
else:
logger.info("Prediction already exists. Output directory: %s", output_dir)
# output_dir = Path("boltz_results") # debug
# Convert output CIF to PDB
if not model.has_prediction(output_dir):
raise gr.Error("No prediction found")
predictions = model.predictions(output_dir)
pdb_paths = []
model_plddt_vals = []
for model_idx, prediction in predictions.items():
cif_path = prediction["prediction_path"]
logger.info(
"CIF file: %s",
)
converted_pdb_path = str(
output_dir / f"{model.model_name}_prediction_{model_idx}.pdb"
)
convert_cif_to_pdb(str(cif_path), str(converted_pdb_path))
plddt_vals = extract_plddt_from_cif(cif_path)
pdb_paths.append(converted_pdb_path)
model_plddt_vals.append(plddt_vals)
plddt_plot = add_plddt_plot(
plddt_vals=model_plddt_vals, model_name=model.model_name
)
return pdb_paths, plddt_plot
def align_structures(pdb_paths: list[str]) -> list[str]:
"""Align multiple PDB structures to the first structure.
Args:
pdb_paths (list[str]): List of paths to PDB files to align
Returns:
list[str]: List of paths to aligned PDB files
"""
parser = PDBParser()
io = PDBIO()
# Parse the reference structure (first one)
ref_structure = parser.get_structure("reference", pdb_paths[0])
ref_atoms = [atom for atom in ref_structure.get_atoms() if atom.get_name() == "CA"]
aligned_paths = [pdb_paths[0]] # First structure is already aligned
# Align each subsequent structure to the reference
for i, pdb_path in enumerate(pdb_paths[1:], start=1):
# Parse the structure to align
structure = parser.get_structure(f"model_{i}", pdb_path)
atoms = [atom for atom in structure.get_atoms() if atom.get_name() == "CA"]
# Create superimposer
sup = Superimposer()
# Set the reference and moving atoms
sup.set_atoms(ref_atoms, atoms)
# Apply the transformation to all atoms in the structure
sup.apply(structure.get_atoms())
# Save the aligned structure
aligned_path = str(Path(pdb_path).parent / f"aligned_{Path(pdb_path).name}")
io.set_structure(structure)
io.save(aligned_path)
aligned_paths.append(aligned_path)
return aligned_paths
def predict_comparison(
sequence: str, api_key: str, model_types: list[FoldingModel]
) -> tuple[str, str]:
"""Predict protein structure from amino acid sequence using Boltz model.
Args:
sequence (str): Amino acid sequence to predict structure for
api_key (str): Folding API key
model (FoldingModel): Folding model to use
Returns:
tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot
"""
if not api_key:
raise gr.Error("Missing API key, please enter a valid API key")
# Set up unique output directory based on sequence hash
pdb_paths = []
for model_type in model_types:
model_pdb_paths, _ = predict(sequence, api_key, model_type)
pdb_paths += model_pdb_paths
aligned_paths = align_structures(pdb_paths)
return aligned_paths