Spaces:
Running
Running
File size: 2,358 Bytes
cec5823 0ed77e1 cec5823 bd92962 cec5823 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import torch
from interactive_pipe import KeyboardControl
from rstor.learning.experiments import get_training_content
from rstor.learning.experiments_definition import get_experiment_config
from rstor.properties import DEVICE, PRETTY_NAME
from tqdm import tqdm
from pathlib import Path
from typing import List, Tuple
from interactive_pipe import interactive
MODELS_PATH = Path("scripts")/"output"
def model_selector(models_dict: dict, global_params={}, model_name="vanilla"):
if isinstance(model_name, str):
current_model = models_dict[model_name]
elif isinstance(model_name, int):
model_names = [name for name in models_dict.keys()]
current_model = models_dict[model_names[model_name % len(model_names)]]
else:
raise ValueError(f"Model name {model_name} not understood")
global_params["model_config"] = current_model["config"]
return current_model["model"]
def get_model_from_exp(exp: int, model_storage: Path = MODELS_PATH, device=DEVICE) -> Tuple[torch.nn.Module, dict]:
config = get_experiment_config(exp)
model, _, _ = get_training_content(config, training_mode=False)
model_path = torch.load(model_storage/f"{exp:04d}"/"best_model.pt", map_location=device)
assert model_path is not None, f"Model {exp} not found"
model.load_state_dict(model_path)
model = model.to(device)
return model, config
def get_default_models(
exp_list: List[int] = [1000, 1001],
model_storage: Path = MODELS_PATH,
keyboard_control: bool = False,
interactive_flag: bool = True
) -> dict:
model_dict = {}
assert model_storage.exists(), f"Model storage {model_storage} does not exist"
for exp in tqdm(exp_list, desc="Loading models"):
model, config = get_model_from_exp(exp, model_storage=model_storage)
name = config.get(PRETTY_NAME, f"{exp:04d}")
model_dict[name] = {
"model": model,
"config": config
}
exp_names = [name for name in model_dict.keys()]
if interactive_flag:
if keyboard_control:
model_control = KeyboardControl(0, [0, len(exp_names)-1], keydown="pagedown", keyup="pageup", modulo=True)
else:
model_control = (exp_names[0], exp_names)
interactive(model_name=model_control)(model_selector) # Create the model dialog
return model_dict
|