"""Query module for Chai prediction endpoint.""" from __future__ import annotations import shutil import tempfile from io import StringIO from itertools import chain from pathlib import Path from typing import Any from folding_studio_data_models import FoldingModel from pydantic import BaseModel, field_validator from folding_studio.commands.utils import ( a3m_to_aligned_pqt, process_uploaded_msas, ) from folding_studio.query import Query from folding_studio.utils.fasta import validate_fasta from folding_studio.utils.headers import get_auth_headers from folding_studio.utils.path_helpers import validate_path class ChaiParameters(BaseModel): """Chai1 inference parameters.""" seed: int = 0 num_trunk_recycles: int = 3 num_diffn_timesteps: int = 200 recycle_msa_subsample: int = 0 num_trunk_samples: int = 1 restraints: str | None = None use_msa_server: bool = False use_templates_server: bool = False custom_msa_paths: dict[str, str] | None = None @field_validator("restraints", mode="before") def read_restraints( cls: ChaiParameters, restraints: str | Path | None ) -> str | None: """Reads restraints from a CSV file and returns its content as a string.""" if restraints is None: return path = validate_path(restraints, is_file=True, file_suffix=(".csv")) with path.open(newline="", encoding="utf-8") as csvfile: return csvfile.read().strip() class ChaiQuery(Query): """Chai1 model query.""" MODEL = FoldingModel.CHAI def __init__( self, fasta_files: dict[str, str], query_name: str, parameters: ChaiParameters = ChaiParameters(), ): """Initializes a ChaiQuery instance.""" if not fasta_files: raise ValueError("FASTA files dictionary cannot be empty.") self.fasta_files = fasta_files self.query_name = query_name self._parameters = parameters @classmethod def from_protein_sequence( cls: ChaiQuery, sequence: str, query_name: str | None = None, **kwargs ) -> ChaiQuery: """Initialize a ChaiQuery instance from a str protein sequence. Args: sequence (str): Protein amino-acid sequence query_name (str | None, optional): User-defined query name. Defaults to None. Raises: NotAMonomer: If the sequence is not a monomer complex. Returns: ChaiQuery """ record = validate_fasta(StringIO(sequence)) custom_msa_paths = kwargs.pop("custom_msa_paths", None) if custom_msa_paths: kwargs["custom_msa_paths"] = cls._upload_custom_msa_files(custom_msa_paths) query_name = ( query_name if query_name is not None else record.description.split("|", maxsplit=1)[0] # first tag ) return cls( fasta_files={query_name: sequence}, query_name=query_name, parameters=ChaiParameters(**kwargs), ) @classmethod def from_file( cls: ChaiQuery, path: str | Path, query_name: str | None = None, **kwargs ) -> ChaiQuery: """Initialize a ChaiQuery instance from a file. Supported file format are: - FASTA Args: path (str | Path): Path of the FASTA file. query_name (str | None, optional): User-defined query name. Defaults to None. Returns: ChaiQuery """ path = validate_path(path, is_file=True, file_suffix=(".fasta", ".fa")) custom_msa_paths = kwargs.pop("custom_msa_paths", None) if custom_msa_paths: kwargs["custom_msa_paths"] = cls._upload_custom_msa_files(custom_msa_paths) query_name = query_name or path.stem return cls( fasta_files={path.stem: validate_fasta(path, str_output=True)}, query_name=query_name, parameters=ChaiParameters(**kwargs), ) @classmethod def from_directory( cls: ChaiQuery, path: str | Path, query_name: str | None = None, **kwargs ) -> ChaiQuery: """Initialize a ChaiQuery instance from a directory. Supported file format in directory are: - FASTA Args: path (str | Path): Path to a directory of FASTA files. query_name (str | None, optional): User-defined query name. Defaults to None. seed (int, optional): Random seed. Defaults to 0. Raises: ValueError: If no FASTA file are present in the directory. Returns: ChaiQuery """ path = validate_path(path, is_dir=True) custom_msa_paths = kwargs.pop("custom_msa_paths", None) if custom_msa_paths: kwargs["custom_msa_paths"] = cls._upload_custom_msa_files(custom_msa_paths) print(kwargs["custom_msa_paths"]) fasta_files = { file.stem: validate_fasta(file, str_output=True) for file in chain(path.glob("*.fasta"), path.glob("*.fa")) } if not fasta_files: raise ValueError(f"No FASTA files found in directory '{path}'.") query_name = query_name or path.name return cls( fasta_files=fasta_files, query_name=query_name, parameters=ChaiParameters(**kwargs), ) @property def payload(self) -> dict[str, Any]: """Payload to send to the prediction API endpoint.""" return { "fasta_files": self.fasta_files, "use_msa_server": self.parameters.use_msa_server, "use_templates_server": self.parameters.use_templates_server, "num_trunk_recycles": self.parameters.num_trunk_recycles, "seed": self.parameters.seed, "num_diffn_timesteps": self.parameters.num_diffn_timesteps, "restraints": self.parameters.restraints, "recycle_msa_subsample": self.parameters.recycle_msa_subsample, "num_trunk_samples": self.parameters.num_trunk_samples, "custom_msa_paths": self.parameters.custom_msa_paths, } @property def parameters(self) -> ChaiParameters: """Parameters of the query.""" return self._parameters @staticmethod def _upload_custom_msa_files( source: str, headers: str | None = None ) -> dict[str, str]: """Read A3M or MSA files from a file or directory and uploads them to GCS. Args: source (str): Path to an .a3m or .aligned.pqt file or a directory containing .a3m or .aligned.pqt files identity_token (str | None, optional): GCP identity token. Defaults to None. Raises: ValueError: If file has unsupported extension. ValueError: If directory has no supported file. Returns: dict[str, str]: _description_ """ headers = headers or get_auth_headers() source_path = validate_path(source) # Process if source is a file. if source_path.is_file(): if source_path.suffix == ".a3m": with tempfile.TemporaryDirectory() as tmpdir: tmp_path = Path(tmpdir) shutil.copy(source_path, tmp_path / source_path.name) pqt_file = a3m_to_aligned_pqt(str(tmp_path)) return process_uploaded_msas([Path(pqt_file)], headers) elif source_path.name.endswith(".aligned.pqt"): return process_uploaded_msas([source_path], headers) else: raise ValueError( f"Invalid file type: {source_path.suffix}. Expected '.a3m' or a file ending with '.aligned.pqt'." ) # Process if source is a directory. elif source_path.is_dir(): pqt_files = list(source_path.glob("*.aligned.pqt")) if pqt_files: return process_uploaded_msas(pqt_files, headers) a3m_files = list(source_path.glob("*.a3m")) if not a3m_files: raise ValueError( f"Directory '{source}' contains no files ending with '.aligned.pqt' or '.a3m'." ) with tempfile.TemporaryDirectory() as tmpdir: tmp_path = Path(tmpdir) for file in a3m_files: shutil.copy(file, tmp_path / file.name) pqt_file = a3m_to_aligned_pqt(str(tmp_path)) return process_uploaded_msas([Path(pqt_file)], headers)