"""Query module for SoloSeq prediction endpoint.""" from __future__ import annotations from io import StringIO from pathlib import Path from typing import Any from folding_studio_data_models import FoldingModel from pydantic import BaseModel, Field from folding_studio.query import Query from folding_studio.utils.fasta import validate_fasta from folding_studio.utils.path_helpers import validate_path MAX_AA_LENGTH = 1024 class SoloSeqParameters(BaseModel): """SoloSeq inference parameters.""" data_random_seed: int = Field(alias="seed", default=0) skip_relaxation: bool = False subtract_plddt: bool = False class SoloSeqQuery(Query): """SoloSeq model query.""" MODEL = FoldingModel.SOLOSEQ def __init__( self, fasta_files: dict[str, str], query_name: str, parameters: SoloSeqParameters = SoloSeqParameters(), ): if not fasta_files: raise ValueError("FASTA files dictionary cannot be empty.") self.fasta_files = fasta_files self.query_name = query_name self._parameters = parameters def __eq__(self, value): if not isinstance(value, SoloSeqQuery): return False return ( self.fasta_files == value.fasta_files and self.query_name == value.query_name and self.parameters == value.parameters ) @classmethod def from_protein_sequence( cls: SoloSeqQuery, sequence: str, query_name: str | None = None, **kwargs ) -> SoloSeqQuery: """Initialize a SoloSeqQuery instance from a str protein sequence. Args: sequence (str): Protein amino-acid sequence query_name (str | None, optional): User-defined query name. Defaults to None. seed (int, optional): Random seed. Defaults to 0. skip_relaxation (bool, optional): Run the skip_relaxation process. Defaults to False. subtract_plddt (bool, optional): Output (100 - pLDDT) instead of the pLDDT itself. Defaults to False. Raises: NotAMonomer: If the sequence is not a monomer complex. Returns: SoloSeqQuery """ record = validate_fasta( StringIO(sequence), allow_multimer=False, max_aa_length=MAX_AA_LENGTH ) query_name = ( query_name if query_name is not None else record.description.split("|", maxsplit=1)[0] # first tag ) return cls( fasta_files={query_name: sequence}, query_name=query_name, parameters=SoloSeqParameters(**kwargs), ) @classmethod def from_file( cls: SoloSeqQuery, path: str | Path, query_name: str | None = None, **kwargs ) -> SoloSeqQuery: """Initialize a SoloSeqQuery instance from a file. Supported file format are: - FASTA Args: path (str | Path): Path of the FASTA file. query_name (str | None, optional): User-defined query name. Defaults to None. seed (int, optional): Random seed. Defaults to 0. skip_relaxation (bool, optional): Run the skip_relaxation process. Defaults to False. subtract_plddt (bool, optional): Output (100 - pLDDT) instead of the pLDDT itself. Defaults to False. Raises: NotAMonomer: If the FASTA file contains non-monomer complex. Returns: SoloSeqQuery """ path = validate_path(path, is_file=True, file_suffix=(".fasta", ".fa")) record = validate_fasta(path, allow_multimer=False, max_aa_length=MAX_AA_LENGTH) query_name = query_name or path.stem return cls( fasta_files={path.stem: record.format("fasta").strip()}, query_name=query_name, parameters=SoloSeqParameters(**kwargs), ) @classmethod def from_directory( cls: SoloSeqQuery, path: str | Path, query_name: str | None = None, **kwargs ) -> SoloSeqQuery: """Initialize a SoloSeqQuery instance from a directory. Supported file format in directory are: - FASTA Args: path (str | Path): Path to a directory of FASTA files. query_name (str | None, optional): User-defined query name. Defaults to None. seed (int, optional): Random seed. Defaults to 0. skip_relaxation (bool, optional): Run the skip_relaxation process. Defaults to False. subtract_plddt (bool, optional): Output (100 - pLDDT) instead of the pLDDT itself. Defaults to False. Raises: ValueError: If no FASTA file are present in the directory. NotAMonomer: If a FASTA file in the directory contains non monomer complex. Returns: SoloSeqQuery """ path = validate_path(path, is_dir=True) fasta_files = {} for filepath in (f for f in path.iterdir() if f.suffix in (".fasta", ".fa")): record = validate_fasta( filepath, allow_multimer=False, max_aa_length=MAX_AA_LENGTH ) fasta_files[filepath.stem] = record.format("fasta").strip() if not fasta_files: raise ValueError(f"No FASTA files found in directory '{path}'.") query_name = query_name or path.name return cls( fasta_files=fasta_files, query_name=query_name, parameters=SoloSeqParameters(**kwargs), ) @property def payload(self) -> dict[str, Any]: """Payload to send to the prediction API endpoint.""" return { "fasta_files": self.fasta_files, "parameters": self.parameters.model_dump(mode="json"), } @property def parameters(self) -> SoloSeqParameters: """Parameters of the query.""" return self._parameters