"""CLI MSA search command and sub-commands.""" import json import shutil import zipfile from datetime import datetime from pathlib import Path from typing import Optional import requests import typer from folding_studio_data_models import ( FeatureMode, MessageStatus, MSAPublication, ) from rich import print # pylint:disable=redefined-builtin from rich.console import Console from rich.markdown import Markdown from rich.table import Table from typing_extensions import Annotated from folding_studio.api_call.msa import simple_msa from folding_studio.config import API_URL, REQUEST_TIMEOUT from folding_studio.utils.data_model import ( MSARequestParams, SimpleInputFile, ) from folding_studio.utils.headers import get_auth_headers app = typer.Typer(help="Handle MSA operation") msa_experiment_app = typer.Typer(help="Commands related to MSA experiments.") app.add_typer(msa_experiment_app, name="experiment") msa_experiment_ID_argument = typer.Argument(help="ID of the MSA experiment.") def _validate_source_path(path: Path) -> Path: """Validate the msa job input source path. Args: path (Path): Source path. Raises: typer.BadParameter: If the source is an unsupported file. Returns: Path: The source. """ supported_simple_msa = tuple(item.value for item in SimpleInputFile) if path.suffix not in supported_simple_msa: raise typer.BadParameter( f"The source file '{path}' is not supported. " f"Only {supported_simple_msa} files are supported." ) return path def _print_instructions_simple(response_json: dict, metadata_file: Path | None) -> None: """Print pretty instructions after successful call to batch predict endpoint. Args: response_json (dict): Server json response metadata_file: (Path | None): File path where job submission metadata are written. """ pub = MSAPublication.model_validate(response_json) msa_experiment_id = pub.message.msa_experiment_id console = Console() if pub.status == MessageStatus.NOT_PUBLISHED_DONE: print( f"The results of your msa_experiment {msa_experiment_id} were found in the cache." ) print("Use the following command to download the msa results:") md = f"""```shell folding msa experiment features {msa_experiment_id} """ console.print(Markdown(md)) elif pub.status == MessageStatus.NOT_PUBLISHED_PENDING: print( f"Your msa_experiment [bold]{msa_experiment_id}[/bold] is [bold green]still running.[/bold green]" ) print("Use the following command to check on its status at a later time.") md = f"""```shell folding msa experiment status {msa_experiment_id} """ console.print(Markdown(md)) elif pub.status == MessageStatus.PUBLISHED: print("[bold green]Experiment submitted successfully ![/bold green]") print(f"The msa_experiment_id is [bold]{msa_experiment_id}[/bold]") if not metadata_file: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") metadata_file = f"simple_prediction_{timestamp}.json" with open(metadata_file, "w") as f: json.dump(response_json, f, indent=4) print(f"Prediction job metadata written to [bold]{metadata_file}[/bold]") print("You can query your experiment status with the command:") md = f"""```shell folding msa experiment status {msa_experiment_id} """ console.print(Markdown(md)) else: raise ValueError(f"Unknown publication status: {pub.status}") @app.command() def search( # pylint: disable=dangerous-default-value, too-many-arguments, too-many-locals source: Annotated[ Path, typer.Argument( help=("Path to the input fasta file."), callback=_validate_source_path, exists=True, ), ], project_code: Annotated[ str, typer.Option( help=( "Project code. If unknown, contact your PM or the Folding Studio team." ), exists=True, envvar="FOLDING_PROJECT_CODE", ), ], cache: Annotated[ bool, typer.Option(help="Use cached experiment results if any."), ] = True, msa_mode: Annotated[ FeatureMode, typer.Option(help="Mode of the MSA features generation."), ] = FeatureMode.SEARCH, metadata_file: Annotated[ Optional[Path], typer.Option( help=( "Path to the file where the job metadata returned by the server are written." ), ), ] = None, ): """Run an MSA tool. Read more at https://int-bio-foldingstudio-gcp.nw.r.appspot.com/tutorials/msa_search/.""" params = MSARequestParams( ignore_cache=not cache, msa_mode=msa_mode, ) response = simple_msa( file=source, params=params, project_code=project_code, ) _print_instructions_simple(response_json=response, metadata_file=metadata_file) def _download_file_from_signed_url( msa_exp_id: str, endpoint: str, output: Path, force: bool, unzip: bool = False, ) -> None: """Download a zip file from an experiment id. Args: msa_exp_id (str): MSA Experiment id. endpoint (str): API endpoint to call. output (Path): Output file path. force (bool): Force file writing if it already exists. unzip (bool): Unzip the zip file after downloading. Raises: typer.Exit: If output file path exists but force set to false. typer.Exit: If unzip set to true but the directory already exists and force set to false. typer.Exit: If an error occurred during the initial API call. """ if output.exists() and not force: print( f"Warning: The file '{output}' already exists. Use the --force flag to overwrite it." ) raise typer.Exit(code=1) if unzip: if not output.suffix == ".zip": print( "Error: The downloaded file is not a .zip file. Please ensure the correct file format." ) raise typer.Exit(code=1) dir_path = output.with_suffix("") if dir_path.exists() and not force: print( f"Warning: The --unzip flag is raised but the directory '{dir_path}' " "already exists. Use the --force flag to overwrite it." ) raise typer.Exit(code=1) url = API_URL + endpoint headers = get_auth_headers() response = requests.get( url, params={"msa_experiment_id": msa_exp_id}, headers=headers, timeout=REQUEST_TIMEOUT, ) if not response.ok: print(f"Failed to download the file: {response.content.decode()}.") raise typer.Exit(code=1) file_response = requests.get( response.json()["signed_url"], stream=True, timeout=REQUEST_TIMEOUT, ) with output.open("wb") as f: file_response.raw.decode_content = True shutil.copyfileobj(file_response.raw, f) print(f"File downloaded successfully to {output}.") if unzip: dir_path.mkdir(parents=True, exist_ok=True) with zipfile.ZipFile(output, "r") as zip_ref: zip_ref.extractall(dir_path) print(f"Extracted all files to {dir_path}.") @msa_experiment_app.command() def status( msa_exp_id: Annotated[str, msa_experiment_ID_argument], ): """Get an MSA experiment status.""" url = API_URL + "getMSAExperimentStatus" headers = get_auth_headers() response = requests.get( url, params={"msa_experiment_id": msa_exp_id}, headers=headers, timeout=REQUEST_TIMEOUT, ) if not response.ok: print(f"An error occurred : {response.content.decode()}") raise typer.Exit(code=1) message = response.json() print(message["status"]) @msa_experiment_app.command() def features( msa_exp_id: Annotated[str, msa_experiment_ID_argument], output: Annotated[ Optional[Path], typer.Option( help="Local path to download the zip to. Default to '_features.zip'." ), ] = None, force: Annotated[ bool, typer.Option( help=( "Forces the download to overwrite any existing file " "with the same name in the specified location." ) ), ] = False, unzip: Annotated[ bool, typer.Option(help="Automatically unzip the file after its download.") ] = False, ): """Get an experiment features.""" if output is None: output = Path(f"{msa_exp_id}_features.zip") _download_file_from_signed_url( msa_exp_id=msa_exp_id, endpoint="getZippedMSAExperimentFeatures", output=output, force=force, unzip=unzip, ) @msa_experiment_app.command() def logs( msa_exp_id: Annotated[str, msa_experiment_ID_argument], output: Annotated[ Optional[Path], typer.Option( help="Local path to download the logs to. Default to '_logs.txt'." ), ] = None, force: Annotated[ bool, typer.Option( help=( "Forces the download to overwrite any existing file " "with the same name in the specified location." ) ), ] = False, ): """Get an experiment logs.""" if output is None: output = Path(f"{msa_exp_id}_logs.txt") _download_file_from_signed_url( msa_exp_id=msa_exp_id, endpoint="getExperimentLogs", output=output, force=force, ) @msa_experiment_app.command() def list( limit: Annotated[ int, typer.Option( help=("Max number of experiment to display in the terminal."), ), ] = 100, output: Annotated[ Optional[Path], typer.Option( "--output", "-o", help=( "Path to the file where the job metadata returned by the server are written." ), ), ] = None, ): # pylint:disable=redefined-builtin """Get all your done and pending experiment ids. The IDs are provided in the order of submission, starting with the most recent.""" headers = get_auth_headers() url = API_URL + "getDoneAndPendingMSAExperiments" response = requests.get( url, headers=headers, timeout=REQUEST_TIMEOUT, ) if not response.ok: print(f"An error occurred : {response.content.decode()}") raise typer.Exit(code=1) response_json = response.json() if output: with open(output, "w") as f: json.dump(response_json, f, indent=4) print(f"Done and pending MSA experiments list written to [bold]{output}[/bold]") table = Table(title="Done and pending MSA experiments") table.add_column("MSA Experiment ID", justify="right", style="cyan", no_wrap=True) table.add_column("Status", style="magenta") total_exp_nb = 0 for status, exp_list in response_json.items(): total_exp_nb += len(exp_list) for exp in exp_list: table.add_row(exp, status) if limit < total_exp_nb: print( f"The table below is truncated to the last [bold]{limit}[/bold] submitted MSA experiments. Increase '--limit' to see more." ) if not output: print("Use '--output' to get the full list in file format.") else: print(f"See the full list in file format at [bold]{output}[/bold]") break console = Console() console.print(table)