"""AF2 folding submission command.""" from pathlib import Path from typing import List, Optional import typer from folding_studio_data_models import ( FeatureMode, ) from folding_studio_data_models.request.folding import FoldingModel from typing_extensions import Annotated from folding_studio.api_call.predict import ( batch_prediction, batch_prediction_from_file, simple_prediction, ) from folding_studio.commands.predict.utils import ( print_instructions_batch, print_instructions_simple, validate_model_subset, validate_source_path, ) from folding_studio.config import FOLDING_API_KEY from folding_studio.console import console from folding_studio.utils.data_model import ( BatchInputFile, PredictRequestCustomFiles, PredictRequestParams, ) from folding_studio.utils.input_validation import ( extract_and_validate_custom_msas, extract_and_validate_custom_templates, validate_initial_guess, ) def af2( # pylint: disable=dangerous-default-value, too-many-arguments, too-many-locals source: Annotated[ Path, typer.Argument( help=( "Path to the data source. Either a fasta file, a directory of fasta files " "or a csv/json file describing a batch prediction request." ), callback=validate_source_path, exists=True, ), ], project_code: Annotated[ str, typer.Option( help=( "Project code. If unknown, contact your PM or the Folding Studio team." ), exists=True, envvar="FOLDING_PROJECT_CODE", ), ], cache: Annotated[ bool, typer.Option(help="Use cached experiment results if any."), ] = True, template_mode: Annotated[ FeatureMode, typer.Option(help="Mode of the template features generation."), ] = FeatureMode.SEARCH, custom_template: Annotated[ List[Path], typer.Option( help=( "Path to a custom template or a directory of custom templates. " "To pass multiple inputs, simply repeat the flag " "(e.g. `--custom_template template_1.cif --custom_template template_2.cif`)." ), callback=extract_and_validate_custom_templates, exists=True, ), ] = [], custom_template_id: Annotated[ List[str], typer.Option( help=( "ID of a custom template. " "To pass multiple inputs, simply repeat the flag " "(e.g. `--custom_template_id template_ID_1 --custom_template_id template_ID_2`)." ) ), ] = [], initial_guess_file: Annotated[ Path | None, typer.Option( help=("Path to an initial guess file."), callback=validate_initial_guess, exists=True, ), ] = None, templates_masks_file: Annotated[ Path | None, typer.Option( help=("Path to a templates masks file."), exists=True, ), ] = None, msa_mode: Annotated[ FeatureMode, typer.Option(help="Mode of the MSA features generation."), ] = FeatureMode.SEARCH, custom_msa: Annotated[ List[Path], typer.Option( help=( "Path to a custom msa or a directory of custom msas. " "To pass multiple inputs, simply repeat the flag " "(e.g. `--custom_msa msa_1.sto --custom_msa msa_2.sto`)." ), callback=extract_and_validate_custom_msas, exists=True, ), ] = [], max_msa_clusters: Annotated[ int, typer.Option(help="Max number of MSA clusters to search."), ] = -1, max_extra_msa: Annotated[ int, typer.Option( help="Max extra non-clustered MSA representation to use as source." ), ] = -1, gap_trick: Annotated[ bool, typer.Option( help="Activate gap trick, allowing to model complexes with monomer models." ), ] = False, num_recycle: Annotated[ int, typer.Option( help="Number of refinement iterations of the predicted structures." ), ] = 3, model_subset: Annotated[ list[int], typer.Option( help="Subset of AF2 model ids to use, between 1 and 5 included.", callback=validate_model_subset, ), ] = [], random_seed: Annotated[ int, typer.Option( help=( "Random seed used during the MSA sampling. " "Different random seed values will introduce variations in the predictions." ) ), ] = 0, num_seed: Annotated[ Optional[int], typer.Option( help="Number of random seeds to use. Creates a batch prediction.", min=2 ), ] = None, metadata_file: Annotated[ Optional[Path], typer.Option( help=( "Path to the file where the job metadata returned by the server are written." ), ), ] = None, ): """Asynchronous AF2 folding submission. Read more at https://int-bio-foldingstudio-gcp.nw.r.appspot.com/how-to-guides/af2_openfold/single_af2_job/. If the source is a CSV or JSON file describing a batch prediction request, all the other options will be overlooked. """ if FOLDING_API_KEY: console.print(":key: Using detected API key for authentication.") else: console.print(":yellow_circle: Using JWT for authentication.") is_batch = source.is_dir() or source.suffix in BatchInputFile.__members__.values() is_multi_seed = num_seed is not None is_batch = is_batch or is_multi_seed params = PredictRequestParams( ignore_cache=not cache, template_mode=template_mode, custom_template_ids=custom_template_id, msa_mode=msa_mode, max_msa_clusters=max_msa_clusters, max_extra_msa=max_extra_msa, gap_trick=gap_trick, num_recycle=num_recycle, random_seed=random_seed, model_subset=model_subset, ) custom_files = PredictRequestCustomFiles( templates=custom_template, msas=custom_msa, initial_guess_files=[initial_guess_file] if initial_guess_file else None, templates_masks_files=[templates_masks_file] if templates_masks_file else None, ) if is_batch: if is_multi_seed: response = batch_prediction( files=[source], folding_model=FoldingModel.AF2, params=params, custom_files=custom_files, num_seed=num_seed, project_code=project_code, ) elif source.is_file(): console.print( f"Submitting batch jobs configuration file [bold]{source}[/bold]" ) console.print( "Input options are [bold yellow]ignored[/bold yellow] in favor of the configuration file content." ) response = batch_prediction_from_file( file=source, project_code=project_code, ) elif source.is_dir(): response = batch_prediction( files=list(f for f in source.iterdir() if f.is_file()), folding_model=FoldingModel.AF2, params=params, custom_files=custom_files, num_seed=num_seed, project_code=project_code, ) print_instructions_batch(response_json=response, metadata_file=metadata_file) else: response = simple_prediction( file=source, folding_model=FoldingModel.AF2, params=params, custom_files=custom_files, project_code=project_code, ) print_instructions_simple(response_json=response, metadata_file=metadata_file)