|
"""Boltz-1 query to prediction endpoint.""" |
|
|
|
from __future__ import annotations |
|
|
|
from io import StringIO |
|
from pathlib import Path |
|
from typing import Any |
|
|
|
import yaml |
|
from folding_studio_data_models import FoldingModel |
|
from pydantic import BaseModel |
|
|
|
from folding_studio.commands.utils import process_uploaded_msas |
|
from folding_studio.query import Query |
|
from folding_studio.utils.fasta import validate_fasta |
|
from folding_studio.utils.headers import get_auth_headers |
|
from folding_studio.utils.path_helpers import validate_path |
|
|
|
|
|
class BoltzParameters(BaseModel): |
|
"""Boltz inference parameters.""" |
|
|
|
seed: int = 0 |
|
recycling_steps: int = 3 |
|
sampling_steps: int = 200 |
|
diffusion_samples: int = 1 |
|
step_scale: float = 1.638 |
|
msa_pairing_strategy: str = "greedy" |
|
write_full_pae: bool = False |
|
write_full_pde: bool = False |
|
use_msa_server: bool = True |
|
custom_msa_paths: dict[str, str] | None = None |
|
|
|
|
|
class BoltzQuery(Query): |
|
"""Boltz1 model query.""" |
|
|
|
MODEL = FoldingModel.BOLTZ |
|
|
|
def __init__( |
|
self, |
|
fasta_dict: dict[str, str], |
|
yaml_dict: dict[str, str], |
|
query_name: str, |
|
parameters: BoltzParameters = BoltzParameters(), |
|
): |
|
self.fasta_dict = fasta_dict |
|
self.yaml_dict = yaml_dict |
|
self.query_name = query_name |
|
self._parameters = parameters |
|
|
|
@staticmethod |
|
def _process_file(file_path: Path) -> tuple[dict[str, str], dict[str, str]]: |
|
"""Processes a single file and extracts its contents. |
|
|
|
Args: |
|
file_path (Path): Path to the file. |
|
|
|
Returns: |
|
tuple[dict[str, str], dict[str, str]]: A tuple containing FASTA and YAML dictionaries. |
|
|
|
Raises: |
|
ValueError: If the file format is unsupported. |
|
""" |
|
fasta_dict = {} |
|
yaml_dict = {} |
|
if file_path.suffix in (".fasta", ".fa"): |
|
fasta_content = validate_fasta(file_path, str_output=True) |
|
fasta_dict = {file_path.stem: fasta_content} |
|
elif file_path.suffix in (".yaml", ".yml"): |
|
with file_path.open("r", encoding="utf-8") as f: |
|
yaml_dict = {file_path.stem: yaml.safe_load(f)} |
|
else: |
|
raise ValueError(f"Unsupported format: {file_path.suffix}") |
|
return fasta_dict, yaml_dict |
|
|
|
@classmethod |
|
def from_protein_sequence( |
|
cls: BoltzQuery, sequence: str, query_name: str | None = None, **kwargs |
|
) -> BoltzQuery: |
|
"""Initialize a BoltzQuery from a str protein sequence. |
|
|
|
Args: |
|
sequence (str): The protein sequence in string format. |
|
query_name (str | None, optional): User-defined query name. Defaults to None. |
|
**kwargs: Additional parameters for the query. |
|
|
|
Returns: |
|
BoltzQuery |
|
""" |
|
record = validate_fasta(StringIO(sequence)) |
|
|
|
custom_msa_paths = kwargs.pop("custom_msa_paths", None) |
|
if custom_msa_paths: |
|
kwargs["custom_msa_paths"] = cls._upload_custom_msa_files(custom_msa_paths) |
|
|
|
query_name = ( |
|
query_name |
|
if query_name is not None |
|
else record.description.split("|", maxsplit=1)[0] |
|
) |
|
return cls( |
|
fasta_files={query_name: sequence}, |
|
query_name=query_name, |
|
parameters=BoltzParameters(**kwargs), |
|
) |
|
|
|
@classmethod |
|
def from_file( |
|
cls: BoltzQuery, path: str | Path, query_name: str | None = None, **kwargs |
|
) -> BoltzQuery: |
|
"""Initialize a BoltzQuery instance from a file. |
|
|
|
Supported file format are: |
|
- FASTA |
|
- YAML |
|
|
|
Args: |
|
path (str | Path): Path to the file. |
|
**kwargs: Additional parameters for the query. |
|
|
|
Returns: |
|
BoltzQuery: An instance of BoltzQuery. |
|
""" |
|
path = validate_path( |
|
path, is_file=True, file_suffix=(".fasta", ".fa", ".yaml", ".yml") |
|
) |
|
fasta_dict, yaml_dict = cls._process_file(path) |
|
query_name = query_name or path.stem |
|
custom_msa_paths = kwargs.pop("custom_msa_paths", None) |
|
if custom_msa_paths: |
|
kwargs["custom_msa_paths"] = cls._upload_custom_msa_files(custom_msa_paths) |
|
return cls( |
|
fasta_dict=fasta_dict, |
|
yaml_dict=yaml_dict, |
|
query_name=query_name, |
|
parameters=BoltzParameters(**kwargs), |
|
) |
|
|
|
@classmethod |
|
def from_directory( |
|
cls: BoltzQuery, path: str | Path, query_name: str | None = None, **kwargs: Any |
|
) -> BoltzQuery: |
|
"""Initialize a BoltzQuery instance from a directory. |
|
|
|
Supported file format in directory are: |
|
- FASTA |
|
- YAML |
|
|
|
Args: |
|
directory_path (Path): Path to the directory. |
|
**kwargs: Additional parameters for the query. |
|
|
|
Returns: |
|
BoltzQuery: An instance of BoltzQuery. |
|
""" |
|
custom_msa_paths = kwargs.pop("custom_msa_paths", None) |
|
if custom_msa_paths: |
|
kwargs["custom_msa_paths"] = cls._upload_custom_msa_files(custom_msa_paths) |
|
path = validate_path(path, is_dir=True) |
|
fasta_dict = {} |
|
yaml_dict = {} |
|
for file in path.iterdir(): |
|
file_fasta_dict, file_yaml_dict = cls._process_file(file) |
|
fasta_dict.update(file_fasta_dict) |
|
yaml_dict.update(file_yaml_dict) |
|
|
|
if not (fasta_dict or yaml_dict): |
|
raise ValueError(f"No FASTA or YAML files found in directory '{path}'.") |
|
|
|
query_name = query_name or path.name |
|
return cls( |
|
fasta_dict=fasta_dict, |
|
yaml_dict=yaml_dict, |
|
query_name=query_name, |
|
parameters=BoltzParameters(**kwargs), |
|
) |
|
|
|
@property |
|
def payload(self) -> dict[str, Any]: |
|
"""Payload to send to the prediction API endpoint.""" |
|
return { |
|
"fasta_files": self.fasta_dict, |
|
"yaml_files": self.yaml_dict, |
|
"parameters": self.parameters.model_dump(mode="json"), |
|
} |
|
|
|
@property |
|
def parameters(self) -> BoltzParameters: |
|
"""Parameters of the query.""" |
|
return self._parameters |
|
|
|
@staticmethod |
|
def _upload_custom_msa_files( |
|
source: str, headers: str | None = None |
|
) -> dict[str, str]: |
|
"""Reads MSA files from a file or directory and uploads them to GCS. |
|
|
|
Args: |
|
source (str): Path to an .a3m or .csv file, or a directory containing such files. |
|
headers (str | None, optional): GCP authentication headers. Defaults to None. |
|
|
|
Raises: |
|
ValueError: If the file has an unsupported extension. |
|
ValueError: If a directory contains no .a3m or .csv files. |
|
|
|
Returns: |
|
dict[str, str]: A mapping of uploaded file names to their GCS URLs. |
|
""" |
|
headers = headers or get_auth_headers() |
|
source_path = validate_path(source) |
|
valid_extensions = {".a3m", ".csv"} |
|
|
|
|
|
if source_path.is_file(): |
|
if source_path.suffix not in valid_extensions: |
|
raise ValueError( |
|
f"Invalid file type: {source_path.suffix}. Expected one of {valid_extensions}." |
|
) |
|
return process_uploaded_msas([source_path], headers) |
|
|
|
|
|
elif source_path.is_dir(): |
|
valid_files = [ |
|
file |
|
for file in source_path.iterdir() |
|
if file.suffix in valid_extensions |
|
] |
|
if not valid_files: |
|
raise ValueError( |
|
f"Directory '{source}' contains no valid files with extensions {valid_extensions}." |
|
) |
|
return process_uploaded_msas(valid_files, headers) |
|
|