"""Shared utils for the predict command.""" import json from datetime import datetime from pathlib import Path import typer from folding_studio_data_models import BatchPublication, MessageStatus, Publication from rich.markdown import Markdown from folding_studio.console import console from folding_studio.utils.data_model import BatchInputFile, SimpleInputFile def validate_source_path(path: Path) -> Path: """Validate the prediction source path. Args: path (Path): Source path. Raises: typer.BadParameter: If the source is an empty directory. typer.BadParameter: If the source is a directory containing unsupported files. typer.BadParameter: If the source is an unsupported file. Returns: Path: The source. """ supported_simple_prediction = tuple(item.value for item in SimpleInputFile) supported_batch_prediction = tuple(item.value for item in BatchInputFile) if path.is_dir(): if not any(path.iterdir()): raise typer.BadParameter(f"The source directory `{path}` is empty.") for file in path.iterdir(): if file.is_file(): if file.suffix not in supported_simple_prediction: raise typer.BadParameter( f"The source directory '{path}' contains unsupported files. " f"Only {supported_simple_prediction} files are supported." ) elif path.suffix not in supported_simple_prediction + supported_batch_prediction: raise typer.BadParameter( f"The source file '{path}' is not supported. " f"Only {supported_simple_prediction + supported_batch_prediction} files are supported." ) return path def validate_model_subset(model_subset: list[int]) -> list[int]: """Validate the model_subset argument. Args: model_subset (list[int]): List of model subset requested. Raises: typer.BadParameter: If more than 5 model ids are specified. typer.BadParameter: If model ids not between 1 and 5 (included). Returns: list[int]: List of model subset requested. """ if len(model_subset) == 0: return model_subset elif len(model_subset) > 5: raise typer.BadParameter( f"--model_subset accept 5 model ids at most but `{len(model_subset)}` were specified." ) elif min(model_subset) < 1 or max(model_subset) > 5: raise typer.BadParameter( "Model subset id out of supported range. --model_subset accepts ids between 1 and 5 (included)." ) return model_subset def print_instructions_simple(response_json: dict, metadata_file: Path | None) -> None: """Print pretty instructions after successful call to predict endpoint. Args: response_json (dict): Server json response metadata_file: (Path | None): File path where job submission metadata are written. """ pub = Publication.model_validate(response_json) experiment_id = pub.message.experiment_id if pub.status == MessageStatus.NOT_PUBLISHED_DONE: console.print( f"The results of your experiment {experiment_id} were found in the cache." ) console.print("Use the following command to download the prediction results:") md = f"""```shell folding experiment results {experiment_id} """ console.print(Markdown(md)) elif pub.status == MessageStatus.NOT_PUBLISHED_PENDING: console.print( f"Your experiment [bold]{experiment_id}[/bold] is [bold green]still running.[/bold green]" ) console.print( "Use the following command to check on its status at a later time." ) md = f"""```shell folding experiment status {experiment_id} """ console.print(Markdown(md)) elif pub.status == MessageStatus.PUBLISHED: console.print("[bold green]Experiment submitted successfully ![/bold green]") console.print(f"The experiment id is [bold]{experiment_id}[/bold]") if not metadata_file: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") metadata_file = f"simple_prediction_{timestamp}.json" with open(metadata_file, "w") as f: json.dump(response_json, f, indent=4) console.print( f"Prediction job metadata written to [bold]{metadata_file}[/bold]" ) console.print("You can query your experiment status with the command:") md = f"""```shell folding experiment status {experiment_id} """ console.print(Markdown(md)) else: raise ValueError(f"Unknown publication status: {pub.status}") def print_instructions_batch(response_json: dict, metadata_file: Path | None) -> None: """Print pretty instructions after successful call to batch predict endpoint. Args: response_json (dict): Server json response metadata_file: (Path | None): File path where job submission metadata are written. """ pub = BatchPublication.model_validate(response_json) non_cached_exps = [ non_cached_pub.message.experiment_id for non_cached_pub in pub.publications ] cached_exps = [ cached_pub.message.experiment_id for cached_pub in pub.cached_publications ] done_exps = [ cached_pub.message.experiment_id for cached_pub in pub.cached_publications if cached_pub.status == MessageStatus.NOT_PUBLISHED_DONE ] if not metadata_file: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") metadata_file = f"batch_prediction_{timestamp}.json" with open(metadata_file, "w") as f: json.dump(response_json, f, indent=4) console.print(f"Batch prediction job metadata written to {metadata_file}") console.print("This file contains your experiments ids.") if pub.cached: console.print( "The results of your experiments were [bold]all found in the cache.[/bold]" ) console.print("The experiment ids are:") console.print(f"{cached_exps}") console.print( "Use the `folding experiment status id` command to check on their status. For example:" ) md = f"""```shell folding experiment status {cached_exps[0]} """ console.print(Markdown(md)) else: console.print( "[bold green]Batch prediction job submitted successfully ![/bold green]" ) console.print( f"The following experiments have been [bold]submitted[/bold] (see [bold]{metadata_file}[/bold] for the full list):" ) console.print(non_cached_exps) console.print( "For example, you can query an experiment status with the command:" ) md = f"""```shell folding experiment status {non_cached_exps[0]} """ console.print(Markdown(md)) if done_exps: console.print( f"The results of the following experiments [bold]were found in the cache[/bold] (see [bold]{metadata_file}[/bold] for the full list):" ) console.print(done_exps) console.print( "Use the `folding experiment results id` command to download the prediction results. For example:" ) md = f"""```shell folding experiment results {cached_exps[0]} """ console.print(Markdown(md))