File size: 5,998 Bytes
44459bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
"""Query module for SoloSeq prediction endpoint."""
from __future__ import annotations
from io import StringIO
from pathlib import Path
from typing import Any
from folding_studio_data_models import FoldingModel
from pydantic import BaseModel, Field
from folding_studio.query import Query
from folding_studio.utils.fasta import validate_fasta
from folding_studio.utils.path_helpers import validate_path
MAX_AA_LENGTH = 1024
class SoloSeqParameters(BaseModel):
"""SoloSeq inference parameters."""
data_random_seed: int = Field(alias="seed", default=0)
skip_relaxation: bool = False
subtract_plddt: bool = False
class SoloSeqQuery(Query):
"""SoloSeq model query."""
MODEL = FoldingModel.SOLOSEQ
def __init__(
self,
fasta_files: dict[str, str],
query_name: str,
parameters: SoloSeqParameters = SoloSeqParameters(),
):
if not fasta_files:
raise ValueError("FASTA files dictionary cannot be empty.")
self.fasta_files = fasta_files
self.query_name = query_name
self._parameters = parameters
def __eq__(self, value):
if not isinstance(value, SoloSeqQuery):
return False
return (
self.fasta_files == value.fasta_files
and self.query_name == value.query_name
and self.parameters == value.parameters
)
@classmethod
def from_protein_sequence(
cls: SoloSeqQuery, sequence: str, query_name: str | None = None, **kwargs
) -> SoloSeqQuery:
"""Initialize a SoloSeqQuery instance from a str protein sequence.
Args:
sequence (str): Protein amino-acid sequence
query_name (str | None, optional): User-defined query name. Defaults to None.
seed (int, optional): Random seed. Defaults to 0.
skip_relaxation (bool, optional): Run the skip_relaxation process.
Defaults to False.
subtract_plddt (bool, optional): Output (100 - pLDDT) instead
of the pLDDT itself. Defaults to False.
Raises:
NotAMonomer: If the sequence is not a monomer complex.
Returns:
SoloSeqQuery
"""
record = validate_fasta(
StringIO(sequence), allow_multimer=False, max_aa_length=MAX_AA_LENGTH
)
query_name = (
query_name
if query_name is not None
else record.description.split("|", maxsplit=1)[0] # first tag
)
return cls(
fasta_files={query_name: sequence},
query_name=query_name,
parameters=SoloSeqParameters(**kwargs),
)
@classmethod
def from_file(
cls: SoloSeqQuery, path: str | Path, query_name: str | None = None, **kwargs
) -> SoloSeqQuery:
"""Initialize a SoloSeqQuery instance from a file.
Supported file format are:
- FASTA
Args:
path (str | Path): Path of the FASTA file.
query_name (str | None, optional): User-defined query name. Defaults to None.
seed (int, optional): Random seed. Defaults to 0.
skip_relaxation (bool, optional): Run the skip_relaxation process.
Defaults to False.
subtract_plddt (bool, optional): Output (100 - pLDDT) instead
of the pLDDT itself. Defaults to False.
Raises:
NotAMonomer: If the FASTA file contains non-monomer complex.
Returns:
SoloSeqQuery
"""
path = validate_path(path, is_file=True, file_suffix=(".fasta", ".fa"))
record = validate_fasta(path, allow_multimer=False, max_aa_length=MAX_AA_LENGTH)
query_name = query_name or path.stem
return cls(
fasta_files={path.stem: record.format("fasta").strip()},
query_name=query_name,
parameters=SoloSeqParameters(**kwargs),
)
@classmethod
def from_directory(
cls: SoloSeqQuery, path: str | Path, query_name: str | None = None, **kwargs
) -> SoloSeqQuery:
"""Initialize a SoloSeqQuery instance from a directory.
Supported file format in directory are:
- FASTA
Args:
path (str | Path): Path to a directory of FASTA files.
query_name (str | None, optional): User-defined query name. Defaults to None.
seed (int, optional): Random seed. Defaults to 0.
skip_relaxation (bool, optional): Run the skip_relaxation process.
Defaults to False.
subtract_plddt (bool, optional): Output (100 - pLDDT) instead
of the pLDDT itself. Defaults to False.
Raises:
ValueError: If no FASTA file are present in the directory.
NotAMonomer: If a FASTA file in the directory contains non monomer complex.
Returns:
SoloSeqQuery
"""
path = validate_path(path, is_dir=True)
fasta_files = {}
for filepath in (f for f in path.iterdir() if f.suffix in (".fasta", ".fa")):
record = validate_fasta(
filepath, allow_multimer=False, max_aa_length=MAX_AA_LENGTH
)
fasta_files[filepath.stem] = record.format("fasta").strip()
if not fasta_files:
raise ValueError(f"No FASTA files found in directory '{path}'.")
query_name = query_name or path.name
return cls(
fasta_files=fasta_files,
query_name=query_name,
parameters=SoloSeqParameters(**kwargs),
)
@property
def payload(self) -> dict[str, Any]:
"""Payload to send to the prediction API endpoint."""
return {
"fasta_files": self.fasta_files,
"parameters": self.parameters.model_dump(mode="json"),
}
@property
def parameters(self) -> SoloSeqParameters:
"""Parameters of the query."""
return self._parameters
|