File size: 6,889 Bytes
801501a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
from ..custom_types import *
from .. import constants
from tqdm import tqdm
from . import files_utils
import os
from .. import options
from ..models import models_utils, occ_gmm


LI = Union[T, float, int]
Models = {'spaghetti': occ_gmm.Spaghetti}


def is_model_clean(model: nn.Module) -> bool:
    for wh in model.parameters():
        if torch.isnan(wh).sum() > 0:
            return False
    return True


def model_factory(opt: options.Options, override_model: Optional[str], device: D) -> models_utils.Model:
    if override_model is None:
        return Models[opt.model_name](opt).to(device)
    return Models[override_model](opt).to(device)


def load_model(opt, device, suffix: str = '', override_model: Optional[str] = None) -> models_utils.Model:
    model_path = f'{opt.cp_folder}/model{"_" + suffix if suffix else ""}'
    model = model_factory(opt, override_model, device)
    name = opt.model_name if override_model is None else override_model
    if os.path.isfile(model_path):
        print(f'loading {name} model from {model_path}')
        model.load_state_dict(torch.load(model_path, map_location=device))
    else:
        print(f'init {name} model')
    return model


def save_model(model, path):
    if constants.DEBUG:
        return False
    print(f'saving model in {path}')
    torch.save(model.state_dict(), path)
    return True


def model_lc(opt: options.Options, override_model: Optional[str] = None) -> Tuple[occ_gmm.Spaghetti, options.Options]:

    def save_model(model_: models_utils.Model, suffix: str = ''):
        nonlocal already_init
        if override_model is not None and suffix == '':
            suffix = override_model
        model_path = f'{opt.cp_folder}/model{"_" + suffix if suffix else ""}'
        if constants.DEBUG or 'debug' in opt.tag:
            return False
        if not already_init:
            files_utils.init_folders(model_path)
            files_utils.save_pickle(opt, params_path)
            already_init = True
        if is_model_clean(model_):
            print(f'saving {opt.model_name} model at {model_path}')
            torch.save(model_.state_dict(), model_path)
        elif os.path.isfile(model_path):
            print(f'model is corrupted')
            print(f'loading {opt.model_name} model from {model_path}')
            model.load_state_dict(torch.load(model_path, map_location=opt.device))
        return True

    already_init = False
    params_path = f'{opt.cp_folder}/options.pkl'
    opt_ = files_utils.load_pickle(params_path)

    if opt_ is not None:
        opt_.device = opt.device
        opt = opt_
        already_init = True
    model = load_model(opt, opt.device, override_model=override_model)
    model.save_model = save_model
    return model, opt


class Logger:

    def __init__(self, level: int = 0):
        self.level_dictionary = dict()
        self.iter_dictionary = dict()
        self.level = level
        self.progress: Union[N, tqdm] = None
        self.iters = 0
        self.tag = ''

    @staticmethod
    def aggregate(dictionary: dict, parent_dictionary: Union[dict, N] = None) -> dict:
        aggregate_dictionary = dict()
        for key in dictionary:
            if 'counter' not in key:
                aggregate_dictionary[key] = dictionary[key] / float(dictionary[f"{key}_counter"])
                if parent_dictionary is not None:
                    Logger.stash(parent_dictionary, (key,  aggregate_dictionary[key]))
        return aggregate_dictionary

    @staticmethod
    def flatten(items: Tuple[Union[Dict[str, LI], str, LI], ...]) -> List[Union[str, LI]]:
        flat_items = []
        for item in items:
            if type(item) is dict:
                for key, value in item.items():
                    flat_items.append(key)
                    flat_items.append(value)
            else:
                flat_items.append(item)
        return flat_items

    @staticmethod
    def stash(dictionary: Dict[str, LI], items: Tuple[Union[Dict[str, LI], str, LI], ...]) -> Dict[str, LI]:
        flat_items = Logger.flatten(items)
        for i in range(0, len(flat_items), 2):
            key, item = flat_items[i], flat_items[i + 1]
            if type(item) is T:
                item = item.item()
            if key not in dictionary:
                dictionary[key] = 0
                dictionary[f"{key}_counter"] = 0
            dictionary[key] += item
            dictionary[f"{key}_counter"] += 1
        return dictionary

    def stash_iter(self, *items: Union[Dict[str, LI], str, LI]):
        self.iter_dictionary = self.stash(self.iter_dictionary, items)
        return self

    def stash_level(self, *items: Union[Dict[str, LI], str, LI]):
        self.level_dictionary = self.stash(self.level_dictionary, items)

    def reset_iter(self, *items: Union[Dict[str, LI], str, LI]):
        if len(items) > 0:
            self.stash_iter(*items)
        aggregate_dictionary = self.aggregate(self.iter_dictionary, self.level_dictionary)
        self.progress.set_postfix(aggregate_dictionary)
        self.progress.update()
        self.iter_dictionary = dict()
        return self

    def start(self, iters: int, tag: str = ''):
        if self.progress is not None:
            self.stop()
        if iters < 0:
            iters = self.iters
        if tag == '':
            tag = self.tag
        self.iters, self.tag = iters, tag
        self.progress = tqdm(total=self.iters, desc=f'{self.tag} {self.level}')
        return self

    def stop(self, aggregate: bool = True):
        if aggregate:
            aggregate_dictionary = self.aggregate(self.level_dictionary)
            self.progress.set_postfix(aggregate_dictionary)
        self.level_dictionary = dict()
        self.progress.close()
        self.progress = None
        self.level += 1
        return aggregate_dictionary

    def reset_level(self, aggregate: bool = True):
        self.stop(aggregate)
        self.start()


class LinearWarmupScheduler:

    def get_lr(self):
        if self.cur_iter >= self.num_iters:
            return [self.target_lr] * len(self.base_lrs)
        alpha = self.cur_iter / self.num_iters
        return [base_lr + delta_lr * alpha for base_lr, delta_lr in zip(self.base_lrs, self.delta_lrs)]

    def step(self):
        if not self.finished:
            for group, lr in zip(self.optimizer.param_groups,  self.get_lr()):
                group['lr'] = lr
            self.cur_iter += 1.
            self.finished = self.cur_iter > self.num_iters

    def __init__(self, optimizer, target_lr, num_iters):
        self.cur_iter = 0.
        self.target_lr = target_lr
        self.num_iters = num_iters
        self.finished = False
        self.optimizer = optimizer
        self.base_lrs = [group['lr'] for group in optimizer.param_groups]
        self.delta_lrs = [target_lr - base_lr for base_lr in self.base_lrs]