"""Query module for for Protenix prediction endpoint.""" from __future__ import annotations from io import StringIO from itertools import chain from pathlib import Path from typing import Any from folding_studio_data_models import FoldingModel from pydantic import BaseModel, Field from folding_studio.query import Query from folding_studio.utils.fasta import validate_fasta from folding_studio.utils.path_helpers import validate_path class ProtenixParameters(BaseModel): """Protenix inference parameters.""" seeds: str = Field(alias="seed", default="0", coerce_numbers_to_str=True) use_msa_server: bool = True class ProtenixQuery(Query): """Protenix model query.""" MODEL = FoldingModel.PROTENIX def __init__( self, fasta_files: dict[str, Any], query_name: str, parameters: ProtenixParameters = ProtenixParameters(), ): if not fasta_files: raise ValueError("FASTA files dictionary cannot be empty.") self.fasta_files = fasta_files self.query_name = query_name self._parameters = parameters @classmethod def from_protein_sequence( cls, sequence: str, query_name: str | None = None, **kwargs ) -> ProtenixQuery: """Initialize a ProtenixQuery from a str protein sequence. Args: sequence (str): The protein sequence in string format. query_name (str | None, optional): User-defined query name. Defaults to None. seed (int, optional): Random seed. Defaults to 0. use_msa_server (bool, optional): Use the MSA server for inference. Defaults to False. Returns: ProtenixQuery: An instance of ProtenixQuery with the sequence stored as a FASTA file. """ record = validate_fasta(StringIO(sequence)) query_name = ( query_name if query_name is not None else record.description.split("|", maxsplit=1)[0] # first tag ) return cls( fasta_files={query_name: sequence}, query_name=query_name, parameters=ProtenixParameters(**kwargs), ) @classmethod def from_file( cls: ProtenixQuery, path: str | Path, query_name: str | None = None, **kwargs ) -> ProtenixQuery: """Initialize a ProtenixQuery instance from a file. Supported file format are: - FASTA Args: path (str | Path): Path of the FASTA file. query_name (str | None, optional): User-defined query name. Defaults to None. seed (int, optional): Random seed. Defaults to 0. use_msa_server (bool, optional): Use the MSA server for inference. Defaults to False. Returns: ProtenixQuery """ path = validate_path(path, is_file=True, file_suffix=(".fasta", ".fa")) query_name = query_name or path.stem return cls( fasta_files={path.stem: validate_fasta(path, str_output=True)}, query_name=query_name, parameters=ProtenixParameters(**kwargs), ) @classmethod def from_directory( cls: ProtenixQuery, path: str | Path, query_name: str | None = None, **kwargs ) -> ProtenixQuery: """Initialize a ProtenixQuery instance from a directory. Supported file format in directory are: - FASTA Args: path (str | Path): Path to a directory of FASTA files. query_name (str | None, optional): User-defined query name. Defaults to None. seed (int, optional): Random seed. Defaults to 0. use_msa_server (bool, optional): Use the MSA server for inference. Defaults to False. Raises: ValueError: If no FASTA file are present in the directory. Returns: ProtenixQuery """ path = validate_path(path, is_dir=True) fasta_files = {} for file in chain(path.glob("*.fasta"), path.glob("*.fa")): fasta_files[file.stem] = validate_fasta(file, str_output=True) if not fasta_files: raise ValueError(f"No FASTA files found in directory '{path}'.") query_name = query_name or path.name return cls( fasta_files=fasta_files, query_name=query_name, parameters=ProtenixParameters(**kwargs), ) @property def payload(self) -> dict[str, Any]: """Payload to send to the prediction API endpoint.""" return { "fasta_files": self.fasta_files, "use_msa_server": self.parameters.use_msa_server, "seeds": self.parameters.seeds, } @property def parameters(self) -> ProtenixParameters: """Parameters of the query.""" return self._parameters