Spaces:
Sleeping
Sleeping
File size: 829 Bytes
ab72d17 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch
from dataclasses import dataclass, asdict
from data_adv_dif import IntervalSplit
from config_adv_dif import Config
def save_model(path, model, tau_interval_split, alpha_interval_split, config):
torch.save({
'model_state_dict': model.state_dict(),
'alpha_interval_split': asdict(alpha_interval_split),
'tau_interval_split': asdict(tau_interval_split),
'config': asdict(config),
}, path)
def load_model(path, model):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
alpha_interval_split = IntervalSplit(**checkpoint['alpha_interval_split'])
tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
config = Config(**checkpoint['config'])
return model, alpha_interval_split, tau_interval_split, config
|