File size: 7,527 Bytes
44459bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
"""Shared utils for the predict command."""
import json
from datetime import datetime
from pathlib import Path
import typer
from folding_studio_data_models import BatchPublication, MessageStatus, Publication
from rich.markdown import Markdown
from folding_studio.console import console
from folding_studio.utils.data_model import BatchInputFile, SimpleInputFile
def validate_source_path(path: Path) -> Path:
"""Validate the prediction source path.
Args:
path (Path): Source path.
Raises:
typer.BadParameter: If the source is an empty directory.
typer.BadParameter: If the source is a directory containing unsupported files.
typer.BadParameter: If the source is an unsupported file.
Returns:
Path: The source.
"""
supported_simple_prediction = tuple(item.value for item in SimpleInputFile)
supported_batch_prediction = tuple(item.value for item in BatchInputFile)
if path.is_dir():
if not any(path.iterdir()):
raise typer.BadParameter(f"The source directory `{path}` is empty.")
for file in path.iterdir():
if file.is_file():
if file.suffix not in supported_simple_prediction:
raise typer.BadParameter(
f"The source directory '{path}' contains unsupported files. "
f"Only {supported_simple_prediction} files are supported."
)
elif path.suffix not in supported_simple_prediction + supported_batch_prediction:
raise typer.BadParameter(
f"The source file '{path}' is not supported. "
f"Only {supported_simple_prediction + supported_batch_prediction} files are supported."
)
return path
def validate_model_subset(model_subset: list[int]) -> list[int]:
"""Validate the model_subset argument.
Args:
model_subset (list[int]): List of model subset requested.
Raises:
typer.BadParameter: If more than 5 model ids are specified.
typer.BadParameter: If model ids not between 1 and 5 (included).
Returns:
list[int]: List of model subset requested.
"""
if len(model_subset) == 0:
return model_subset
elif len(model_subset) > 5:
raise typer.BadParameter(
f"--model_subset accept 5 model ids at most but `{len(model_subset)}` were specified."
)
elif min(model_subset) < 1 or max(model_subset) > 5:
raise typer.BadParameter(
"Model subset id out of supported range. --model_subset accepts ids between 1 and 5 (included)."
)
return model_subset
def print_instructions_simple(response_json: dict, metadata_file: Path | None) -> None:
"""Print pretty instructions after successful call to predict endpoint.
Args:
response_json (dict): Server json response
metadata_file: (Path | None): File path where job submission metadata are written.
"""
pub = Publication.model_validate(response_json)
experiment_id = pub.message.experiment_id
if pub.status == MessageStatus.NOT_PUBLISHED_DONE:
console.print(
f"The results of your experiment {experiment_id} were found in the cache."
)
console.print("Use the following command to download the prediction results:")
md = f"""```shell
folding experiment results {experiment_id}
"""
console.print(Markdown(md))
elif pub.status == MessageStatus.NOT_PUBLISHED_PENDING:
console.print(
f"Your experiment [bold]{experiment_id}[/bold] is [bold green]still running.[/bold green]"
)
console.print(
"Use the following command to check on its status at a later time."
)
md = f"""```shell
folding experiment status {experiment_id}
"""
console.print(Markdown(md))
elif pub.status == MessageStatus.PUBLISHED:
console.print("[bold green]Experiment submitted successfully ![/bold green]")
console.print(f"The experiment id is [bold]{experiment_id}[/bold]")
if not metadata_file:
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
metadata_file = f"simple_prediction_{timestamp}.json"
with open(metadata_file, "w") as f:
json.dump(response_json, f, indent=4)
console.print(
f"Prediction job metadata written to [bold]{metadata_file}[/bold]"
)
console.print("You can query your experiment status with the command:")
md = f"""```shell
folding experiment status {experiment_id}
"""
console.print(Markdown(md))
else:
raise ValueError(f"Unknown publication status: {pub.status}")
def print_instructions_batch(response_json: dict, metadata_file: Path | None) -> None:
"""Print pretty instructions after successful call to batch predict endpoint.
Args:
response_json (dict): Server json response
metadata_file: (Path | None): File path where job submission metadata are written.
"""
pub = BatchPublication.model_validate(response_json)
non_cached_exps = [
non_cached_pub.message.experiment_id for non_cached_pub in pub.publications
]
cached_exps = [
cached_pub.message.experiment_id for cached_pub in pub.cached_publications
]
done_exps = [
cached_pub.message.experiment_id
for cached_pub in pub.cached_publications
if cached_pub.status == MessageStatus.NOT_PUBLISHED_DONE
]
if not metadata_file:
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
metadata_file = f"batch_prediction_{timestamp}.json"
with open(metadata_file, "w") as f:
json.dump(response_json, f, indent=4)
console.print(f"Batch prediction job metadata written to {metadata_file}")
console.print("This file contains your experiments ids.")
if pub.cached:
console.print(
"The results of your experiments were [bold]all found in the cache.[/bold]"
)
console.print("The experiment ids are:")
console.print(f"{cached_exps}")
console.print(
"Use the `folding experiment status id` command to check on their status. For example:"
)
md = f"""```shell
folding experiment status {cached_exps[0]}
"""
console.print(Markdown(md))
else:
console.print(
"[bold green]Batch prediction job submitted successfully ![/bold green]"
)
console.print(
f"The following experiments have been [bold]submitted[/bold] (see [bold]{metadata_file}[/bold] for the full list):"
)
console.print(non_cached_exps)
console.print(
"For example, you can query an experiment status with the command:"
)
md = f"""```shell
folding experiment status {non_cached_exps[0]}
"""
console.print(Markdown(md))
if done_exps:
console.print(
f"The results of the following experiments [bold]were found in the cache[/bold] (see [bold]{metadata_file}[/bold] for the full list):"
)
console.print(done_exps)
console.print(
"Use the `folding experiment results id` command to download the prediction results. For example:"
)
md = f"""```shell
folding experiment results {cached_exps[0]}
"""
console.print(Markdown(md))
|