File size: 829 Bytes
ab72d17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
from dataclasses import dataclass, asdict
from data_adv_dif import IntervalSplit
from config_adv_dif import Config


def save_model(path, model, tau_interval_split, alpha_interval_split, config):
    torch.save({
        'model_state_dict': model.state_dict(),
        'alpha_interval_split': asdict(alpha_interval_split),
        'tau_interval_split': asdict(tau_interval_split),
        'config': asdict(config),
    }, path)


def load_model(path, model):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    alpha_interval_split = IntervalSplit(**checkpoint['alpha_interval_split'])
    tau_interval_split = IntervalSplit(**checkpoint['tau_interval_split'])
    config = Config(**checkpoint['config'])
    return model, alpha_interval_split, tau_interval_split, config