jfaustin's picture
add dockerfile and folding studio cli
44459bb
raw
history blame
7.89 kB
"""Boltz-1 query to prediction endpoint."""
from __future__ import annotations
from io import StringIO
from pathlib import Path
from typing import Any
import yaml
from folding_studio_data_models import FoldingModel
from pydantic import BaseModel
from folding_studio.commands.utils import process_uploaded_msas
from folding_studio.query import Query
from folding_studio.utils.fasta import validate_fasta
from folding_studio.utils.headers import get_auth_headers
from folding_studio.utils.path_helpers import validate_path
class BoltzParameters(BaseModel):
"""Boltz inference parameters."""
seed: int = 0
recycling_steps: int = 3
sampling_steps: int = 200
diffusion_samples: int = 1
step_scale: float = 1.638
msa_pairing_strategy: str = "greedy"
write_full_pae: bool = False
write_full_pde: bool = False
use_msa_server: bool = True
custom_msa_paths: dict[str, str] | None = None
class BoltzQuery(Query):
"""Boltz1 model query."""
MODEL = FoldingModel.BOLTZ
def __init__(
self,
fasta_dict: dict[str, str],
yaml_dict: dict[str, str],
query_name: str,
parameters: BoltzParameters = BoltzParameters(),
):
self.fasta_dict = fasta_dict
self.yaml_dict = yaml_dict
self.query_name = query_name
self._parameters = parameters
@staticmethod
def _process_file(file_path: Path) -> tuple[dict[str, str], dict[str, str]]:
"""Processes a single file and extracts its contents.
Args:
file_path (Path): Path to the file.
Returns:
tuple[dict[str, str], dict[str, str]]: A tuple containing FASTA and YAML dictionaries.
Raises:
ValueError: If the file format is unsupported.
"""
fasta_dict = {}
yaml_dict = {}
if file_path.suffix in (".fasta", ".fa"):
fasta_content = validate_fasta(file_path, str_output=True)
fasta_dict = {file_path.stem: fasta_content}
elif file_path.suffix in (".yaml", ".yml"):
with file_path.open("r", encoding="utf-8") as f:
yaml_dict = {file_path.stem: yaml.safe_load(f)}
else:
raise ValueError(f"Unsupported format: {file_path.suffix}")
return fasta_dict, yaml_dict
@classmethod
def from_protein_sequence(
cls: BoltzQuery, sequence: str, query_name: str | None = None, **kwargs
) -> BoltzQuery:
"""Initialize a BoltzQuery from a str protein sequence.
Args:
sequence (str): The protein sequence in string format.
query_name (str | None, optional): User-defined query name. Defaults to None.
**kwargs: Additional parameters for the query.
Returns:
BoltzQuery
"""
record = validate_fasta(StringIO(sequence))
custom_msa_paths = kwargs.pop("custom_msa_paths", None)
if custom_msa_paths:
kwargs["custom_msa_paths"] = cls._upload_custom_msa_files(custom_msa_paths)
query_name = (
query_name
if query_name is not None
else record.description.split("|", maxsplit=1)[0] # first tag
)
return cls(
fasta_files={query_name: sequence},
query_name=query_name,
parameters=BoltzParameters(**kwargs),
)
@classmethod
def from_file(
cls: BoltzQuery, path: str | Path, query_name: str | None = None, **kwargs
) -> BoltzQuery:
"""Initialize a BoltzQuery instance from a file.
Supported file format are:
- FASTA
- YAML
Args:
path (str | Path): Path to the file.
**kwargs: Additional parameters for the query.
Returns:
BoltzQuery: An instance of BoltzQuery.
"""
path = validate_path(
path, is_file=True, file_suffix=(".fasta", ".fa", ".yaml", ".yml")
)
fasta_dict, yaml_dict = cls._process_file(path)
query_name = query_name or path.stem
custom_msa_paths = kwargs.pop("custom_msa_paths", None)
if custom_msa_paths:
kwargs["custom_msa_paths"] = cls._upload_custom_msa_files(custom_msa_paths)
return cls(
fasta_dict=fasta_dict,
yaml_dict=yaml_dict,
query_name=query_name,
parameters=BoltzParameters(**kwargs),
)
@classmethod
def from_directory(
cls: BoltzQuery, path: str | Path, query_name: str | None = None, **kwargs: Any
) -> BoltzQuery:
"""Initialize a BoltzQuery instance from a directory.
Supported file format in directory are:
- FASTA
- YAML
Args:
directory_path (Path): Path to the directory.
**kwargs: Additional parameters for the query.
Returns:
BoltzQuery: An instance of BoltzQuery.
"""
custom_msa_paths = kwargs.pop("custom_msa_paths", None)
if custom_msa_paths:
kwargs["custom_msa_paths"] = cls._upload_custom_msa_files(custom_msa_paths)
path = validate_path(path, is_dir=True)
fasta_dict = {}
yaml_dict = {}
for file in path.iterdir():
file_fasta_dict, file_yaml_dict = cls._process_file(file)
fasta_dict.update(file_fasta_dict)
yaml_dict.update(file_yaml_dict)
if not (fasta_dict or yaml_dict):
raise ValueError(f"No FASTA or YAML files found in directory '{path}'.")
query_name = query_name or path.name
return cls(
fasta_dict=fasta_dict,
yaml_dict=yaml_dict,
query_name=query_name,
parameters=BoltzParameters(**kwargs),
)
@property
def payload(self) -> dict[str, Any]:
"""Payload to send to the prediction API endpoint."""
return {
"fasta_files": self.fasta_dict,
"yaml_files": self.yaml_dict,
"parameters": self.parameters.model_dump(mode="json"),
}
@property
def parameters(self) -> BoltzParameters:
"""Parameters of the query."""
return self._parameters
@staticmethod
def _upload_custom_msa_files(
source: str, headers: str | None = None
) -> dict[str, str]:
"""Reads MSA files from a file or directory and uploads them to GCS.
Args:
source (str): Path to an .a3m or .csv file, or a directory containing such files.
headers (str | None, optional): GCP authentication headers. Defaults to None.
Raises:
ValueError: If the file has an unsupported extension.
ValueError: If a directory contains no .a3m or .csv files.
Returns:
dict[str, str]: A mapping of uploaded file names to their GCS URLs.
"""
headers = headers or get_auth_headers()
source_path = validate_path(source)
valid_extensions = {".a3m", ".csv"} # Allow both a3m and csv files
# Process if source is a file
if source_path.is_file():
if source_path.suffix not in valid_extensions:
raise ValueError(
f"Invalid file type: {source_path.suffix}. Expected one of {valid_extensions}."
)
return process_uploaded_msas([source_path], headers)
# Process if source is a directory
elif source_path.is_dir():
valid_files = [
file
for file in source_path.iterdir()
if file.suffix in valid_extensions
]
if not valid_files:
raise ValueError(
f"Directory '{source}' contains no valid files with extensions {valid_extensions}."
)
return process_uploaded_msas(valid_files, headers)