from ..custom_types import * from .. import constants from tqdm import tqdm from . import files_utils import os from .. import options from ..models import models_utils, occ_gmm LI = Union[T, float, int] Models = {'spaghetti': occ_gmm.Spaghetti} def is_model_clean(model: nn.Module) -> bool: for wh in model.parameters(): if torch.isnan(wh).sum() > 0: return False return True def model_factory(opt: options.Options, override_model: Optional[str], device: D) -> models_utils.Model: if override_model is None: return Models[opt.model_name](opt).to(device) return Models[override_model](opt).to(device) def load_model(opt, device, suffix: str = '', override_model: Optional[str] = None) -> models_utils.Model: model_path = f'{opt.cp_folder}/model{"_" + suffix if suffix else ""}' model = model_factory(opt, override_model, device) name = opt.model_name if override_model is None else override_model if os.path.isfile(model_path): print(f'loading {name} model from {model_path}') model.load_state_dict(torch.load(model_path, map_location=device)) else: print(f'init {name} model') return model def save_model(model, path): if constants.DEBUG: return False print(f'saving model in {path}') torch.save(model.state_dict(), path) return True def model_lc(opt: options.Options, override_model: Optional[str] = None) -> Tuple[occ_gmm.Spaghetti, options.Options]: def save_model(model_: models_utils.Model, suffix: str = ''): nonlocal already_init if override_model is not None and suffix == '': suffix = override_model model_path = f'{opt.cp_folder}/model{"_" + suffix if suffix else ""}' if constants.DEBUG or 'debug' in opt.tag: return False if not already_init: files_utils.init_folders(model_path) files_utils.save_pickle(opt, params_path) already_init = True if is_model_clean(model_): print(f'saving {opt.model_name} model at {model_path}') torch.save(model_.state_dict(), model_path) elif os.path.isfile(model_path): print(f'model is corrupted') print(f'loading {opt.model_name} model from {model_path}') model.load_state_dict(torch.load(model_path, map_location=opt.device)) return True already_init = False params_path = f'{opt.cp_folder}/options.pkl' opt_ = files_utils.load_pickle(params_path) if opt_ is not None: opt_.device = opt.device opt = opt_ already_init = True model = load_model(opt, opt.device, override_model=override_model) model.save_model = save_model return model, opt class Logger: def __init__(self, level: int = 0): self.level_dictionary = dict() self.iter_dictionary = dict() self.level = level self.progress: Union[N, tqdm] = None self.iters = 0 self.tag = '' @staticmethod def aggregate(dictionary: dict, parent_dictionary: Union[dict, N] = None) -> dict: aggregate_dictionary = dict() for key in dictionary: if 'counter' not in key: aggregate_dictionary[key] = dictionary[key] / float(dictionary[f"{key}_counter"]) if parent_dictionary is not None: Logger.stash(parent_dictionary, (key, aggregate_dictionary[key])) return aggregate_dictionary @staticmethod def flatten(items: Tuple[Union[Dict[str, LI], str, LI], ...]) -> List[Union[str, LI]]: flat_items = [] for item in items: if type(item) is dict: for key, value in item.items(): flat_items.append(key) flat_items.append(value) else: flat_items.append(item) return flat_items @staticmethod def stash(dictionary: Dict[str, LI], items: Tuple[Union[Dict[str, LI], str, LI], ...]) -> Dict[str, LI]: flat_items = Logger.flatten(items) for i in range(0, len(flat_items), 2): key, item = flat_items[i], flat_items[i + 1] if type(item) is T: item = item.item() if key not in dictionary: dictionary[key] = 0 dictionary[f"{key}_counter"] = 0 dictionary[key] += item dictionary[f"{key}_counter"] += 1 return dictionary def stash_iter(self, *items: Union[Dict[str, LI], str, LI]): self.iter_dictionary = self.stash(self.iter_dictionary, items) return self def stash_level(self, *items: Union[Dict[str, LI], str, LI]): self.level_dictionary = self.stash(self.level_dictionary, items) def reset_iter(self, *items: Union[Dict[str, LI], str, LI]): if len(items) > 0: self.stash_iter(*items) aggregate_dictionary = self.aggregate(self.iter_dictionary, self.level_dictionary) self.progress.set_postfix(aggregate_dictionary) self.progress.update() self.iter_dictionary = dict() return self def start(self, iters: int, tag: str = ''): if self.progress is not None: self.stop() if iters < 0: iters = self.iters if tag == '': tag = self.tag self.iters, self.tag = iters, tag self.progress = tqdm(total=self.iters, desc=f'{self.tag} {self.level}') return self def stop(self, aggregate: bool = True): if aggregate: aggregate_dictionary = self.aggregate(self.level_dictionary) self.progress.set_postfix(aggregate_dictionary) self.level_dictionary = dict() self.progress.close() self.progress = None self.level += 1 return aggregate_dictionary def reset_level(self, aggregate: bool = True): self.stop(aggregate) self.start() class LinearWarmupScheduler: def get_lr(self): if self.cur_iter >= self.num_iters: return [self.target_lr] * len(self.base_lrs) alpha = self.cur_iter / self.num_iters return [base_lr + delta_lr * alpha for base_lr, delta_lr in zip(self.base_lrs, self.delta_lrs)] def step(self): if not self.finished: for group, lr in zip(self.optimizer.param_groups, self.get_lr()): group['lr'] = lr self.cur_iter += 1. self.finished = self.cur_iter > self.num_iters def __init__(self, optimizer, target_lr, num_iters): self.cur_iter = 0. self.target_lr = target_lr self.num_iters = num_iters self.finished = False self.optimizer = optimizer self.base_lrs = [group['lr'] for group in optimizer.param_groups] self.delta_lrs = [target_lr - base_lr for base_lr in self.base_lrs]