"""Boltz-1 query to prediction endpoint.""" from __future__ import annotations from io import StringIO from pathlib import Path from typing import Any import yaml from folding_studio_data_models import FoldingModel from pydantic import BaseModel from folding_studio.commands.utils import process_uploaded_msas from folding_studio.query import Query from folding_studio.utils.fasta import validate_fasta from folding_studio.utils.headers import get_auth_headers from folding_studio.utils.path_helpers import validate_path class BoltzParameters(BaseModel): """Boltz inference parameters.""" seed: int = 0 recycling_steps: int = 3 sampling_steps: int = 200 diffusion_samples: int = 1 step_scale: float = 1.638 msa_pairing_strategy: str = "greedy" write_full_pae: bool = False write_full_pde: bool = False use_msa_server: bool = True custom_msa_paths: dict[str, str] | None = None class BoltzQuery(Query): """Boltz1 model query.""" MODEL = FoldingModel.BOLTZ def __init__( self, fasta_dict: dict[str, str], yaml_dict: dict[str, str], query_name: str, parameters: BoltzParameters = BoltzParameters(), ): self.fasta_dict = fasta_dict self.yaml_dict = yaml_dict self.query_name = query_name self._parameters = parameters @staticmethod def _process_file(file_path: Path) -> tuple[dict[str, str], dict[str, str]]: """Processes a single file and extracts its contents. Args: file_path (Path): Path to the file. Returns: tuple[dict[str, str], dict[str, str]]: A tuple containing FASTA and YAML dictionaries. Raises: ValueError: If the file format is unsupported. """ fasta_dict = {} yaml_dict = {} if file_path.suffix in (".fasta", ".fa"): fasta_content = validate_fasta(file_path, str_output=True) fasta_dict = {file_path.stem: fasta_content} elif file_path.suffix in (".yaml", ".yml"): with file_path.open("r", encoding="utf-8") as f: yaml_dict = {file_path.stem: yaml.safe_load(f)} else: raise ValueError(f"Unsupported format: {file_path.suffix}") return fasta_dict, yaml_dict @classmethod def from_protein_sequence( cls: BoltzQuery, sequence: str, query_name: str | None = None, **kwargs ) -> BoltzQuery: """Initialize a BoltzQuery from a str protein sequence. Args: sequence (str): The protein sequence in string format. query_name (str | None, optional): User-defined query name. Defaults to None. **kwargs: Additional parameters for the query. Returns: BoltzQuery """ record = validate_fasta(StringIO(sequence)) custom_msa_paths = kwargs.pop("custom_msa_paths", None) if custom_msa_paths: kwargs["custom_msa_paths"] = cls._upload_custom_msa_files(custom_msa_paths) query_name = ( query_name if query_name is not None else record.description.split("|", maxsplit=1)[0] # first tag ) return cls( fasta_files={query_name: sequence}, query_name=query_name, parameters=BoltzParameters(**kwargs), ) @classmethod def from_file( cls: BoltzQuery, path: str | Path, query_name: str | None = None, **kwargs ) -> BoltzQuery: """Initialize a BoltzQuery instance from a file. Supported file format are: - FASTA - YAML Args: path (str | Path): Path to the file. **kwargs: Additional parameters for the query. Returns: BoltzQuery: An instance of BoltzQuery. """ path = validate_path( path, is_file=True, file_suffix=(".fasta", ".fa", ".yaml", ".yml") ) fasta_dict, yaml_dict = cls._process_file(path) query_name = query_name or path.stem custom_msa_paths = kwargs.pop("custom_msa_paths", None) if custom_msa_paths: kwargs["custom_msa_paths"] = cls._upload_custom_msa_files(custom_msa_paths) return cls( fasta_dict=fasta_dict, yaml_dict=yaml_dict, query_name=query_name, parameters=BoltzParameters(**kwargs), ) @classmethod def from_directory( cls: BoltzQuery, path: str | Path, query_name: str | None = None, **kwargs: Any ) -> BoltzQuery: """Initialize a BoltzQuery instance from a directory. Supported file format in directory are: - FASTA - YAML Args: directory_path (Path): Path to the directory. **kwargs: Additional parameters for the query. Returns: BoltzQuery: An instance of BoltzQuery. """ custom_msa_paths = kwargs.pop("custom_msa_paths", None) if custom_msa_paths: kwargs["custom_msa_paths"] = cls._upload_custom_msa_files(custom_msa_paths) path = validate_path(path, is_dir=True) fasta_dict = {} yaml_dict = {} for file in path.iterdir(): file_fasta_dict, file_yaml_dict = cls._process_file(file) fasta_dict.update(file_fasta_dict) yaml_dict.update(file_yaml_dict) if not (fasta_dict or yaml_dict): raise ValueError(f"No FASTA or YAML files found in directory '{path}'.") query_name = query_name or path.name return cls( fasta_dict=fasta_dict, yaml_dict=yaml_dict, query_name=query_name, parameters=BoltzParameters(**kwargs), ) @property def payload(self) -> dict[str, Any]: """Payload to send to the prediction API endpoint.""" return { "fasta_files": self.fasta_dict, "yaml_files": self.yaml_dict, "parameters": self.parameters.model_dump(mode="json"), } @property def parameters(self) -> BoltzParameters: """Parameters of the query.""" return self._parameters @staticmethod def _upload_custom_msa_files( source: str, headers: str | None = None ) -> dict[str, str]: """Reads MSA files from a file or directory and uploads them to GCS. Args: source (str): Path to an .a3m or .csv file, or a directory containing such files. headers (str | None, optional): GCP authentication headers. Defaults to None. Raises: ValueError: If the file has an unsupported extension. ValueError: If a directory contains no .a3m or .csv files. Returns: dict[str, str]: A mapping of uploaded file names to their GCS URLs. """ headers = headers or get_auth_headers() source_path = validate_path(source) valid_extensions = {".a3m", ".csv"} # Allow both a3m and csv files # Process if source is a file if source_path.is_file(): if source_path.suffix not in valid_extensions: raise ValueError( f"Invalid file type: {source_path.suffix}. Expected one of {valid_extensions}." ) return process_uploaded_msas([source_path], headers) # Process if source is a directory elif source_path.is_dir(): valid_files = [ file for file in source_path.iterdir() if file.suffix in valid_extensions ] if not valid_files: raise ValueError( f"Directory '{source}' contains no valid files with extensions {valid_extensions}." ) return process_uploaded_msas(valid_files, headers)