File size: 4,834 Bytes
44459bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
"""Query module for for Protenix prediction endpoint."""
from __future__ import annotations
from io import StringIO
from itertools import chain
from pathlib import Path
from typing import Any
from folding_studio_data_models import FoldingModel
from pydantic import BaseModel, Field
from folding_studio.query import Query
from folding_studio.utils.fasta import validate_fasta
from folding_studio.utils.path_helpers import validate_path
class ProtenixParameters(BaseModel):
"""Protenix inference parameters."""
seeds: str = Field(alias="seed", default="0", coerce_numbers_to_str=True)
use_msa_server: bool = True
class ProtenixQuery(Query):
"""Protenix model query."""
MODEL = FoldingModel.PROTENIX
def __init__(
self,
fasta_files: dict[str, Any],
query_name: str,
parameters: ProtenixParameters = ProtenixParameters(),
):
if not fasta_files:
raise ValueError("FASTA files dictionary cannot be empty.")
self.fasta_files = fasta_files
self.query_name = query_name
self._parameters = parameters
@classmethod
def from_protein_sequence(
cls, sequence: str, query_name: str | None = None, **kwargs
) -> ProtenixQuery:
"""Initialize a ProtenixQuery from a str protein sequence.
Args:
sequence (str): The protein sequence in string format.
query_name (str | None, optional): User-defined query name. Defaults to None.
seed (int, optional): Random seed. Defaults to 0.
use_msa_server (bool, optional): Use the MSA server for inference. Defaults to False.
Returns:
ProtenixQuery: An instance of ProtenixQuery with the sequence stored as a FASTA file.
"""
record = validate_fasta(StringIO(sequence))
query_name = (
query_name
if query_name is not None
else record.description.split("|", maxsplit=1)[0] # first tag
)
return cls(
fasta_files={query_name: sequence},
query_name=query_name,
parameters=ProtenixParameters(**kwargs),
)
@classmethod
def from_file(
cls: ProtenixQuery, path: str | Path, query_name: str | None = None, **kwargs
) -> ProtenixQuery:
"""Initialize a ProtenixQuery instance from a file.
Supported file format are:
- FASTA
Args:
path (str | Path): Path of the FASTA file.
query_name (str | None, optional): User-defined query name. Defaults to None.
seed (int, optional): Random seed. Defaults to 0.
use_msa_server (bool, optional): Use the MSA server for inference. Defaults to False.
Returns:
ProtenixQuery
"""
path = validate_path(path, is_file=True, file_suffix=(".fasta", ".fa"))
query_name = query_name or path.stem
return cls(
fasta_files={path.stem: validate_fasta(path, str_output=True)},
query_name=query_name,
parameters=ProtenixParameters(**kwargs),
)
@classmethod
def from_directory(
cls: ProtenixQuery, path: str | Path, query_name: str | None = None, **kwargs
) -> ProtenixQuery:
"""Initialize a ProtenixQuery instance from a directory.
Supported file format in directory are:
- FASTA
Args:
path (str | Path): Path to a directory of FASTA files.
query_name (str | None, optional): User-defined query name. Defaults to None.
seed (int, optional): Random seed. Defaults to 0.
use_msa_server (bool, optional): Use the MSA server for inference. Defaults to False.
Raises:
ValueError: If no FASTA file are present in the directory.
Returns:
ProtenixQuery
"""
path = validate_path(path, is_dir=True)
fasta_files = {}
for file in chain(path.glob("*.fasta"), path.glob("*.fa")):
fasta_files[file.stem] = validate_fasta(file, str_output=True)
if not fasta_files:
raise ValueError(f"No FASTA files found in directory '{path}'.")
query_name = query_name or path.name
return cls(
fasta_files=fasta_files,
query_name=query_name,
parameters=ProtenixParameters(**kwargs),
)
@property
def payload(self) -> dict[str, Any]:
"""Payload to send to the prediction API endpoint."""
return {
"fasta_files": self.fasta_files,
"use_msa_server": self.parameters.use_msa_server,
"seeds": self.parameters.seeds,
}
@property
def parameters(self) -> ProtenixParameters:
"""Parameters of the query."""
return self._parameters
|