balthou commited on
Commit
bd92962
·
1 Parent(s): f51f8b4

load to cpu

Browse files
src/rstor/analyzis/interactive/model_selection.py CHANGED
@@ -26,7 +26,7 @@ def model_selector(models_dict: dict, global_params={}, model_name="vanilla"):
26
  def get_model_from_exp(exp: int, model_storage: Path = MODELS_PATH, device=DEVICE) -> Tuple[torch.nn.Module, dict]:
27
  config = get_experiment_config(exp)
28
  model, _, _ = get_training_content(config, training_mode=False)
29
- model_path = torch.load(model_storage/f"{exp:04d}"/"best_model.pt")
30
  assert model_path is not None, f"Model {exp} not found"
31
  model.load_state_dict(model_path)
32
  model = model.to(device)
 
26
  def get_model_from_exp(exp: int, model_storage: Path = MODELS_PATH, device=DEVICE) -> Tuple[torch.nn.Module, dict]:
27
  config = get_experiment_config(exp)
28
  model, _, _ = get_training_content(config, training_mode=False)
29
+ model_path = torch.load(model_storage/f"{exp:04d}"/"best_model.pt", map_location=device)
30
  assert model_path is not None, f"Model {exp} not found"
31
  model.load_state_dict(model_path)
32
  model = model.to(device)