"""Models for the Folding Studio API.""" import json import logging import os import sys import time from io import StringIO from pathlib import Path from typing import Any import folding_studio import gradio as gr import numpy as np from folding_studio import single_job_prediction from folding_studio.client import Client from folding_studio.commands.experiment import results as get_results from folding_studio.commands.experiment import status as get_status from folding_studio.query import Query from folding_studio.query.boltz import BoltzQuery from folding_studio.query.chai import ChaiQuery from folding_studio.query.protenix import ProtenixQuery from folding_studio_data_models import AF2Parameters, OpenFoldParameters from folding_studio_data_models.parameters.base import BaseFoldingParameters from folding_studio_demo.model_fasta_validators import ( BaseFastaValidator, BoltzFastaValidator, ChaiFastaValidator, ProtenixFastaValidator, ) class Capturing(list): """Capture stdout output.""" def __enter__(self): self._stdout = sys.stdout sys.stdout = self._stringio = StringIO() return self def __exit__(self, *args): self.extend(self._stringio.getvalue().splitlines()) del self._stringio # free up some memory sys.stdout = self._stdout logger = logging.getLogger(__name__) class AF3Model: model_name = None def __init__(self, api_key: str, query: Query, validator: BaseFastaValidator): self.api_key = api_key self.query = query self.validator = validator def call( self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False ) -> None: """Predict protein structure from amino acid sequence using AF3 model. Args: seq_file (Path | str): Path to FASTA file containing amino acid sequence output_dir (Path): Path to output directory format_description (bool): Whether to format the description of the sequence """ # Validate FASTA format before calling is_valid, error_msg = self.check_file_description(seq_file) if format_fasta and not is_valid: logger.info("Invalid FASTA file format, forcing formatting...") self.format_fasta(seq_file) elif not is_valid: logger.error(error_msg) raise gr.Error(error_msg) # Create a client using API key logger.info("Authenticating client with API key") client = Client.from_api_key(api_key=self.api_key) # Define query query: Query = self.query.from_file(path=seq_file, query_name="gradio") query.save_parameters(output_dir) logger.info("Payload: %s", query.payload) # Send a request logger.info(f"Sending {self.model_name} request to Folding Studio API") response = client.send_request( query, project_code=os.environ["FOLDING_PROJECT_CODE"] ) # Access confidence data logger.info("Confidence data: %s", response.confidence_data) response.download_results(output_dir=output_dir, force=True, unzip=True) logger.info("Results downloaded to %s", output_dir) def format_fasta(self, seq_file: Path | str) -> None: """Format sequence to FASTA format. Args: seq_file (Path | str): Path to FASTA file """ formatted_fasta = self.validator.transform_fasta(seq_file) with open(seq_file, "w") as f: f.write(formatted_fasta) def predictions(self, output_dir: Path) -> list[Path]: """Get the path to the prediction. Args: output_dir (Path): Path to output directory Returns: list[Path]: List of paths to predictions """ raise NotImplementedError() def has_prediction(self, output_dir: Path) -> bool: """Check if prediction exists in output directory.""" return len(self.predictions(output_dir)) > 0 def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]: """Check if the file description is correct. Args: seq_file (Path | str): Path to FASTA file Returns: tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not """ is_valid, error_msg = self.validator.is_valid_fasta(seq_file) if not is_valid: return False, error_msg return True, None class ChaiModel(AF3Model): model_name = "Chai" def __init__(self, api_key: str): super().__init__(api_key, ChaiQuery, ChaiFastaValidator()) def call( self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False ) -> None: """Predict protein structure from amino acid sequence using Chai model. Args: seq_file (Path | str): Path to FASTA file containing amino acid sequence output_dir (Path): Path to output directory format_fasta (bool): Whether to format the FASTA file """ super().call(seq_file, output_dir, format_fasta) def predictions(self, output_dir: Path) -> dict[Path, dict[str, Any]]: """Get the path to the prediction.""" prediction = next(output_dir.rglob("pred.model_idx_[0-9].cif"), None) if prediction is None: return {} cif_files = { int(f.stem.split("model_idx_")[1]): f for f in prediction.parent.glob("pred.model_idx_*.cif") } # Get all npz files and extract their indices npz_files = { int(f.stem.split("model_idx_")[1]): f for f in prediction.parent.glob("scores.model_idx_*.npz") } # Find common indices and create pairs common_indices = sorted(set(cif_files.keys()) & set(npz_files.keys())) return { idx: {"prediction_path": cif_files[idx], "metrics": np.load(npz_files[idx])} for idx in common_indices } class ProtenixModel(AF3Model): model_name = "Protenix" def __init__(self, api_key: str): super().__init__(api_key, ProtenixQuery, ProtenixFastaValidator()) def call( self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False ) -> None: """Predict protein structure from amino acid sequence using Protenix model. Args: seq_file (Path | str): Path to FASTA file containing amino acid sequence output_dir (Path): Path to output directory format_fasta (bool): Whether to format the FASTA file """ super().call(seq_file, output_dir, format_fasta) def predictions(self, output_dir: Path) -> list[Path]: """Get the path to the prediction.""" prediction = next(output_dir.rglob("sequence_*_sample_[0-9].cif"), None) if prediction is None: return {} cif_files = { int(f.stem[-1]): f for f in prediction.parent.glob("sequence_*_sample_[0-9].cif") } # Get all npz files and extract their indices json_files = { int(f.stem[-1]): f for f in prediction.parent.glob( "sequence_*_summary_confidence_sample_[0-9].json" ) } # Find common indices and create pairs common_indices = sorted(set(cif_files.keys()) & set(json_files.keys())) return { idx: { "prediction_path": cif_files[idx], "metrics": json.load(open(json_files[idx])), } for idx in common_indices } class BoltzModel(AF3Model): model_name = "Boltz" def __init__(self, api_key: str): super().__init__(api_key, BoltzQuery, BoltzFastaValidator()) def call( self, seq_file: Path | str, output_dir: Path, format_fasta: bool = False ) -> None: """Predict protein structure from amino acid sequence using Boltz model. Args: seq_file (Path | str): Path to FASTA file containing amino acid sequence output_dir (Path): Path to output directory format_fasta (bool): Whether to format the FASTA file """ super().call(seq_file, output_dir, format_fasta) def predictions(self, output_dir: Path) -> list[Path]: """Get the path to the prediction.""" prediction_paths = list(output_dir.rglob("*_model_[0-9].cif")) return { int(cif_path.stem[-1]): { "prediction_path": cif_path, "metrics": np.load(list(cif_path.parent.glob("plddt_*.npz"))[0]), } for cif_path in prediction_paths } class OldModel: model_name = None def __init__(self, api_key: str): self.api_key = api_key def call( self, seq_file: Path | str, output_dir: Path, parameters: BaseFoldingParameters, *args, **kwargs, ) -> None: """Predict protein structure from amino acid sequence using AF2 model. Args: seq_file (Path | str): Path to FASTA file containing amino acid sequence output_dir (Path): Path to output directory """ output = single_job_prediction( fasta_file=seq_file, parameters=parameters, api_key=self.api_key, ) experiment_id = output["message"]["experiment_id"] done = False while not done: with Capturing() as output: get_status(experiment_id, api_key=self.api_key) status = output[0] logger.info(f"Experiment {experiment_id} status: {status}") if status == "Done": done = True logger.info("Downloading results") get_results( experiment_id, force=True, unzip=True, output=output_dir / "results.zip", api_key=self.api_key, ) logger.info("Results downloaded to %s", output_dir) else: logger.info("Sleeping for 10 seconds") time.sleep(10) def format_fasta(self, seq_file: Path | str) -> None: """Format sequence to FASTA format. Args: seq_file (Path | str): Path to FASTA file """ return def predictions(self, output_dir: Path) -> dict[int, dict[str, Any]]: """Get the path to the prediction. Args: output_dir (Path): Path to output directory Returns: dict[int, dict[str, Any]]: Dictionary mapping model indices to their prediction paths and metrics """ prediction_paths = list( (output_dir / "results").rglob("relaxed_model_[0-9]_*_pred_0.pdb") ) metrics_path = output_dir / "results" / "metrics_per_model.json" if not metrics_path.exists(): return {} with open(metrics_path, "r") as f: metrics = json.load(f) output = {} for pred_path in prediction_paths: model_id = int(pred_path.stem.split("_")[2]) model_name = "_".join(pred_path.stem.split("_")[1:-2]) output[model_id] = { "prediction_path": pred_path, "metrics": metrics[model_name], } return output def has_prediction(self, output_dir: Path) -> bool: """Check if prediction exists in output directory.""" return len(self.predictions(output_dir)) > 0 def check_file_description(self, seq_file: Path | str) -> tuple[bool, str | None]: """Check if the file description is correct. Args: seq_file (Path | str): Path to FASTA file Returns: tuple[bool, str | None]: Tuple containing a boolean indicating if the format is correct and an error message if not """ return True, None class AF2Model(OldModel): model_name = "AlphaFold2" def call(self, seq_file: Path | str, output_dir: Path, *args, **kwargs) -> None: super().call(seq_file, output_dir, AF2Parameters(), *args, **kwargs) class OpenFoldModel(OldModel): model_name = "OpenFold" def call(self, seq_file: Path | str, output_dir: Path, *args, **kwargs) -> None: super().call(seq_file, output_dir, OpenFoldParameters(), *args, **kwargs)