Spaces:
Sleeping
Sleeping
File size: 6,889 Bytes
801501a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
from ..custom_types import *
from .. import constants
from tqdm import tqdm
from . import files_utils
import os
from .. import options
from ..models import models_utils, occ_gmm
LI = Union[T, float, int]
Models = {'spaghetti': occ_gmm.Spaghetti}
def is_model_clean(model: nn.Module) -> bool:
for wh in model.parameters():
if torch.isnan(wh).sum() > 0:
return False
return True
def model_factory(opt: options.Options, override_model: Optional[str], device: D) -> models_utils.Model:
if override_model is None:
return Models[opt.model_name](opt).to(device)
return Models[override_model](opt).to(device)
def load_model(opt, device, suffix: str = '', override_model: Optional[str] = None) -> models_utils.Model:
model_path = f'{opt.cp_folder}/model{"_" + suffix if suffix else ""}'
model = model_factory(opt, override_model, device)
name = opt.model_name if override_model is None else override_model
if os.path.isfile(model_path):
print(f'loading {name} model from {model_path}')
model.load_state_dict(torch.load(model_path, map_location=device))
else:
print(f'init {name} model')
return model
def save_model(model, path):
if constants.DEBUG:
return False
print(f'saving model in {path}')
torch.save(model.state_dict(), path)
return True
def model_lc(opt: options.Options, override_model: Optional[str] = None) -> Tuple[occ_gmm.Spaghetti, options.Options]:
def save_model(model_: models_utils.Model, suffix: str = ''):
nonlocal already_init
if override_model is not None and suffix == '':
suffix = override_model
model_path = f'{opt.cp_folder}/model{"_" + suffix if suffix else ""}'
if constants.DEBUG or 'debug' in opt.tag:
return False
if not already_init:
files_utils.init_folders(model_path)
files_utils.save_pickle(opt, params_path)
already_init = True
if is_model_clean(model_):
print(f'saving {opt.model_name} model at {model_path}')
torch.save(model_.state_dict(), model_path)
elif os.path.isfile(model_path):
print(f'model is corrupted')
print(f'loading {opt.model_name} model from {model_path}')
model.load_state_dict(torch.load(model_path, map_location=opt.device))
return True
already_init = False
params_path = f'{opt.cp_folder}/options.pkl'
opt_ = files_utils.load_pickle(params_path)
if opt_ is not None:
opt_.device = opt.device
opt = opt_
already_init = True
model = load_model(opt, opt.device, override_model=override_model)
model.save_model = save_model
return model, opt
class Logger:
def __init__(self, level: int = 0):
self.level_dictionary = dict()
self.iter_dictionary = dict()
self.level = level
self.progress: Union[N, tqdm] = None
self.iters = 0
self.tag = ''
@staticmethod
def aggregate(dictionary: dict, parent_dictionary: Union[dict, N] = None) -> dict:
aggregate_dictionary = dict()
for key in dictionary:
if 'counter' not in key:
aggregate_dictionary[key] = dictionary[key] / float(dictionary[f"{key}_counter"])
if parent_dictionary is not None:
Logger.stash(parent_dictionary, (key, aggregate_dictionary[key]))
return aggregate_dictionary
@staticmethod
def flatten(items: Tuple[Union[Dict[str, LI], str, LI], ...]) -> List[Union[str, LI]]:
flat_items = []
for item in items:
if type(item) is dict:
for key, value in item.items():
flat_items.append(key)
flat_items.append(value)
else:
flat_items.append(item)
return flat_items
@staticmethod
def stash(dictionary: Dict[str, LI], items: Tuple[Union[Dict[str, LI], str, LI], ...]) -> Dict[str, LI]:
flat_items = Logger.flatten(items)
for i in range(0, len(flat_items), 2):
key, item = flat_items[i], flat_items[i + 1]
if type(item) is T:
item = item.item()
if key not in dictionary:
dictionary[key] = 0
dictionary[f"{key}_counter"] = 0
dictionary[key] += item
dictionary[f"{key}_counter"] += 1
return dictionary
def stash_iter(self, *items: Union[Dict[str, LI], str, LI]):
self.iter_dictionary = self.stash(self.iter_dictionary, items)
return self
def stash_level(self, *items: Union[Dict[str, LI], str, LI]):
self.level_dictionary = self.stash(self.level_dictionary, items)
def reset_iter(self, *items: Union[Dict[str, LI], str, LI]):
if len(items) > 0:
self.stash_iter(*items)
aggregate_dictionary = self.aggregate(self.iter_dictionary, self.level_dictionary)
self.progress.set_postfix(aggregate_dictionary)
self.progress.update()
self.iter_dictionary = dict()
return self
def start(self, iters: int, tag: str = ''):
if self.progress is not None:
self.stop()
if iters < 0:
iters = self.iters
if tag == '':
tag = self.tag
self.iters, self.tag = iters, tag
self.progress = tqdm(total=self.iters, desc=f'{self.tag} {self.level}')
return self
def stop(self, aggregate: bool = True):
if aggregate:
aggregate_dictionary = self.aggregate(self.level_dictionary)
self.progress.set_postfix(aggregate_dictionary)
self.level_dictionary = dict()
self.progress.close()
self.progress = None
self.level += 1
return aggregate_dictionary
def reset_level(self, aggregate: bool = True):
self.stop(aggregate)
self.start()
class LinearWarmupScheduler:
def get_lr(self):
if self.cur_iter >= self.num_iters:
return [self.target_lr] * len(self.base_lrs)
alpha = self.cur_iter / self.num_iters
return [base_lr + delta_lr * alpha for base_lr, delta_lr in zip(self.base_lrs, self.delta_lrs)]
def step(self):
if not self.finished:
for group, lr in zip(self.optimizer.param_groups, self.get_lr()):
group['lr'] = lr
self.cur_iter += 1.
self.finished = self.cur_iter > self.num_iters
def __init__(self, optimizer, target_lr, num_iters):
self.cur_iter = 0.
self.target_lr = target_lr
self.num_iters = num_iters
self.finished = False
self.optimizer = optimizer
self.base_lrs = [group['lr'] for group in optimizer.param_groups]
self.delta_lrs = [target_lr - base_lr for base_lr in self.base_lrs]
|