jfaustin's picture
add dockerfile and folding studio cli
44459bb
"""Query module for Chai prediction endpoint."""
from __future__ import annotations
import shutil
import tempfile
from io import StringIO
from itertools import chain
from pathlib import Path
from typing import Any
from folding_studio_data_models import FoldingModel
from pydantic import BaseModel, field_validator
from folding_studio.commands.utils import (
a3m_to_aligned_pqt,
process_uploaded_msas,
)
from folding_studio.query import Query
from folding_studio.utils.fasta import validate_fasta
from folding_studio.utils.headers import get_auth_headers
from folding_studio.utils.path_helpers import validate_path
class ChaiParameters(BaseModel):
"""Chai1 inference parameters."""
seed: int = 0
num_trunk_recycles: int = 3
num_diffn_timesteps: int = 200
recycle_msa_subsample: int = 0
num_trunk_samples: int = 1
restraints: str | None = None
use_msa_server: bool = False
use_templates_server: bool = False
custom_msa_paths: dict[str, str] | None = None
@field_validator("restraints", mode="before")
def read_restraints(
cls: ChaiParameters, restraints: str | Path | None
) -> str | None:
"""Reads restraints from a CSV file and returns its content as a string."""
if restraints is None:
return
path = validate_path(restraints, is_file=True, file_suffix=(".csv"))
with path.open(newline="", encoding="utf-8") as csvfile:
return csvfile.read().strip()
class ChaiQuery(Query):
"""Chai1 model query."""
MODEL = FoldingModel.CHAI
def __init__(
self,
fasta_files: dict[str, str],
query_name: str,
parameters: ChaiParameters = ChaiParameters(),
):
"""Initializes a ChaiQuery instance."""
if not fasta_files:
raise ValueError("FASTA files dictionary cannot be empty.")
self.fasta_files = fasta_files
self.query_name = query_name
self._parameters = parameters
@classmethod
def from_protein_sequence(
cls: ChaiQuery, sequence: str, query_name: str | None = None, **kwargs
) -> ChaiQuery:
"""Initialize a ChaiQuery instance from a str protein sequence.
Args:
sequence (str): Protein amino-acid sequence
query_name (str | None, optional): User-defined query name. Defaults to None.
Raises:
NotAMonomer: If the sequence is not a monomer complex.
Returns:
ChaiQuery
"""
record = validate_fasta(StringIO(sequence))
custom_msa_paths = kwargs.pop("custom_msa_paths", None)
if custom_msa_paths:
kwargs["custom_msa_paths"] = cls._upload_custom_msa_files(custom_msa_paths)
query_name = (
query_name
if query_name is not None
else record.description.split("|", maxsplit=1)[0] # first tag
)
return cls(
fasta_files={query_name: sequence},
query_name=query_name,
parameters=ChaiParameters(**kwargs),
)
@classmethod
def from_file(
cls: ChaiQuery, path: str | Path, query_name: str | None = None, **kwargs
) -> ChaiQuery:
"""Initialize a ChaiQuery instance from a file.
Supported file format are:
- FASTA
Args:
path (str | Path): Path of the FASTA file.
query_name (str | None, optional): User-defined query name. Defaults to None.
Returns:
ChaiQuery
"""
path = validate_path(path, is_file=True, file_suffix=(".fasta", ".fa"))
custom_msa_paths = kwargs.pop("custom_msa_paths", None)
if custom_msa_paths:
kwargs["custom_msa_paths"] = cls._upload_custom_msa_files(custom_msa_paths)
query_name = query_name or path.stem
return cls(
fasta_files={path.stem: validate_fasta(path, str_output=True)},
query_name=query_name,
parameters=ChaiParameters(**kwargs),
)
@classmethod
def from_directory(
cls: ChaiQuery, path: str | Path, query_name: str | None = None, **kwargs
) -> ChaiQuery:
"""Initialize a ChaiQuery instance from a directory.
Supported file format in directory are:
- FASTA
Args:
path (str | Path): Path to a directory of FASTA files.
query_name (str | None, optional): User-defined query name. Defaults to None.
seed (int, optional): Random seed. Defaults to 0.
Raises:
ValueError: If no FASTA file are present in the directory.
Returns:
ChaiQuery
"""
path = validate_path(path, is_dir=True)
custom_msa_paths = kwargs.pop("custom_msa_paths", None)
if custom_msa_paths:
kwargs["custom_msa_paths"] = cls._upload_custom_msa_files(custom_msa_paths)
print(kwargs["custom_msa_paths"])
fasta_files = {
file.stem: validate_fasta(file, str_output=True)
for file in chain(path.glob("*.fasta"), path.glob("*.fa"))
}
if not fasta_files:
raise ValueError(f"No FASTA files found in directory '{path}'.")
query_name = query_name or path.name
return cls(
fasta_files=fasta_files,
query_name=query_name,
parameters=ChaiParameters(**kwargs),
)
@property
def payload(self) -> dict[str, Any]:
"""Payload to send to the prediction API endpoint."""
return {
"fasta_files": self.fasta_files,
"use_msa_server": self.parameters.use_msa_server,
"use_templates_server": self.parameters.use_templates_server,
"num_trunk_recycles": self.parameters.num_trunk_recycles,
"seed": self.parameters.seed,
"num_diffn_timesteps": self.parameters.num_diffn_timesteps,
"restraints": self.parameters.restraints,
"recycle_msa_subsample": self.parameters.recycle_msa_subsample,
"num_trunk_samples": self.parameters.num_trunk_samples,
"custom_msa_paths": self.parameters.custom_msa_paths,
}
@property
def parameters(self) -> ChaiParameters:
"""Parameters of the query."""
return self._parameters
@staticmethod
def _upload_custom_msa_files(
source: str, headers: str | None = None
) -> dict[str, str]:
"""Read A3M or MSA files from a file or directory and uploads them to GCS.
Args:
source (str): Path to an .a3m or .aligned.pqt file or a directory containing .a3m or .aligned.pqt files
identity_token (str | None, optional): GCP identity token. Defaults to None.
Raises:
ValueError: If file has unsupported extension.
ValueError: If directory has no supported file.
Returns:
dict[str, str]: _description_
"""
headers = headers or get_auth_headers()
source_path = validate_path(source)
# Process if source is a file.
if source_path.is_file():
if source_path.suffix == ".a3m":
with tempfile.TemporaryDirectory() as tmpdir:
tmp_path = Path(tmpdir)
shutil.copy(source_path, tmp_path / source_path.name)
pqt_file = a3m_to_aligned_pqt(str(tmp_path))
return process_uploaded_msas([Path(pqt_file)], headers)
elif source_path.name.endswith(".aligned.pqt"):
return process_uploaded_msas([source_path], headers)
else:
raise ValueError(
f"Invalid file type: {source_path.suffix}. Expected '.a3m' or a file ending with '.aligned.pqt'."
)
# Process if source is a directory.
elif source_path.is_dir():
pqt_files = list(source_path.glob("*.aligned.pqt"))
if pqt_files:
return process_uploaded_msas(pqt_files, headers)
a3m_files = list(source_path.glob("*.a3m"))
if not a3m_files:
raise ValueError(
f"Directory '{source}' contains no files ending with '.aligned.pqt' or '.a3m'."
)
with tempfile.TemporaryDirectory() as tmpdir:
tmp_path = Path(tmpdir)
for file in a3m_files:
shutil.copy(file, tmp_path / file.name)
pqt_file = a3m_to_aligned_pqt(str(tmp_path))
return process_uploaded_msas([Path(pqt_file)], headers)