|
"""AF2 folding submission command.""" |
|
|
|
from pathlib import Path |
|
from typing import List, Optional |
|
|
|
import typer |
|
from folding_studio_data_models import ( |
|
FeatureMode, |
|
) |
|
from folding_studio_data_models.request.folding import FoldingModel |
|
from typing_extensions import Annotated |
|
|
|
from folding_studio.api_call.predict import ( |
|
batch_prediction, |
|
batch_prediction_from_file, |
|
simple_prediction, |
|
) |
|
from folding_studio.commands.predict.utils import ( |
|
print_instructions_batch, |
|
print_instructions_simple, |
|
validate_model_subset, |
|
validate_source_path, |
|
) |
|
from folding_studio.config import FOLDING_API_KEY |
|
from folding_studio.console import console |
|
from folding_studio.utils.data_model import ( |
|
BatchInputFile, |
|
PredictRequestCustomFiles, |
|
PredictRequestParams, |
|
) |
|
from folding_studio.utils.input_validation import ( |
|
extract_and_validate_custom_msas, |
|
extract_and_validate_custom_templates, |
|
validate_initial_guess, |
|
) |
|
|
|
|
|
def af2( |
|
source: Annotated[ |
|
Path, |
|
typer.Argument( |
|
help=( |
|
"Path to the data source. Either a fasta file, a directory of fasta files " |
|
"or a csv/json file describing a batch prediction request." |
|
), |
|
callback=validate_source_path, |
|
exists=True, |
|
), |
|
], |
|
project_code: Annotated[ |
|
str, |
|
typer.Option( |
|
help=( |
|
"Project code. If unknown, contact your PM or the Folding Studio team." |
|
), |
|
exists=True, |
|
envvar="FOLDING_PROJECT_CODE", |
|
), |
|
], |
|
cache: Annotated[ |
|
bool, |
|
typer.Option(help="Use cached experiment results if any."), |
|
] = True, |
|
template_mode: Annotated[ |
|
FeatureMode, |
|
typer.Option(help="Mode of the template features generation."), |
|
] = FeatureMode.SEARCH, |
|
custom_template: Annotated[ |
|
List[Path], |
|
typer.Option( |
|
help=( |
|
"Path to a custom template or a directory of custom templates. " |
|
"To pass multiple inputs, simply repeat the flag " |
|
"(e.g. `--custom_template template_1.cif --custom_template template_2.cif`)." |
|
), |
|
callback=extract_and_validate_custom_templates, |
|
exists=True, |
|
), |
|
] = [], |
|
custom_template_id: Annotated[ |
|
List[str], |
|
typer.Option( |
|
help=( |
|
"ID of a custom template. " |
|
"To pass multiple inputs, simply repeat the flag " |
|
"(e.g. `--custom_template_id template_ID_1 --custom_template_id template_ID_2`)." |
|
) |
|
), |
|
] = [], |
|
initial_guess_file: Annotated[ |
|
Path | None, |
|
typer.Option( |
|
help=("Path to an initial guess file."), |
|
callback=validate_initial_guess, |
|
exists=True, |
|
), |
|
] = None, |
|
templates_masks_file: Annotated[ |
|
Path | None, |
|
typer.Option( |
|
help=("Path to a templates masks file."), |
|
exists=True, |
|
), |
|
] = None, |
|
msa_mode: Annotated[ |
|
FeatureMode, |
|
typer.Option(help="Mode of the MSA features generation."), |
|
] = FeatureMode.SEARCH, |
|
custom_msa: Annotated[ |
|
List[Path], |
|
typer.Option( |
|
help=( |
|
"Path to a custom msa or a directory of custom msas. " |
|
"To pass multiple inputs, simply repeat the flag " |
|
"(e.g. `--custom_msa msa_1.sto --custom_msa msa_2.sto`)." |
|
), |
|
callback=extract_and_validate_custom_msas, |
|
exists=True, |
|
), |
|
] = [], |
|
max_msa_clusters: Annotated[ |
|
int, |
|
typer.Option(help="Max number of MSA clusters to search."), |
|
] = -1, |
|
max_extra_msa: Annotated[ |
|
int, |
|
typer.Option( |
|
help="Max extra non-clustered MSA representation to use as source." |
|
), |
|
] = -1, |
|
gap_trick: Annotated[ |
|
bool, |
|
typer.Option( |
|
help="Activate gap trick, allowing to model complexes with monomer models." |
|
), |
|
] = False, |
|
num_recycle: Annotated[ |
|
int, |
|
typer.Option( |
|
help="Number of refinement iterations of the predicted structures." |
|
), |
|
] = 3, |
|
model_subset: Annotated[ |
|
list[int], |
|
typer.Option( |
|
help="Subset of AF2 model ids to use, between 1 and 5 included.", |
|
callback=validate_model_subset, |
|
), |
|
] = [], |
|
random_seed: Annotated[ |
|
int, |
|
typer.Option( |
|
help=( |
|
"Random seed used during the MSA sampling. " |
|
"Different random seed values will introduce variations in the predictions." |
|
) |
|
), |
|
] = 0, |
|
num_seed: Annotated[ |
|
Optional[int], |
|
typer.Option( |
|
help="Number of random seeds to use. Creates a batch prediction.", min=2 |
|
), |
|
] = None, |
|
metadata_file: Annotated[ |
|
Optional[Path], |
|
typer.Option( |
|
help=( |
|
"Path to the file where the job metadata returned by the server are written." |
|
), |
|
), |
|
] = None, |
|
): |
|
"""Asynchronous AF2 folding submission. |
|
|
|
Read more at https://int-bio-foldingstudio-gcp.nw.r.appspot.com/how-to-guides/af2_openfold/single_af2_job/. |
|
|
|
If the source is a CSV or JSON file describing a batch prediction request, all the other |
|
options will be overlooked. |
|
""" |
|
|
|
if FOLDING_API_KEY: |
|
console.print(":key: Using detected API key for authentication.") |
|
else: |
|
console.print(":yellow_circle: Using JWT for authentication.") |
|
|
|
is_batch = source.is_dir() or source.suffix in BatchInputFile.__members__.values() |
|
is_multi_seed = num_seed is not None |
|
is_batch = is_batch or is_multi_seed |
|
|
|
params = PredictRequestParams( |
|
ignore_cache=not cache, |
|
template_mode=template_mode, |
|
custom_template_ids=custom_template_id, |
|
msa_mode=msa_mode, |
|
max_msa_clusters=max_msa_clusters, |
|
max_extra_msa=max_extra_msa, |
|
gap_trick=gap_trick, |
|
num_recycle=num_recycle, |
|
random_seed=random_seed, |
|
model_subset=model_subset, |
|
) |
|
|
|
custom_files = PredictRequestCustomFiles( |
|
templates=custom_template, |
|
msas=custom_msa, |
|
initial_guess_files=[initial_guess_file] if initial_guess_file else None, |
|
templates_masks_files=[templates_masks_file] if templates_masks_file else None, |
|
) |
|
|
|
if is_batch: |
|
if is_multi_seed: |
|
response = batch_prediction( |
|
files=[source], |
|
folding_model=FoldingModel.AF2, |
|
params=params, |
|
custom_files=custom_files, |
|
num_seed=num_seed, |
|
project_code=project_code, |
|
) |
|
elif source.is_file(): |
|
console.print( |
|
f"Submitting batch jobs configuration file [bold]{source}[/bold]" |
|
) |
|
console.print( |
|
"Input options are [bold yellow]ignored[/bold yellow] in favor of the configuration file content." |
|
) |
|
response = batch_prediction_from_file( |
|
file=source, |
|
project_code=project_code, |
|
) |
|
elif source.is_dir(): |
|
response = batch_prediction( |
|
files=list(f for f in source.iterdir() if f.is_file()), |
|
folding_model=FoldingModel.AF2, |
|
params=params, |
|
custom_files=custom_files, |
|
num_seed=num_seed, |
|
project_code=project_code, |
|
) |
|
print_instructions_batch(response_json=response, metadata_file=metadata_file) |
|
else: |
|
response = simple_prediction( |
|
file=source, |
|
folding_model=FoldingModel.AF2, |
|
params=params, |
|
custom_files=custom_files, |
|
project_code=project_code, |
|
) |
|
print_instructions_simple(response_json=response, metadata_file=metadata_file) |
|
|