"""Predict protein structure using Folding Studio.""" import hashlib import logging import os from pathlib import Path import gradio as gr import numpy as np import plotly.graph_objects as go from Bio import SeqIO from Bio.PDB import PDBIO, MMCIFParser from folding_studio.client import Client from folding_studio.query import Query from folding_studio.query.boltz import BoltzQuery from folding_studio.query.chai import ChaiQuery from folding_studio.query.protenix import ProtenixQuery from folding_studio_data_models import FoldingModel from folding_studio_demo.model_fasta_validators import ( BaseFastaValidator, BoltzFastaValidator, ChaiFastaValidator, ProtenixFastaValidator, ) logger = logging.getLogger(__name__) SEQUENCE_DIR = Path("sequences") SEQUENCE_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_DIR = Path("output") OUTPUT_DIR.mkdir(parents=True, exist_ok=True) def convert_cif_to_pdb(cif_path: str, pdb_path: str) -> None: """Convert a .cif file to .pdb format using Biopython. Args: cif_path (str): Path to input .cif file pdb_path (str): Path to output .pdb file """ # Parse the CIF file parser = MMCIFParser() structure = parser.get_structure("structure", cif_path) # Save as PDB io = PDBIO() io.set_structure(structure) io.save(pdb_path) def add_plddt_plot(plddt_vals: list[float]) -> str: """Create a plot of metrics.""" visible = True plddt_trace = go.Scatter( x=np.arange(len(plddt_vals)), y=plddt_vals, hovertemplate="pLDDT: %{y:.2f}
Residue index: %{x}
", name="seq", visible=visible, ) plddt_fig = go.Figure(data=[plddt_trace]) plddt_fig.update_layout( title="pLDDT", xaxis_title="Residue index", yaxis_title="pLDDT", height=500, template="simple_white", legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99), ) return plddt_fig def _write_fasta_file( sequence: str, directory: Path = SEQUENCE_DIR ) -> tuple[str, Path]: """Write sequence to FASTA file. Args: sequence (str): Sequence to write to FASTA file directory (Path): Directory to write FASTA file to (default: SEQUENCE_DIR) Returns: tuple[str, Path]: Tuple containing the sequence ID and the path to the FASTA file """ seq_id = hashlib.sha1(sequence.encode()).hexdigest() seq_file = directory / f"sequence_{seq_id}.fasta" with open(seq_file, "w") as f: f.write(sequence) return seq_id, seq_file class AF3Model: def __init__( self, api_key: str, model_name: str, query: Query, validator: BaseFastaValidator ): self.api_key = api_key self.model_name = model_name self.query = query self.validator = validator def call(self, seq_file: Path | str, output_dir: Path) -> None: """Predict protein structure from amino acid sequence using AF3 model. Args: seq_file (Path | str): Path to FASTA file containing amino acid sequence output_dir (Path): Path to output directory """ # Validate FASTA format before calling is_valid, error_msg = self.check_file_description(seq_file) if not is_valid: logger.error(error_msg) raise gr.Error(error_msg) # Create a client using API key logger.info("Authenticating client with API key") client = Client.from_api_key(api_key=self.api_key) # Define query query: Query = self.query.from_file(path=seq_file, query_name="gradio") query.save_parameters(output_dir) logger.info("Payload: %s", query.payload) # Send a request logger.info(f"Sending {self.model_name} request to Folding Studio API") response = client.send_request( query, project_code=os.environ["FOLDING_PROJECT_CODE"] ) # Access confidence data logger.info("Confidence data: %s", response.confidence_data) response.download_results(output_dir=output_dir, force=True, unzip=True) logger.info("Results downloaded to %s", output_dir) def format_fasta(self, sequence: str) -> str: """Format sequence to FASTA format.""" return f">{self.model_name}\n{sequence}" def predictions(self, output_dir: Path) -> list[Path]: """Get the path to the prediction.""" raise NotImplementedError("Not implemented") def has_prediction(self, output_dir: Path) -> bool: """Check if prediction exists in output directory.""" return any(self.predictions(output_dir)) def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]: """Check if the file description is correct. Args: seq_file (Path | str): Path to FASTA file Returns: tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not """ input_rep = list(SeqIO.parse(seq_file, "fasta")) if not input_rep: error_msg = f"{self.model_name.upper()} Validation Error: No sequence found" return False, error_msg is_valid, error_msg = self.validator.is_valid_fasta(seq_file) if not is_valid: return False, error_msg return True, None class ChaiModel(AF3Model): def __init__(self, api_key: str): super().__init__(api_key, "Chai", ChaiQuery, ChaiFastaValidator()) def call(self, seq_file: Path | str, output_dir: Path) -> None: """Predict protein structure from amino acid sequence using Chai model. Args: seq_file (Path | str): Path to FASTA file containing amino acid sequence output_dir (Path): Path to output directory """ super().call(seq_file, output_dir) def predictions(self, output_dir: Path) -> list[Path]: """Get the path to the prediction.""" return list(output_dir.rglob("*_model_[0-9].cif")) class ProtenixModel(AF3Model): def __init__(self, api_key: str): super().__init__(api_key, "Protenix", ProtenixQuery, ProtenixFastaValidator()) def call(self, seq_file: Path | str, output_dir: Path) -> None: """Predict protein structure from amino acid sequence using Protenix model. Args: seq_file (Path | str): Path to FASTA file containing amino acid sequence output_dir (Path): Path to output directory """ super().call(seq_file, output_dir) def predictions(self, output_dir: Path) -> list[Path]: """Get the path to the prediction.""" return list(output_dir.rglob("*_model_[0-9].cif")) class BoltzModel(AF3Model): def __init__(self, api_key: str): super().__init__(api_key, "Boltz", BoltzQuery, BoltzFastaValidator()) def call(self, seq_file: Path | str, output_dir: Path) -> None: """Predict protein structure from amino acid sequence using Boltz model. Args: seq_file (Path | str): Path to FASTA file containing amino acid sequence output_dir (Path): Path to output directory """ super().call(seq_file, output_dir) def predictions(self, output_dir: Path) -> list[Path]: """Get the path to the prediction.""" return list(output_dir.rglob("*_model_[0-9].cif")) def predict(sequence: str, api_key: str, model_type: FoldingModel) -> tuple[str, str]: """Predict protein structure from amino acid sequence using Boltz model. Args: sequence (str): Amino acid sequence to predict structure for api_key (str): Folding API key model (FoldingModel): Folding model to use Returns: tuple[str, str]: Tuple containing the path to the PDB file and the pLDDT plot """ # Set up unique output directory based on sequence hash seq_id, seq_file = _write_fasta_file(sequence) output_dir = OUTPUT_DIR / seq_id / model_type output_dir.mkdir(parents=True, exist_ok=True) if model_type == FoldingModel.BOLTZ: model = BoltzModel(api_key) elif model_type == FoldingModel.CHAI: model = ChaiModel(api_key) elif model_type == FoldingModel.PROTENIX: model = ProtenixModel(api_key) else: raise ValueError(f"Model {model_type} not supported") # Check if prediction already exists if not model.has_prediction(output_dir): # Run Boltz prediction logger.info(f"Predicting {seq_id}") model.call(seq_file=seq_file, output_dir=output_dir) logger.info("Prediction done. Output directory: %s", output_dir) else: logger.info("Prediction already exists. Output directory: %s", output_dir) # output_dir = Path("boltz_results") # debug # Convert output CIF to PDB if not model.has_prediction(output_dir): raise gr.Error("No prediction found") pred_cif = model.predictions(output_dir)[0] logger.info("Output file: %s", pred_cif) converted_pdb_path = str(output_dir / f"pred_{seq_id}.pdb") convert_cif_to_pdb(str(pred_cif), str(converted_pdb_path)) logger.info("Converted PDB file: %s", converted_pdb_path) plddt_file = list(pred_cif.parent.glob("plddt_*.npz"))[0] logger.info("plddt file: %s", plddt_file) plddt_vals = np.load(plddt_file)["plddt"] return converted_pdb_path, add_plddt_plot(plddt_vals=plddt_vals)