jfaustin's picture
add dockerfile and folding studio cli
44459bb
"""Query module for SoloSeq prediction endpoint."""
from __future__ import annotations
from io import StringIO
from pathlib import Path
from typing import Any
from folding_studio_data_models import FoldingModel
from pydantic import BaseModel, Field
from folding_studio.query import Query
from folding_studio.utils.fasta import validate_fasta
from folding_studio.utils.path_helpers import validate_path
MAX_AA_LENGTH = 1024
class SoloSeqParameters(BaseModel):
"""SoloSeq inference parameters."""
data_random_seed: int = Field(alias="seed", default=0)
skip_relaxation: bool = False
subtract_plddt: bool = False
class SoloSeqQuery(Query):
"""SoloSeq model query."""
MODEL = FoldingModel.SOLOSEQ
def __init__(
self,
fasta_files: dict[str, str],
query_name: str,
parameters: SoloSeqParameters = SoloSeqParameters(),
):
if not fasta_files:
raise ValueError("FASTA files dictionary cannot be empty.")
self.fasta_files = fasta_files
self.query_name = query_name
self._parameters = parameters
def __eq__(self, value):
if not isinstance(value, SoloSeqQuery):
return False
return (
self.fasta_files == value.fasta_files
and self.query_name == value.query_name
and self.parameters == value.parameters
)
@classmethod
def from_protein_sequence(
cls: SoloSeqQuery, sequence: str, query_name: str | None = None, **kwargs
) -> SoloSeqQuery:
"""Initialize a SoloSeqQuery instance from a str protein sequence.
Args:
sequence (str): Protein amino-acid sequence
query_name (str | None, optional): User-defined query name. Defaults to None.
seed (int, optional): Random seed. Defaults to 0.
skip_relaxation (bool, optional): Run the skip_relaxation process.
Defaults to False.
subtract_plddt (bool, optional): Output (100 - pLDDT) instead
of the pLDDT itself. Defaults to False.
Raises:
NotAMonomer: If the sequence is not a monomer complex.
Returns:
SoloSeqQuery
"""
record = validate_fasta(
StringIO(sequence), allow_multimer=False, max_aa_length=MAX_AA_LENGTH
)
query_name = (
query_name
if query_name is not None
else record.description.split("|", maxsplit=1)[0] # first tag
)
return cls(
fasta_files={query_name: sequence},
query_name=query_name,
parameters=SoloSeqParameters(**kwargs),
)
@classmethod
def from_file(
cls: SoloSeqQuery, path: str | Path, query_name: str | None = None, **kwargs
) -> SoloSeqQuery:
"""Initialize a SoloSeqQuery instance from a file.
Supported file format are:
- FASTA
Args:
path (str | Path): Path of the FASTA file.
query_name (str | None, optional): User-defined query name. Defaults to None.
seed (int, optional): Random seed. Defaults to 0.
skip_relaxation (bool, optional): Run the skip_relaxation process.
Defaults to False.
subtract_plddt (bool, optional): Output (100 - pLDDT) instead
of the pLDDT itself. Defaults to False.
Raises:
NotAMonomer: If the FASTA file contains non-monomer complex.
Returns:
SoloSeqQuery
"""
path = validate_path(path, is_file=True, file_suffix=(".fasta", ".fa"))
record = validate_fasta(path, allow_multimer=False, max_aa_length=MAX_AA_LENGTH)
query_name = query_name or path.stem
return cls(
fasta_files={path.stem: record.format("fasta").strip()},
query_name=query_name,
parameters=SoloSeqParameters(**kwargs),
)
@classmethod
def from_directory(
cls: SoloSeqQuery, path: str | Path, query_name: str | None = None, **kwargs
) -> SoloSeqQuery:
"""Initialize a SoloSeqQuery instance from a directory.
Supported file format in directory are:
- FASTA
Args:
path (str | Path): Path to a directory of FASTA files.
query_name (str | None, optional): User-defined query name. Defaults to None.
seed (int, optional): Random seed. Defaults to 0.
skip_relaxation (bool, optional): Run the skip_relaxation process.
Defaults to False.
subtract_plddt (bool, optional): Output (100 - pLDDT) instead
of the pLDDT itself. Defaults to False.
Raises:
ValueError: If no FASTA file are present in the directory.
NotAMonomer: If a FASTA file in the directory contains non monomer complex.
Returns:
SoloSeqQuery
"""
path = validate_path(path, is_dir=True)
fasta_files = {}
for filepath in (f for f in path.iterdir() if f.suffix in (".fasta", ".fa")):
record = validate_fasta(
filepath, allow_multimer=False, max_aa_length=MAX_AA_LENGTH
)
fasta_files[filepath.stem] = record.format("fasta").strip()
if not fasta_files:
raise ValueError(f"No FASTA files found in directory '{path}'.")
query_name = query_name or path.name
return cls(
fasta_files=fasta_files,
query_name=query_name,
parameters=SoloSeqParameters(**kwargs),
)
@property
def payload(self) -> dict[str, Any]:
"""Payload to send to the prediction API endpoint."""
return {
"fasta_files": self.fasta_files,
"parameters": self.parameters.model_dump(mode="json"),
}
@property
def parameters(self) -> SoloSeqParameters:
"""Parameters of the query."""
return self._parameters