"""Predict protein structure using Folding Studio.""" import hashlib import logging import os from io import StringIO from pathlib import Path from typing import Any import gradio as gr import numpy as np import plotly.graph_objects as go from Bio import SeqIO from Bio.PDB import PDBIO, MMCIFParser, PDBParser, Superimposer from folding_studio.client import Client from folding_studio.query import Query from folding_studio.query.boltz import BoltzQuery from folding_studio.query.chai import ChaiQuery from folding_studio.query.protenix import ProtenixQuery from folding_studio_data_models import FoldingModel from folding_studio_demo.model_fasta_validators import ( BaseFastaValidator, BoltzFastaValidator, ChaiFastaValidator, ProtenixFastaValidator, ) logger = logging.getLogger(__name__) SEQUENCE_DIR = Path("sequences") SEQUENCE_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_DIR = Path("output") OUTPUT_DIR.mkdir(parents=True, exist_ok=True) def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None: """Convert a .cif file to .pdb format using Biopython. Args: cif_path (str): Path to input .cif file pdb_path (str): Path to output .pdb file """ # Parse the CIF file parser = MMCIFParser() structure = parser.get_structure("structure", cif_path) # Save as PDB io = PDBIO() io.set_structure(structure) io.save(pdb_path) def add_plddt_plot(plddt_vals: list[list[float]], model_name: str) -> go.Figure: """Create a plot of metrics.""" visible = True plddt_traces = [ go.Scatter( x=np.arange(len(plddt_val)), y=plddt_val, hovertemplate="pLDDT: %{y:.2f}
Residue index: %{x}
", name=f"{model_name} {i}", visible=visible, ) for i, plddt_val in enumerate(plddt_vals) ] plddt_fig = go.Figure(data=plddt_traces) plddt_fig.update_layout( title="pLDDT", xaxis_title="Residue index", yaxis_title="pLDDT", height=500, template="simple_white", legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99), ) return plddt_fig def _write_fasta_file( sequence: str, directory: Path = SEQUENCE_DIR ) -> tuple[str, Path]: """Write sequence to FASTA file. Args: sequence (str): Sequence to write to FASTA file directory (Path): Directory to write FASTA file to (default: SEQUENCE_DIR) Returns: tuple[str, Path]: Tuple containing the sequence ID and the path to the FASTA file """ input_rep = list(SeqIO.parse(StringIO(sequence), "fasta")) if not input_rep: raise gr.Error("No sequence found") seq_id = hashlib.sha256( "_".join([str(records.seq) for records in input_rep]).encode() ).hexdigest() seq_file = directory / f"sequence_{seq_id}.fasta" with open(seq_file, "w") as f: f.write(sequence) return seq_id, seq_file class AF3Model: def __init__( self, api_key: str, model_name: str, query: Query, validator: BaseFastaValidator ): self.api_key = api_key self.model_name = model_name self.query = query self.validator = validator def call(self, seq_file: Path | str, output_dir: Path) -> None: """Predict protein structure from amino acid sequence using AF3 model. Args: seq_file (Path | str): Path to FASTA file containing amino acid sequence output_dir (Path): Path to output directory """ # Validate FASTA format before calling is_valid, error_msg = self.check_file_description(seq_file) if not is_valid: logger.error(error_msg) raise gr.Error(error_msg) # Create a client using API key logger.info("Authenticating client with API key") client = Client.from_api_key(api_key=self.api_key) # Define query query: Query = self.query.from_file(path=seq_file, query_name="gradio") query.save_parameters(output_dir) logger.info("Payload: %s", query.payload) # Send a request logger.info(f"Sending {self.model_name} request to Folding Studio API") response = client.send_request( query, project_code=os.environ["FOLDING_PROJECT_CODE"] ) # Access confidence data logger.info("Confidence data: %s", response.confidence_data) response.download_results(output_dir=output_dir, force=True, unzip=True) logger.info("Results downloaded to %s", output_dir) def format_fasta(self, sequence: str) -> str: """Format sequence to FASTA format.""" return f">{self.model_name}\n{sequence}" def predictions(self, output_dir: Path) -> list[Path]: """Get the path to the prediction.""" raise NotImplementedError("Not implemented") def has_prediction(self, output_dir: Path) -> bool: """Check if prediction exists in output directory.""" return len(self.predictions(output_dir)) > 0 def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]: """Check if the file description is correct. Args: seq_file (Path | str): Path to FASTA file Returns: tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not """ is_valid, error_msg = self.validator.is_valid_fasta(seq_file) if not is_valid: return False, error_msg return True, None class ChaiModel(AF3Model): def __init__(self, api_key: str): super().__init__(api_key, "Chai", ChaiQuery, ChaiFastaValidator()) def call(self, seq_file: Path | str, output_dir: Path) -> None: """Predict protein structure from amino acid sequence using Chai model. Args: seq_file (Path | str): Path to FASTA file containing amino acid sequence output_dir (Path): Path to output directory """ super().call(seq_file, output_dir) def _get_chai_paired_files(self, directory: Path) -> list[tuple[Path, Path]]: """Get pairs of .cif and .npz files with matching model indices. Args: directory (Path): Directory containing the prediction files Returns: list[tuple[Path, Path]]: List of tuples containing (cif_path, npz_path) pairs """ # Get all cif files and extract their indices def predictions(self, output_dir: Path) -> dict[Path, dict[str, Any]]: """Get the path to the prediction.""" prediction = next(output_dir.rglob("pred.model_idx_[0-9].cif"), None) if prediction is None: return {} cif_files = { int(f.stem.split("model_idx_")[1]): f for f in prediction.parent.glob("pred.model_idx_*.cif") } # Get all npz files and extract their indices npz_files = { int(f.stem.split("model_idx_")[1]): f for f in prediction.parent.glob("scores.model_idx_*.npz") } # Find common indices and create pairs common_indices = sorted(set(cif_files.keys()) & set(npz_files.keys())) return { idx: {"prediction_path": cif_files[idx], "metrics": np.load(npz_files[idx])} for idx in common_indices } class ProtenixModel(AF3Model): def __init__(self, api_key: str): super().__init__(api_key, "Protenix", ProtenixQuery, ProtenixFastaValidator()) def call(self, seq_file: Path | str, output_dir: Path) -> None: """Predict protein structure from amino acid sequence using Protenix model. Args: seq_file (Path | str): Path to FASTA file containing amino acid sequence output_dir (Path): Path to output directory """ super().call(seq_file, output_dir) def predictions(self, output_dir: Path) -> list[Path]: """Get the path to the prediction.""" return list(output_dir.rglob("*_model_[0-9].cif")) class BoltzModel(AF3Model): def __init__(self, api_key: str): super().__init__(api_key, "Boltz", BoltzQuery, BoltzFastaValidator()) def call(self, seq_file: Path | str, output_dir: Path) -> None: """Predict protein structure from amino acid sequence using Boltz model. Args: seq_file (Path | str): Path to FASTA file containing amino acid sequence output_dir (Path): Path to output directory """ super().call(seq_file, output_dir) def predictions(self, output_dir: Path) -> list[Path]: """Get the path to the prediction.""" prediction_paths = list(output_dir.rglob("*_model_[0-9].cif")) return { int(cif_path.stem[-1]): { "prediction_path": cif_path, "metrics": np.load(list(cif_path.parent.glob("plddt_*.npz"))[0]), } for cif_path in prediction_paths } def extract_plddt_from_cif(cif_path): structure = MMCIFParser().get_structure("structure", cif_path) # Dictionary to store pLDDT values per residue plddt_values = [] # Iterate through all atoms for model in structure: for chain in model: for residue in chain: # Get the first atom of each residue (usually CA atom) if "CA" in residue: # The B-factor contains the pLDDT value plddt = residue["CA"].get_bfactor() plddt_values.append(plddt) return plddt_values def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str, str]: """Predict protein structure from amino acid sequence using Boltz model. Args: sequence (str): Amino acid sequence to predict structure for api_key (str): Folding API key model (FoldingModel): Folding model to use Returns: tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot """ if not api_key: raise gr.Error("Missing API key, please enter a valid API key") # Set up unique output directory based on sequence hash seq_id, seq_file = _write_fasta_file(sequence) output_dir = OUTPUT_DIR / seq_id / model_type output_dir.mkdir(parents=True, exist_ok=True) if model_type == FoldingModel.BOLTZ: model = BoltzModel(api_key) elif model_type == FoldingModel.CHAI: model = ChaiModel(api_key) elif model_type == FoldingModel.PROTENIX: model = ProtenixModel(api_key) else: raise ValueError(f"Model {model_type} not supported") # Check if prediction already exists if not model.has_prediction(output_dir): # Run Boltz prediction logger.info(f"Predicting {seq_id}") model.call(seq_file=seq_file, output_dir=output_dir) logger.info("Prediction done. Output directory: %s", output_dir) else: logger.info("Prediction already exists. Output directory: %s", output_dir) # output_dir = Path("boltz_results") # debug # Convert output CIF to PDB if not model.has_prediction(output_dir): raise gr.Error("No prediction found") predictions = model.predictions(output_dir) pdb_paths = [] model_plddt_vals = [] for model_idx, prediction in predictions.items(): cif_path = prediction["prediction_path"] logger.info( "CIF file: %s", ) converted_pdb_path = str( output_dir / f"{model.model_name}_prediction_{model_idx}.pdb" ) convert_cif_to_pdb(str(cif_path), str(converted_pdb_path)) plddt_vals = extract_plddt_from_cif(cif_path) pdb_paths.append(converted_pdb_path) model_plddt_vals.append(plddt_vals) plddt_plot = add_plddt_plot( plddt_vals=model_plddt_vals, model_name=model.model_name ) return pdb_paths, plddt_plot def align_structures(pdb_paths: list[str]) -> list[str]: """Align multiple PDB structures to the first structure. Args: pdb_paths (list[str]): List of paths to PDB files to align Returns: list[str]: List of paths to aligned PDB files """ parser = PDBParser() io = PDBIO() # Parse the reference structure (first one) ref_structure = parser.get_structure("reference", pdb_paths[0]) ref_atoms = [atom for atom in ref_structure.get_atoms() if atom.get_name() == "CA"] aligned_paths = [pdb_paths[0]] # First structure is already aligned # Align each subsequent structure to the reference for i, pdb_path in enumerate(pdb_paths[1:], start=1): # Parse the structure to align structure = parser.get_structure(f"model_{i}", pdb_path) atoms = [atom for atom in structure.get_atoms() if atom.get_name() == "CA"] # Create superimposer sup = Superimposer() # Set the reference and moving atoms sup.set_atoms(ref_atoms, atoms) # Apply the transformation to all atoms in the structure sup.apply(structure.get_atoms()) # Save the aligned structure aligned_path = str(Path(pdb_path).parent / f"aligned_{Path(pdb_path).name}") io.set_structure(structure) io.save(aligned_path) aligned_paths.append(aligned_path) return aligned_paths def predict_comparison( sequence: str, api_key: str, model_types: list[FoldingModel] ) -> tuple[str, str]: """Predict protein structure from amino acid sequence using Boltz model. Args: sequence (str): Amino acid sequence to predict structure for api_key (str): Folding API key model (FoldingModel): Folding model to use Returns: tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot """ if not api_key: raise gr.Error("Missing API key, please enter a valid API key") # Set up unique output directory based on sequence hash pdb_paths = [] for model_type in model_types: model_pdb_paths, _ = predict(sequence, api_key, model_type) pdb_paths += model_pdb_paths aligned_paths = align_structures(pdb_paths) return aligned_paths