balthou's picture
add large model
0ed77e1
import torch
from interactive_pipe import KeyboardControl
from rstor.learning.experiments import get_training_content
from rstor.learning.experiments_definition import get_experiment_config
from rstor.properties import DEVICE, PRETTY_NAME
from tqdm import tqdm
from pathlib import Path
from typing import List, Tuple
from interactive_pipe import interactive
MODELS_PATH = Path("scripts")/"output"
def model_selector(models_dict: dict, global_params={}, model_name="vanilla"):
if isinstance(model_name, str):
current_model = models_dict[model_name]
elif isinstance(model_name, int):
model_names = [name for name in models_dict.keys()]
current_model = models_dict[model_names[model_name % len(model_names)]]
else:
raise ValueError(f"Model name {model_name} not understood")
global_params["model_config"] = current_model["config"]
return current_model["model"]
def get_model_from_exp(exp: int, model_storage: Path = MODELS_PATH, device=DEVICE) -> Tuple[torch.nn.Module, dict]:
config = get_experiment_config(exp)
model, _, _ = get_training_content(config, training_mode=False)
model_path = torch.load(model_storage/f"{exp:04d}"/"best_model.pt", map_location=device)
assert model_path is not None, f"Model {exp} not found"
model.load_state_dict(model_path)
model = model.to(device)
return model, config
def get_default_models(
exp_list: List[int] = [1000, 1001],
model_storage: Path = MODELS_PATH,
keyboard_control: bool = False,
interactive_flag: bool = True
) -> dict:
model_dict = {}
assert model_storage.exists(), f"Model storage {model_storage} does not exist"
for exp in tqdm(exp_list, desc="Loading models"):
model, config = get_model_from_exp(exp, model_storage=model_storage)
name = config.get(PRETTY_NAME, f"{exp:04d}")
model_dict[name] = {
"model": model,
"config": config
}
exp_names = [name for name in model_dict.keys()]
if interactive_flag:
if keyboard_control:
model_control = KeyboardControl(0, [0, len(exp_names)-1], keydown="pagedown", keyup="pageup", modulo=True)
else:
model_control = (exp_names[0], exp_names)
interactive(model_name=model_control)(model_selector) # Create the model dialog
return model_dict