|
"""Query module for SoloSeq prediction endpoint.""" |
|
|
|
from __future__ import annotations |
|
|
|
from io import StringIO |
|
from pathlib import Path |
|
from typing import Any |
|
|
|
from folding_studio_data_models import FoldingModel |
|
from pydantic import BaseModel, Field |
|
|
|
from folding_studio.query import Query |
|
from folding_studio.utils.fasta import validate_fasta |
|
from folding_studio.utils.path_helpers import validate_path |
|
|
|
MAX_AA_LENGTH = 1024 |
|
|
|
|
|
class SoloSeqParameters(BaseModel): |
|
"""SoloSeq inference parameters.""" |
|
|
|
data_random_seed: int = Field(alias="seed", default=0) |
|
skip_relaxation: bool = False |
|
subtract_plddt: bool = False |
|
|
|
|
|
class SoloSeqQuery(Query): |
|
"""SoloSeq model query.""" |
|
|
|
MODEL = FoldingModel.SOLOSEQ |
|
|
|
def __init__( |
|
self, |
|
fasta_files: dict[str, str], |
|
query_name: str, |
|
parameters: SoloSeqParameters = SoloSeqParameters(), |
|
): |
|
if not fasta_files: |
|
raise ValueError("FASTA files dictionary cannot be empty.") |
|
|
|
self.fasta_files = fasta_files |
|
self.query_name = query_name |
|
self._parameters = parameters |
|
|
|
def __eq__(self, value): |
|
if not isinstance(value, SoloSeqQuery): |
|
return False |
|
return ( |
|
self.fasta_files == value.fasta_files |
|
and self.query_name == value.query_name |
|
and self.parameters == value.parameters |
|
) |
|
|
|
@classmethod |
|
def from_protein_sequence( |
|
cls: SoloSeqQuery, sequence: str, query_name: str | None = None, **kwargs |
|
) -> SoloSeqQuery: |
|
"""Initialize a SoloSeqQuery instance from a str protein sequence. |
|
|
|
Args: |
|
sequence (str): Protein amino-acid sequence |
|
query_name (str | None, optional): User-defined query name. Defaults to None. |
|
seed (int, optional): Random seed. Defaults to 0. |
|
skip_relaxation (bool, optional): Run the skip_relaxation process. |
|
Defaults to False. |
|
subtract_plddt (bool, optional): Output (100 - pLDDT) instead |
|
of the pLDDT itself. Defaults to False. |
|
|
|
Raises: |
|
NotAMonomer: If the sequence is not a monomer complex. |
|
|
|
Returns: |
|
SoloSeqQuery |
|
""" |
|
record = validate_fasta( |
|
StringIO(sequence), allow_multimer=False, max_aa_length=MAX_AA_LENGTH |
|
) |
|
query_name = ( |
|
query_name |
|
if query_name is not None |
|
else record.description.split("|", maxsplit=1)[0] |
|
) |
|
return cls( |
|
fasta_files={query_name: sequence}, |
|
query_name=query_name, |
|
parameters=SoloSeqParameters(**kwargs), |
|
) |
|
|
|
@classmethod |
|
def from_file( |
|
cls: SoloSeqQuery, path: str | Path, query_name: str | None = None, **kwargs |
|
) -> SoloSeqQuery: |
|
"""Initialize a SoloSeqQuery instance from a file. |
|
|
|
Supported file format are: |
|
- FASTA |
|
|
|
Args: |
|
path (str | Path): Path of the FASTA file. |
|
query_name (str | None, optional): User-defined query name. Defaults to None. |
|
seed (int, optional): Random seed. Defaults to 0. |
|
skip_relaxation (bool, optional): Run the skip_relaxation process. |
|
Defaults to False. |
|
subtract_plddt (bool, optional): Output (100 - pLDDT) instead |
|
of the pLDDT itself. Defaults to False. |
|
|
|
Raises: |
|
NotAMonomer: If the FASTA file contains non-monomer complex. |
|
|
|
Returns: |
|
SoloSeqQuery |
|
""" |
|
path = validate_path(path, is_file=True, file_suffix=(".fasta", ".fa")) |
|
record = validate_fasta(path, allow_multimer=False, max_aa_length=MAX_AA_LENGTH) |
|
query_name = query_name or path.stem |
|
return cls( |
|
fasta_files={path.stem: record.format("fasta").strip()}, |
|
query_name=query_name, |
|
parameters=SoloSeqParameters(**kwargs), |
|
) |
|
|
|
@classmethod |
|
def from_directory( |
|
cls: SoloSeqQuery, path: str | Path, query_name: str | None = None, **kwargs |
|
) -> SoloSeqQuery: |
|
"""Initialize a SoloSeqQuery instance from a directory. |
|
|
|
Supported file format in directory are: |
|
- FASTA |
|
|
|
Args: |
|
path (str | Path): Path to a directory of FASTA files. |
|
query_name (str | None, optional): User-defined query name. Defaults to None. |
|
seed (int, optional): Random seed. Defaults to 0. |
|
skip_relaxation (bool, optional): Run the skip_relaxation process. |
|
Defaults to False. |
|
subtract_plddt (bool, optional): Output (100 - pLDDT) instead |
|
of the pLDDT itself. Defaults to False. |
|
|
|
Raises: |
|
ValueError: If no FASTA file are present in the directory. |
|
NotAMonomer: If a FASTA file in the directory contains non monomer complex. |
|
|
|
Returns: |
|
SoloSeqQuery |
|
""" |
|
path = validate_path(path, is_dir=True) |
|
fasta_files = {} |
|
for filepath in (f for f in path.iterdir() if f.suffix in (".fasta", ".fa")): |
|
record = validate_fasta( |
|
filepath, allow_multimer=False, max_aa_length=MAX_AA_LENGTH |
|
) |
|
fasta_files[filepath.stem] = record.format("fasta").strip() |
|
if not fasta_files: |
|
raise ValueError(f"No FASTA files found in directory '{path}'.") |
|
query_name = query_name or path.name |
|
return cls( |
|
fasta_files=fasta_files, |
|
query_name=query_name, |
|
parameters=SoloSeqParameters(**kwargs), |
|
) |
|
|
|
@property |
|
def payload(self) -> dict[str, Any]: |
|
"""Payload to send to the prediction API endpoint.""" |
|
return { |
|
"fasta_files": self.fasta_files, |
|
"parameters": self.parameters.model_dump(mode="json"), |
|
} |
|
|
|
@property |
|
def parameters(self) -> SoloSeqParameters: |
|
"""Parameters of the query.""" |
|
return self._parameters |
|
|