|
"""Query module for Chai prediction endpoint.""" |
|
|
|
from __future__ import annotations |
|
|
|
import shutil |
|
import tempfile |
|
from io import StringIO |
|
from itertools import chain |
|
from pathlib import Path |
|
from typing import Any |
|
|
|
from folding_studio_data_models import FoldingModel |
|
from pydantic import BaseModel, field_validator |
|
|
|
from folding_studio.commands.utils import ( |
|
a3m_to_aligned_pqt, |
|
process_uploaded_msas, |
|
) |
|
from folding_studio.query import Query |
|
from folding_studio.utils.fasta import validate_fasta |
|
from folding_studio.utils.headers import get_auth_headers |
|
from folding_studio.utils.path_helpers import validate_path |
|
|
|
|
|
class ChaiParameters(BaseModel): |
|
"""Chai1 inference parameters.""" |
|
|
|
seed: int = 0 |
|
num_trunk_recycles: int = 3 |
|
num_diffn_timesteps: int = 200 |
|
recycle_msa_subsample: int = 0 |
|
num_trunk_samples: int = 1 |
|
restraints: str | None = None |
|
use_msa_server: bool = False |
|
use_templates_server: bool = False |
|
custom_msa_paths: dict[str, str] | None = None |
|
|
|
@field_validator("restraints", mode="before") |
|
def read_restraints( |
|
cls: ChaiParameters, restraints: str | Path | None |
|
) -> str | None: |
|
"""Reads restraints from a CSV file and returns its content as a string.""" |
|
if restraints is None: |
|
return |
|
path = validate_path(restraints, is_file=True, file_suffix=(".csv")) |
|
with path.open(newline="", encoding="utf-8") as csvfile: |
|
return csvfile.read().strip() |
|
|
|
|
|
class ChaiQuery(Query): |
|
"""Chai1 model query.""" |
|
|
|
MODEL = FoldingModel.CHAI |
|
|
|
def __init__( |
|
self, |
|
fasta_files: dict[str, str], |
|
query_name: str, |
|
parameters: ChaiParameters = ChaiParameters(), |
|
): |
|
"""Initializes a ChaiQuery instance.""" |
|
if not fasta_files: |
|
raise ValueError("FASTA files dictionary cannot be empty.") |
|
|
|
self.fasta_files = fasta_files |
|
self.query_name = query_name |
|
self._parameters = parameters |
|
|
|
@classmethod |
|
def from_protein_sequence( |
|
cls: ChaiQuery, sequence: str, query_name: str | None = None, **kwargs |
|
) -> ChaiQuery: |
|
"""Initialize a ChaiQuery instance from a str protein sequence. |
|
|
|
Args: |
|
sequence (str): Protein amino-acid sequence |
|
query_name (str | None, optional): User-defined query name. Defaults to None. |
|
|
|
Raises: |
|
NotAMonomer: If the sequence is not a monomer complex. |
|
|
|
Returns: |
|
ChaiQuery |
|
""" |
|
record = validate_fasta(StringIO(sequence)) |
|
|
|
custom_msa_paths = kwargs.pop("custom_msa_paths", None) |
|
if custom_msa_paths: |
|
kwargs["custom_msa_paths"] = cls._upload_custom_msa_files(custom_msa_paths) |
|
|
|
query_name = ( |
|
query_name |
|
if query_name is not None |
|
else record.description.split("|", maxsplit=1)[0] |
|
) |
|
return cls( |
|
fasta_files={query_name: sequence}, |
|
query_name=query_name, |
|
parameters=ChaiParameters(**kwargs), |
|
) |
|
|
|
@classmethod |
|
def from_file( |
|
cls: ChaiQuery, path: str | Path, query_name: str | None = None, **kwargs |
|
) -> ChaiQuery: |
|
"""Initialize a ChaiQuery instance from a file. |
|
|
|
Supported file format are: |
|
- FASTA |
|
|
|
Args: |
|
path (str | Path): Path of the FASTA file. |
|
query_name (str | None, optional): User-defined query name. Defaults to None. |
|
|
|
|
|
Returns: |
|
ChaiQuery |
|
""" |
|
path = validate_path(path, is_file=True, file_suffix=(".fasta", ".fa")) |
|
|
|
custom_msa_paths = kwargs.pop("custom_msa_paths", None) |
|
if custom_msa_paths: |
|
kwargs["custom_msa_paths"] = cls._upload_custom_msa_files(custom_msa_paths) |
|
|
|
query_name = query_name or path.stem |
|
return cls( |
|
fasta_files={path.stem: validate_fasta(path, str_output=True)}, |
|
query_name=query_name, |
|
parameters=ChaiParameters(**kwargs), |
|
) |
|
|
|
@classmethod |
|
def from_directory( |
|
cls: ChaiQuery, path: str | Path, query_name: str | None = None, **kwargs |
|
) -> ChaiQuery: |
|
"""Initialize a ChaiQuery instance from a directory. |
|
|
|
Supported file format in directory are: |
|
- FASTA |
|
|
|
Args: |
|
path (str | Path): Path to a directory of FASTA files. |
|
query_name (str | None, optional): User-defined query name. Defaults to None. |
|
seed (int, optional): Random seed. Defaults to 0. |
|
|
|
Raises: |
|
ValueError: If no FASTA file are present in the directory. |
|
|
|
Returns: |
|
ChaiQuery |
|
""" |
|
path = validate_path(path, is_dir=True) |
|
custom_msa_paths = kwargs.pop("custom_msa_paths", None) |
|
if custom_msa_paths: |
|
kwargs["custom_msa_paths"] = cls._upload_custom_msa_files(custom_msa_paths) |
|
print(kwargs["custom_msa_paths"]) |
|
fasta_files = { |
|
file.stem: validate_fasta(file, str_output=True) |
|
for file in chain(path.glob("*.fasta"), path.glob("*.fa")) |
|
} |
|
if not fasta_files: |
|
raise ValueError(f"No FASTA files found in directory '{path}'.") |
|
query_name = query_name or path.name |
|
return cls( |
|
fasta_files=fasta_files, |
|
query_name=query_name, |
|
parameters=ChaiParameters(**kwargs), |
|
) |
|
|
|
@property |
|
def payload(self) -> dict[str, Any]: |
|
"""Payload to send to the prediction API endpoint.""" |
|
return { |
|
"fasta_files": self.fasta_files, |
|
"use_msa_server": self.parameters.use_msa_server, |
|
"use_templates_server": self.parameters.use_templates_server, |
|
"num_trunk_recycles": self.parameters.num_trunk_recycles, |
|
"seed": self.parameters.seed, |
|
"num_diffn_timesteps": self.parameters.num_diffn_timesteps, |
|
"restraints": self.parameters.restraints, |
|
"recycle_msa_subsample": self.parameters.recycle_msa_subsample, |
|
"num_trunk_samples": self.parameters.num_trunk_samples, |
|
"custom_msa_paths": self.parameters.custom_msa_paths, |
|
} |
|
|
|
@property |
|
def parameters(self) -> ChaiParameters: |
|
"""Parameters of the query.""" |
|
return self._parameters |
|
|
|
@staticmethod |
|
def _upload_custom_msa_files( |
|
source: str, headers: str | None = None |
|
) -> dict[str, str]: |
|
"""Read A3M or MSA files from a file or directory and uploads them to GCS. |
|
|
|
Args: |
|
source (str): Path to an .a3m or .aligned.pqt file or a directory containing .a3m or .aligned.pqt files |
|
identity_token (str | None, optional): GCP identity token. Defaults to None. |
|
|
|
Raises: |
|
ValueError: If file has unsupported extension. |
|
ValueError: If directory has no supported file. |
|
|
|
Returns: |
|
dict[str, str]: _description_ |
|
""" |
|
|
|
headers = headers or get_auth_headers() |
|
source_path = validate_path(source) |
|
|
|
|
|
if source_path.is_file(): |
|
if source_path.suffix == ".a3m": |
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
tmp_path = Path(tmpdir) |
|
shutil.copy(source_path, tmp_path / source_path.name) |
|
pqt_file = a3m_to_aligned_pqt(str(tmp_path)) |
|
return process_uploaded_msas([Path(pqt_file)], headers) |
|
elif source_path.name.endswith(".aligned.pqt"): |
|
return process_uploaded_msas([source_path], headers) |
|
else: |
|
raise ValueError( |
|
f"Invalid file type: {source_path.suffix}. Expected '.a3m' or a file ending with '.aligned.pqt'." |
|
) |
|
|
|
|
|
elif source_path.is_dir(): |
|
pqt_files = list(source_path.glob("*.aligned.pqt")) |
|
if pqt_files: |
|
return process_uploaded_msas(pqt_files, headers) |
|
|
|
a3m_files = list(source_path.glob("*.a3m")) |
|
if not a3m_files: |
|
raise ValueError( |
|
f"Directory '{source}' contains no files ending with '.aligned.pqt' or '.a3m'." |
|
) |
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
tmp_path = Path(tmpdir) |
|
for file in a3m_files: |
|
shutil.copy(file, tmp_path / file.name) |
|
pqt_file = a3m_to_aligned_pqt(str(tmp_path)) |
|
return process_uploaded_msas([Path(pqt_file)], headers) |
|
|