jfaustin's picture
add dockerfile and folding studio cli
44459bb
"""AF2 folding submission command."""
from pathlib import Path
from typing import List, Optional
import typer
from folding_studio_data_models import (
FeatureMode,
)
from folding_studio_data_models.request.folding import FoldingModel
from typing_extensions import Annotated
from folding_studio.api_call.predict import (
batch_prediction,
batch_prediction_from_file,
simple_prediction,
)
from folding_studio.commands.predict.utils import (
print_instructions_batch,
print_instructions_simple,
validate_model_subset,
validate_source_path,
)
from folding_studio.config import FOLDING_API_KEY
from folding_studio.console import console
from folding_studio.utils.data_model import (
BatchInputFile,
PredictRequestCustomFiles,
PredictRequestParams,
)
from folding_studio.utils.input_validation import (
extract_and_validate_custom_msas,
extract_and_validate_custom_templates,
validate_initial_guess,
)
def af2( # pylint: disable=dangerous-default-value, too-many-arguments, too-many-locals
source: Annotated[
Path,
typer.Argument(
help=(
"Path to the data source. Either a fasta file, a directory of fasta files "
"or a csv/json file describing a batch prediction request."
),
callback=validate_source_path,
exists=True,
),
],
project_code: Annotated[
str,
typer.Option(
help=(
"Project code. If unknown, contact your PM or the Folding Studio team."
),
exists=True,
envvar="FOLDING_PROJECT_CODE",
),
],
cache: Annotated[
bool,
typer.Option(help="Use cached experiment results if any."),
] = True,
template_mode: Annotated[
FeatureMode,
typer.Option(help="Mode of the template features generation."),
] = FeatureMode.SEARCH,
custom_template: Annotated[
List[Path],
typer.Option(
help=(
"Path to a custom template or a directory of custom templates. "
"To pass multiple inputs, simply repeat the flag "
"(e.g. `--custom_template template_1.cif --custom_template template_2.cif`)."
),
callback=extract_and_validate_custom_templates,
exists=True,
),
] = [],
custom_template_id: Annotated[
List[str],
typer.Option(
help=(
"ID of a custom template. "
"To pass multiple inputs, simply repeat the flag "
"(e.g. `--custom_template_id template_ID_1 --custom_template_id template_ID_2`)."
)
),
] = [],
initial_guess_file: Annotated[
Path | None,
typer.Option(
help=("Path to an initial guess file."),
callback=validate_initial_guess,
exists=True,
),
] = None,
templates_masks_file: Annotated[
Path | None,
typer.Option(
help=("Path to a templates masks file."),
exists=True,
),
] = None,
msa_mode: Annotated[
FeatureMode,
typer.Option(help="Mode of the MSA features generation."),
] = FeatureMode.SEARCH,
custom_msa: Annotated[
List[Path],
typer.Option(
help=(
"Path to a custom msa or a directory of custom msas. "
"To pass multiple inputs, simply repeat the flag "
"(e.g. `--custom_msa msa_1.sto --custom_msa msa_2.sto`)."
),
callback=extract_and_validate_custom_msas,
exists=True,
),
] = [],
max_msa_clusters: Annotated[
int,
typer.Option(help="Max number of MSA clusters to search."),
] = -1,
max_extra_msa: Annotated[
int,
typer.Option(
help="Max extra non-clustered MSA representation to use as source."
),
] = -1,
gap_trick: Annotated[
bool,
typer.Option(
help="Activate gap trick, allowing to model complexes with monomer models."
),
] = False,
num_recycle: Annotated[
int,
typer.Option(
help="Number of refinement iterations of the predicted structures."
),
] = 3,
model_subset: Annotated[
list[int],
typer.Option(
help="Subset of AF2 model ids to use, between 1 and 5 included.",
callback=validate_model_subset,
),
] = [],
random_seed: Annotated[
int,
typer.Option(
help=(
"Random seed used during the MSA sampling. "
"Different random seed values will introduce variations in the predictions."
)
),
] = 0,
num_seed: Annotated[
Optional[int],
typer.Option(
help="Number of random seeds to use. Creates a batch prediction.", min=2
),
] = None,
metadata_file: Annotated[
Optional[Path],
typer.Option(
help=(
"Path to the file where the job metadata returned by the server are written."
),
),
] = None,
):
"""Asynchronous AF2 folding submission.
Read more at https://int-bio-foldingstudio-gcp.nw.r.appspot.com/how-to-guides/af2_openfold/single_af2_job/.
If the source is a CSV or JSON file describing a batch prediction request, all the other
options will be overlooked.
"""
if FOLDING_API_KEY:
console.print(":key: Using detected API key for authentication.")
else:
console.print(":yellow_circle: Using JWT for authentication.")
is_batch = source.is_dir() or source.suffix in BatchInputFile.__members__.values()
is_multi_seed = num_seed is not None
is_batch = is_batch or is_multi_seed
params = PredictRequestParams(
ignore_cache=not cache,
template_mode=template_mode,
custom_template_ids=custom_template_id,
msa_mode=msa_mode,
max_msa_clusters=max_msa_clusters,
max_extra_msa=max_extra_msa,
gap_trick=gap_trick,
num_recycle=num_recycle,
random_seed=random_seed,
model_subset=model_subset,
)
custom_files = PredictRequestCustomFiles(
templates=custom_template,
msas=custom_msa,
initial_guess_files=[initial_guess_file] if initial_guess_file else None,
templates_masks_files=[templates_masks_file] if templates_masks_file else None,
)
if is_batch:
if is_multi_seed:
response = batch_prediction(
files=[source],
folding_model=FoldingModel.AF2,
params=params,
custom_files=custom_files,
num_seed=num_seed,
project_code=project_code,
)
elif source.is_file():
console.print(
f"Submitting batch jobs configuration file [bold]{source}[/bold]"
)
console.print(
"Input options are [bold yellow]ignored[/bold yellow] in favor of the configuration file content."
)
response = batch_prediction_from_file(
file=source,
project_code=project_code,
)
elif source.is_dir():
response = batch_prediction(
files=list(f for f in source.iterdir() if f.is_file()),
folding_model=FoldingModel.AF2,
params=params,
custom_files=custom_files,
num_seed=num_seed,
project_code=project_code,
)
print_instructions_batch(response_json=response, metadata_file=metadata_file)
else:
response = simple_prediction(
file=source,
folding_model=FoldingModel.AF2,
params=params,
custom_files=custom_files,
project_code=project_code,
)
print_instructions_simple(response_json=response, metadata_file=metadata_file)