|
"""Query module for for Protenix prediction endpoint.""" |
|
|
|
from __future__ import annotations |
|
|
|
from io import StringIO |
|
from itertools import chain |
|
from pathlib import Path |
|
from typing import Any |
|
|
|
from folding_studio_data_models import FoldingModel |
|
from pydantic import BaseModel, Field |
|
|
|
from folding_studio.query import Query |
|
from folding_studio.utils.fasta import validate_fasta |
|
from folding_studio.utils.path_helpers import validate_path |
|
|
|
|
|
class ProtenixParameters(BaseModel): |
|
"""Protenix inference parameters.""" |
|
|
|
seeds: str = Field(alias="seed", default="0", coerce_numbers_to_str=True) |
|
use_msa_server: bool = True |
|
|
|
|
|
class ProtenixQuery(Query): |
|
"""Protenix model query.""" |
|
|
|
MODEL = FoldingModel.PROTENIX |
|
|
|
def __init__( |
|
self, |
|
fasta_files: dict[str, Any], |
|
query_name: str, |
|
parameters: ProtenixParameters = ProtenixParameters(), |
|
): |
|
if not fasta_files: |
|
raise ValueError("FASTA files dictionary cannot be empty.") |
|
|
|
self.fasta_files = fasta_files |
|
self.query_name = query_name |
|
self._parameters = parameters |
|
|
|
@classmethod |
|
def from_protein_sequence( |
|
cls, sequence: str, query_name: str | None = None, **kwargs |
|
) -> ProtenixQuery: |
|
"""Initialize a ProtenixQuery from a str protein sequence. |
|
|
|
Args: |
|
sequence (str): The protein sequence in string format. |
|
query_name (str | None, optional): User-defined query name. Defaults to None. |
|
seed (int, optional): Random seed. Defaults to 0. |
|
use_msa_server (bool, optional): Use the MSA server for inference. Defaults to False. |
|
|
|
Returns: |
|
ProtenixQuery: An instance of ProtenixQuery with the sequence stored as a FASTA file. |
|
""" |
|
record = validate_fasta(StringIO(sequence)) |
|
|
|
query_name = ( |
|
query_name |
|
if query_name is not None |
|
else record.description.split("|", maxsplit=1)[0] |
|
) |
|
return cls( |
|
fasta_files={query_name: sequence}, |
|
query_name=query_name, |
|
parameters=ProtenixParameters(**kwargs), |
|
) |
|
|
|
@classmethod |
|
def from_file( |
|
cls: ProtenixQuery, path: str | Path, query_name: str | None = None, **kwargs |
|
) -> ProtenixQuery: |
|
"""Initialize a ProtenixQuery instance from a file. |
|
|
|
Supported file format are: |
|
- FASTA |
|
|
|
Args: |
|
path (str | Path): Path of the FASTA file. |
|
query_name (str | None, optional): User-defined query name. Defaults to None. |
|
seed (int, optional): Random seed. Defaults to 0. |
|
use_msa_server (bool, optional): Use the MSA server for inference. Defaults to False. |
|
|
|
Returns: |
|
ProtenixQuery |
|
""" |
|
path = validate_path(path, is_file=True, file_suffix=(".fasta", ".fa")) |
|
query_name = query_name or path.stem |
|
return cls( |
|
fasta_files={path.stem: validate_fasta(path, str_output=True)}, |
|
query_name=query_name, |
|
parameters=ProtenixParameters(**kwargs), |
|
) |
|
|
|
@classmethod |
|
def from_directory( |
|
cls: ProtenixQuery, path: str | Path, query_name: str | None = None, **kwargs |
|
) -> ProtenixQuery: |
|
"""Initialize a ProtenixQuery instance from a directory. |
|
|
|
Supported file format in directory are: |
|
- FASTA |
|
|
|
Args: |
|
path (str | Path): Path to a directory of FASTA files. |
|
query_name (str | None, optional): User-defined query name. Defaults to None. |
|
seed (int, optional): Random seed. Defaults to 0. |
|
use_msa_server (bool, optional): Use the MSA server for inference. Defaults to False. |
|
|
|
Raises: |
|
ValueError: If no FASTA file are present in the directory. |
|
|
|
Returns: |
|
ProtenixQuery |
|
""" |
|
path = validate_path(path, is_dir=True) |
|
fasta_files = {} |
|
for file in chain(path.glob("*.fasta"), path.glob("*.fa")): |
|
fasta_files[file.stem] = validate_fasta(file, str_output=True) |
|
if not fasta_files: |
|
raise ValueError(f"No FASTA files found in directory '{path}'.") |
|
query_name = query_name or path.name |
|
return cls( |
|
fasta_files=fasta_files, |
|
query_name=query_name, |
|
parameters=ProtenixParameters(**kwargs), |
|
) |
|
|
|
@property |
|
def payload(self) -> dict[str, Any]: |
|
"""Payload to send to the prediction API endpoint.""" |
|
return { |
|
"fasta_files": self.fasta_files, |
|
"use_msa_server": self.parameters.use_msa_server, |
|
"seeds": self.parameters.seeds, |
|
} |
|
|
|
@property |
|
def parameters(self) -> ProtenixParameters: |
|
"""Parameters of the query.""" |
|
return self._parameters |
|
|