jfaustin's picture
add dockerfile and folding studio cli
44459bb
raw
history blame
7.53 kB
"""Shared utils for the predict command."""
import json
from datetime import datetime
from pathlib import Path
import typer
from folding_studio_data_models import BatchPublication, MessageStatus, Publication
from rich.markdown import Markdown
from folding_studio.console import console
from folding_studio.utils.data_model import BatchInputFile, SimpleInputFile
def validate_source_path(path: Path) -> Path:
"""Validate the prediction source path.
Args:
path (Path): Source path.
Raises:
typer.BadParameter: If the source is an empty directory.
typer.BadParameter: If the source is a directory containing unsupported files.
typer.BadParameter: If the source is an unsupported file.
Returns:
Path: The source.
"""
supported_simple_prediction = tuple(item.value for item in SimpleInputFile)
supported_batch_prediction = tuple(item.value for item in BatchInputFile)
if path.is_dir():
if not any(path.iterdir()):
raise typer.BadParameter(f"The source directory `{path}` is empty.")
for file in path.iterdir():
if file.is_file():
if file.suffix not in supported_simple_prediction:
raise typer.BadParameter(
f"The source directory '{path}' contains unsupported files. "
f"Only {supported_simple_prediction} files are supported."
)
elif path.suffix not in supported_simple_prediction + supported_batch_prediction:
raise typer.BadParameter(
f"The source file '{path}' is not supported. "
f"Only {supported_simple_prediction + supported_batch_prediction} files are supported."
)
return path
def validate_model_subset(model_subset: list[int]) -> list[int]:
"""Validate the model_subset argument.
Args:
model_subset (list[int]): List of model subset requested.
Raises:
typer.BadParameter: If more than 5 model ids are specified.
typer.BadParameter: If model ids not between 1 and 5 (included).
Returns:
list[int]: List of model subset requested.
"""
if len(model_subset) == 0:
return model_subset
elif len(model_subset) > 5:
raise typer.BadParameter(
f"--model_subset accept 5 model ids at most but `{len(model_subset)}` were specified."
)
elif min(model_subset) < 1 or max(model_subset) > 5:
raise typer.BadParameter(
"Model subset id out of supported range. --model_subset accepts ids between 1 and 5 (included)."
)
return model_subset
def print_instructions_simple(response_json: dict, metadata_file: Path | None) -> None:
"""Print pretty instructions after successful call to predict endpoint.
Args:
response_json (dict): Server json response
metadata_file: (Path | None): File path where job submission metadata are written.
"""
pub = Publication.model_validate(response_json)
experiment_id = pub.message.experiment_id
if pub.status == MessageStatus.NOT_PUBLISHED_DONE:
console.print(
f"The results of your experiment {experiment_id} were found in the cache."
)
console.print("Use the following command to download the prediction results:")
md = f"""```shell
folding experiment results {experiment_id}
"""
console.print(Markdown(md))
elif pub.status == MessageStatus.NOT_PUBLISHED_PENDING:
console.print(
f"Your experiment [bold]{experiment_id}[/bold] is [bold green]still running.[/bold green]"
)
console.print(
"Use the following command to check on its status at a later time."
)
md = f"""```shell
folding experiment status {experiment_id}
"""
console.print(Markdown(md))
elif pub.status == MessageStatus.PUBLISHED:
console.print("[bold green]Experiment submitted successfully ![/bold green]")
console.print(f"The experiment id is [bold]{experiment_id}[/bold]")
if not metadata_file:
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
metadata_file = f"simple_prediction_{timestamp}.json"
with open(metadata_file, "w") as f:
json.dump(response_json, f, indent=4)
console.print(
f"Prediction job metadata written to [bold]{metadata_file}[/bold]"
)
console.print("You can query your experiment status with the command:")
md = f"""```shell
folding experiment status {experiment_id}
"""
console.print(Markdown(md))
else:
raise ValueError(f"Unknown publication status: {pub.status}")
def print_instructions_batch(response_json: dict, metadata_file: Path | None) -> None:
"""Print pretty instructions after successful call to batch predict endpoint.
Args:
response_json (dict): Server json response
metadata_file: (Path | None): File path where job submission metadata are written.
"""
pub = BatchPublication.model_validate(response_json)
non_cached_exps = [
non_cached_pub.message.experiment_id for non_cached_pub in pub.publications
]
cached_exps = [
cached_pub.message.experiment_id for cached_pub in pub.cached_publications
]
done_exps = [
cached_pub.message.experiment_id
for cached_pub in pub.cached_publications
if cached_pub.status == MessageStatus.NOT_PUBLISHED_DONE
]
if not metadata_file:
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
metadata_file = f"batch_prediction_{timestamp}.json"
with open(metadata_file, "w") as f:
json.dump(response_json, f, indent=4)
console.print(f"Batch prediction job metadata written to {metadata_file}")
console.print("This file contains your experiments ids.")
if pub.cached:
console.print(
"The results of your experiments were [bold]all found in the cache.[/bold]"
)
console.print("The experiment ids are:")
console.print(f"{cached_exps}")
console.print(
"Use the `folding experiment status id` command to check on their status. For example:"
)
md = f"""```shell
folding experiment status {cached_exps[0]}
"""
console.print(Markdown(md))
else:
console.print(
"[bold green]Batch prediction job submitted successfully ![/bold green]"
)
console.print(
f"The following experiments have been [bold]submitted[/bold] (see [bold]{metadata_file}[/bold] for the full list):"
)
console.print(non_cached_exps)
console.print(
"For example, you can query an experiment status with the command:"
)
md = f"""```shell
folding experiment status {non_cached_exps[0]}
"""
console.print(Markdown(md))
if done_exps:
console.print(
f"The results of the following experiments [bold]were found in the cache[/bold] (see [bold]{metadata_file}[/bold] for the full list):"
)
console.print(done_exps)
console.print(
"Use the `folding experiment results id` command to download the prediction results. For example:"
)
md = f"""```shell
folding experiment results {cached_exps[0]}
"""
console.print(Markdown(md))