|
"""Shared utils for the predict command.""" |
|
|
|
import json |
|
from datetime import datetime |
|
from pathlib import Path |
|
|
|
import typer |
|
from folding_studio_data_models import BatchPublication, MessageStatus, Publication |
|
from rich.markdown import Markdown |
|
|
|
from folding_studio.console import console |
|
from folding_studio.utils.data_model import BatchInputFile, SimpleInputFile |
|
|
|
|
|
def validate_source_path(path: Path) -> Path: |
|
"""Validate the prediction source path. |
|
|
|
Args: |
|
path (Path): Source path. |
|
|
|
Raises: |
|
typer.BadParameter: If the source is an empty directory. |
|
typer.BadParameter: If the source is a directory containing unsupported files. |
|
typer.BadParameter: If the source is an unsupported file. |
|
|
|
Returns: |
|
Path: The source. |
|
""" |
|
supported_simple_prediction = tuple(item.value for item in SimpleInputFile) |
|
supported_batch_prediction = tuple(item.value for item in BatchInputFile) |
|
|
|
if path.is_dir(): |
|
if not any(path.iterdir()): |
|
raise typer.BadParameter(f"The source directory `{path}` is empty.") |
|
|
|
for file in path.iterdir(): |
|
if file.is_file(): |
|
if file.suffix not in supported_simple_prediction: |
|
raise typer.BadParameter( |
|
f"The source directory '{path}' contains unsupported files. " |
|
f"Only {supported_simple_prediction} files are supported." |
|
) |
|
|
|
elif path.suffix not in supported_simple_prediction + supported_batch_prediction: |
|
raise typer.BadParameter( |
|
f"The source file '{path}' is not supported. " |
|
f"Only {supported_simple_prediction + supported_batch_prediction} files are supported." |
|
) |
|
return path |
|
|
|
|
|
def validate_model_subset(model_subset: list[int]) -> list[int]: |
|
"""Validate the model_subset argument. |
|
|
|
Args: |
|
model_subset (list[int]): List of model subset requested. |
|
|
|
Raises: |
|
typer.BadParameter: If more than 5 model ids are specified. |
|
typer.BadParameter: If model ids not between 1 and 5 (included). |
|
|
|
Returns: |
|
list[int]: List of model subset requested. |
|
""" |
|
if len(model_subset) == 0: |
|
return model_subset |
|
elif len(model_subset) > 5: |
|
raise typer.BadParameter( |
|
f"--model_subset accept 5 model ids at most but `{len(model_subset)}` were specified." |
|
) |
|
elif min(model_subset) < 1 or max(model_subset) > 5: |
|
raise typer.BadParameter( |
|
"Model subset id out of supported range. --model_subset accepts ids between 1 and 5 (included)." |
|
) |
|
return model_subset |
|
|
|
|
|
def print_instructions_simple(response_json: dict, metadata_file: Path | None) -> None: |
|
"""Print pretty instructions after successful call to predict endpoint. |
|
|
|
Args: |
|
response_json (dict): Server json response |
|
metadata_file: (Path | None): File path where job submission metadata are written. |
|
""" |
|
pub = Publication.model_validate(response_json) |
|
experiment_id = pub.message.experiment_id |
|
|
|
if pub.status == MessageStatus.NOT_PUBLISHED_DONE: |
|
console.print( |
|
f"The results of your experiment {experiment_id} were found in the cache." |
|
) |
|
console.print("Use the following command to download the prediction results:") |
|
md = f"""```shell |
|
folding experiment results {experiment_id} |
|
""" |
|
console.print(Markdown(md)) |
|
elif pub.status == MessageStatus.NOT_PUBLISHED_PENDING: |
|
console.print( |
|
f"Your experiment [bold]{experiment_id}[/bold] is [bold green]still running.[/bold green]" |
|
) |
|
console.print( |
|
"Use the following command to check on its status at a later time." |
|
) |
|
md = f"""```shell |
|
folding experiment status {experiment_id} |
|
""" |
|
console.print(Markdown(md)) |
|
elif pub.status == MessageStatus.PUBLISHED: |
|
console.print("[bold green]Experiment submitted successfully ![/bold green]") |
|
console.print(f"The experiment id is [bold]{experiment_id}[/bold]") |
|
|
|
if not metadata_file: |
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") |
|
metadata_file = f"simple_prediction_{timestamp}.json" |
|
with open(metadata_file, "w") as f: |
|
json.dump(response_json, f, indent=4) |
|
|
|
console.print( |
|
f"Prediction job metadata written to [bold]{metadata_file}[/bold]" |
|
) |
|
console.print("You can query your experiment status with the command:") |
|
md = f"""```shell |
|
folding experiment status {experiment_id} |
|
""" |
|
console.print(Markdown(md)) |
|
else: |
|
raise ValueError(f"Unknown publication status: {pub.status}") |
|
|
|
|
|
def print_instructions_batch(response_json: dict, metadata_file: Path | None) -> None: |
|
"""Print pretty instructions after successful call to batch predict endpoint. |
|
|
|
Args: |
|
response_json (dict): Server json response |
|
metadata_file: (Path | None): File path where job submission metadata are written. |
|
""" |
|
pub = BatchPublication.model_validate(response_json) |
|
non_cached_exps = [ |
|
non_cached_pub.message.experiment_id for non_cached_pub in pub.publications |
|
] |
|
cached_exps = [ |
|
cached_pub.message.experiment_id for cached_pub in pub.cached_publications |
|
] |
|
done_exps = [ |
|
cached_pub.message.experiment_id |
|
for cached_pub in pub.cached_publications |
|
if cached_pub.status == MessageStatus.NOT_PUBLISHED_DONE |
|
] |
|
|
|
if not metadata_file: |
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") |
|
metadata_file = f"batch_prediction_{timestamp}.json" |
|
with open(metadata_file, "w") as f: |
|
json.dump(response_json, f, indent=4) |
|
console.print(f"Batch prediction job metadata written to {metadata_file}") |
|
console.print("This file contains your experiments ids.") |
|
|
|
if pub.cached: |
|
console.print( |
|
"The results of your experiments were [bold]all found in the cache.[/bold]" |
|
) |
|
console.print("The experiment ids are:") |
|
console.print(f"{cached_exps}") |
|
console.print( |
|
"Use the `folding experiment status id` command to check on their status. For example:" |
|
) |
|
md = f"""```shell |
|
folding experiment status {cached_exps[0]} |
|
""" |
|
console.print(Markdown(md)) |
|
else: |
|
console.print( |
|
"[bold green]Batch prediction job submitted successfully ![/bold green]" |
|
) |
|
|
|
console.print( |
|
f"The following experiments have been [bold]submitted[/bold] (see [bold]{metadata_file}[/bold] for the full list):" |
|
) |
|
console.print(non_cached_exps) |
|
console.print( |
|
"For example, you can query an experiment status with the command:" |
|
) |
|
md = f"""```shell |
|
folding experiment status {non_cached_exps[0]} |
|
""" |
|
console.print(Markdown(md)) |
|
|
|
if done_exps: |
|
console.print( |
|
f"The results of the following experiments [bold]were found in the cache[/bold] (see [bold]{metadata_file}[/bold] for the full list):" |
|
) |
|
console.print(done_exps) |
|
console.print( |
|
"Use the `folding experiment results id` command to download the prediction results. For example:" |
|
) |
|
md = f"""```shell |
|
folding experiment results {cached_exps[0]} |
|
""" |
|
console.print(Markdown(md)) |
|
|