Flexi-Propagator / .ipynb_checkpoints /model_io_adv_dif-checkpoint.py
Khalid Rafiq
Fix: Map Advection-Diffusion model checkpoint to CPU
1c9f376
raw
history blame contribute delete
849 Bytes
import torch
from dataclasses import dataclass, asdict
from data_adv_dif import IntervalSplit
from config_adv_dif import Config
def save_model(path, model, tau_interval_split, alpha_interval_split, config):
torch.save({
'model_state_dict': model.state_dict(),
'alpha_interval_split': asdict(alpha_interval_split),
'tau_interval_split': asdict(tau_interval_split),
'config': asdict(config),
}, path)
def load_model(path, model):
checkpoint = torch.load(path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
alpha_interval_split = IntervalSplit(**checkpoint['alpha_interval_split'])
tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
config = Config(**checkpoint['config'])
return model, alpha_interval_split, tau_interval_split, config