Spaces:
Running
Running
load to cpu
Browse files
src/rstor/analyzis/interactive/model_selection.py
CHANGED
@@ -26,7 +26,7 @@ def model_selector(models_dict: dict, global_params={}, model_name="vanilla"):
|
|
26 |
def get_model_from_exp(exp: int, model_storage: Path = MODELS_PATH, device=DEVICE) -> Tuple[torch.nn.Module, dict]:
|
27 |
config = get_experiment_config(exp)
|
28 |
model, _, _ = get_training_content(config, training_mode=False)
|
29 |
-
model_path = torch.load(model_storage/f"{exp:04d}"/"best_model.pt")
|
30 |
assert model_path is not None, f"Model {exp} not found"
|
31 |
model.load_state_dict(model_path)
|
32 |
model = model.to(device)
|
|
|
26 |
def get_model_from_exp(exp: int, model_storage: Path = MODELS_PATH, device=DEVICE) -> Tuple[torch.nn.Module, dict]:
|
27 |
config = get_experiment_config(exp)
|
28 |
model, _, _ = get_training_content(config, training_mode=False)
|
29 |
+
model_path = torch.load(model_storage/f"{exp:04d}"/"best_model.pt", map_location=device)
|
30 |
assert model_path is not None, f"Model {exp} not found"
|
31 |
model.load_state_dict(model_path)
|
32 |
model = model.to(device)
|