Flexi-Propagator / model_io_adv_dif.py
Khalid Rafiq
Add all required modules and requirements.txt
ab72d17
raw
history blame
829 Bytes
import torch
from dataclasses import dataclass, asdict
from data_adv_dif import IntervalSplit
from config_adv_dif import Config
def save_model(path, model, tau_interval_split, alpha_interval_split, config):
torch.save({
'model_state_dict': model.state_dict(),
'alpha_interval_split': asdict(alpha_interval_split),
'tau_interval_split': asdict(tau_interval_split),
'config': asdict(config),
}, path)
def load_model(path, model):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
alpha_interval_split = IntervalSplit(**checkpoint['alpha_interval_split'])
tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
config = Config(**checkpoint['config'])
return model, alpha_interval_split, tau_interval_split, config