jfaustin's picture
add dockerfile and folding studio cli
44459bb
"""Query module for for Protenix prediction endpoint."""
from __future__ import annotations
from io import StringIO
from itertools import chain
from pathlib import Path
from typing import Any
from folding_studio_data_models import FoldingModel
from pydantic import BaseModel, Field
from folding_studio.query import Query
from folding_studio.utils.fasta import validate_fasta
from folding_studio.utils.path_helpers import validate_path
class ProtenixParameters(BaseModel):
"""Protenix inference parameters."""
seeds: str = Field(alias="seed", default="0", coerce_numbers_to_str=True)
use_msa_server: bool = True
class ProtenixQuery(Query):
"""Protenix model query."""
MODEL = FoldingModel.PROTENIX
def __init__(
self,
fasta_files: dict[str, Any],
query_name: str,
parameters: ProtenixParameters = ProtenixParameters(),
):
if not fasta_files:
raise ValueError("FASTA files dictionary cannot be empty.")
self.fasta_files = fasta_files
self.query_name = query_name
self._parameters = parameters
@classmethod
def from_protein_sequence(
cls, sequence: str, query_name: str | None = None, **kwargs
) -> ProtenixQuery:
"""Initialize a ProtenixQuery from a str protein sequence.
Args:
sequence (str): The protein sequence in string format.
query_name (str | None, optional): User-defined query name. Defaults to None.
seed (int, optional): Random seed. Defaults to 0.
use_msa_server (bool, optional): Use the MSA server for inference. Defaults to False.
Returns:
ProtenixQuery: An instance of ProtenixQuery with the sequence stored as a FASTA file.
"""
record = validate_fasta(StringIO(sequence))
query_name = (
query_name
if query_name is not None
else record.description.split("|", maxsplit=1)[0] # first tag
)
return cls(
fasta_files={query_name: sequence},
query_name=query_name,
parameters=ProtenixParameters(**kwargs),
)
@classmethod
def from_file(
cls: ProtenixQuery, path: str | Path, query_name: str | None = None, **kwargs
) -> ProtenixQuery:
"""Initialize a ProtenixQuery instance from a file.
Supported file format are:
- FASTA
Args:
path (str | Path): Path of the FASTA file.
query_name (str | None, optional): User-defined query name. Defaults to None.
seed (int, optional): Random seed. Defaults to 0.
use_msa_server (bool, optional): Use the MSA server for inference. Defaults to False.
Returns:
ProtenixQuery
"""
path = validate_path(path, is_file=True, file_suffix=(".fasta", ".fa"))
query_name = query_name or path.stem
return cls(
fasta_files={path.stem: validate_fasta(path, str_output=True)},
query_name=query_name,
parameters=ProtenixParameters(**kwargs),
)
@classmethod
def from_directory(
cls: ProtenixQuery, path: str | Path, query_name: str | None = None, **kwargs
) -> ProtenixQuery:
"""Initialize a ProtenixQuery instance from a directory.
Supported file format in directory are:
- FASTA
Args:
path (str | Path): Path to a directory of FASTA files.
query_name (str | None, optional): User-defined query name. Defaults to None.
seed (int, optional): Random seed. Defaults to 0.
use_msa_server (bool, optional): Use the MSA server for inference. Defaults to False.
Raises:
ValueError: If no FASTA file are present in the directory.
Returns:
ProtenixQuery
"""
path = validate_path(path, is_dir=True)
fasta_files = {}
for file in chain(path.glob("*.fasta"), path.glob("*.fa")):
fasta_files[file.stem] = validate_fasta(file, str_output=True)
if not fasta_files:
raise ValueError(f"No FASTA files found in directory '{path}'.")
query_name = query_name or path.name
return cls(
fasta_files=fasta_files,
query_name=query_name,
parameters=ProtenixParameters(**kwargs),
)
@property
def payload(self) -> dict[str, Any]:
"""Payload to send to the prediction API endpoint."""
return {
"fasta_files": self.fasta_files,
"use_msa_server": self.parameters.use_msa_server,
"seeds": self.parameters.seeds,
}
@property
def parameters(self) -> ProtenixParameters:
"""Parameters of the query."""
return self._parameters