Spaces:
Running
Running
import os | |
import torch | |
import torch.nn as nn | |
from mono.utils.comm import main_process | |
import copy | |
import inspect | |
import logging | |
import glob | |
class LrUpdater(): | |
"""Refer to LR Scheduler in MMCV. | |
Args: | |
@by_epoch (bool): LR changes epoch by epoch | |
@warmup (string): Type of warmup used. It can be None(use no warmup), | |
'constant', 'linear' or 'exp' | |
@warmup_iters (int): The number of iterations or epochs that warmup | |
lasts. Note when by_epoch == True, warmup_iters means the number | |
of epochs that warmup lasts, otherwise means the number of | |
iteration that warmup lasts | |
@warmup_ratio (float): LR used at the beginning of warmup equals to | |
warmup_ratio * initial_lr | |
@runner (dict): Configs for running. Run by epoches or iters. | |
""" | |
def __init__(self, | |
by_epoch: bool=True, | |
warmup: str=None, | |
warmup_iters: int=0, | |
warmup_ratio: float=0.1, | |
runner: dict={}): | |
# validate the "warmup" argument | |
if warmup is not None: | |
if warmup not in ['constant', 'linear', 'exp']: | |
raise ValueError( | |
f'"{warmup}" is not a supported type for warming up, valid' | |
' types are "constant" and "linear"') | |
if warmup is not None: | |
assert warmup_iters > 0, \ | |
'"warmup_iters" must be a positive integer' | |
assert 0 < warmup_ratio <= 1.0, \ | |
'"warmup_ratio" must be in range (0,1]' | |
if runner is None: | |
raise RuntimeError('runner should be set.') | |
self.by_epoch = by_epoch | |
self.warmup = warmup | |
self.warmup_iters = warmup_iters | |
self.warmup_ratio = warmup_ratio | |
self.runner = runner | |
self.max_iters = None | |
self.max_epoches = None | |
if 'IterBasedRunner' in self.runner.type: | |
self.max_iters = self.runner.max_iters | |
assert self.by_epoch==False | |
self.warmup_by_epoch = False | |
elif 'EpochBasedRunner' in self.runner.type: | |
self.max_epoches = self.runner.max_epoches | |
assert self.by_epoch==True | |
self.warmup_by_epoch = True | |
else: | |
raise ValueError(f'{self.runner.type} is not a supported type for running.') | |
if self.warmup_by_epoch: | |
self.warmup_epochs = self.warmup_iters | |
self.warmup_iters = None | |
else: | |
self.warmup_epochs = None | |
self.base_lr = [] # initial lr for all param groups | |
self.regular_lr = [] # expected lr if no warming up is performed | |
self._step_count = 0 | |
def _set_lr(self, optimizer, lr_groups): | |
if isinstance(optimizer, dict): | |
for k, optim in optimizer.items(): | |
for param_group, lr in zip(optim.param_groups, lr_groups[k]): | |
param_group['lr'] = lr | |
else: | |
for param_group, lr in zip(optimizer.param_groups, | |
lr_groups): | |
param_group['lr'] = lr | |
def get_lr(self, _iter, max_iter, base_lr): | |
raise NotImplementedError | |
def get_regular_lr(self, _iter, optimizer): | |
max_iters = self.max_iters if not self.by_epoch else self.max_epoches | |
if isinstance(optimizer, dict): | |
lr_groups = {} | |
for k in optimizer.keys(): | |
_lr_group = [ | |
self.get_lr(_iter, max_iters, _base_lr) | |
for _base_lr in self.base_lr[k] | |
] | |
lr_groups.update({k: _lr_group}) | |
return lr_groups | |
else: | |
return [self.get_lr(_iter, max_iters, _base_lr) for _base_lr in self.base_lr] | |
def get_warmup_lr(self, cur_iters): | |
def _get_warmup_lr(cur_iters, regular_lr): | |
if self.warmup == 'constant': | |
warmup_lr = [_lr * self.warmup_ratio for _lr in regular_lr] | |
elif self.warmup == 'linear': | |
k = (1 - cur_iters / self.warmup_iters) * (1 - | |
self.warmup_ratio) | |
warmup_lr = [_lr * (1 - k) for _lr in regular_lr] | |
elif self.warmup == 'exp': | |
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters) | |
warmup_lr = [_lr * k for _lr in regular_lr] | |
return warmup_lr | |
if isinstance(self.regular_lr, dict): | |
lr_groups = {} | |
for key, regular_lr in self.regular_lr.items(): | |
lr_groups[key] = _get_warmup_lr(cur_iters, regular_lr) | |
return lr_groups | |
else: | |
return _get_warmup_lr(cur_iters, self.regular_lr) | |
def before_run(self, optimizer): | |
# NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved, | |
# it will be set according to the optimizer params | |
if isinstance(optimizer, dict): | |
self.base_lr = {} | |
for k, optim in optimizer.items(): | |
for group in optim.param_groups: | |
group.setdefault('initial_lr', group['lr']) | |
_base_lr = [ | |
group['initial_lr'] for group in optim.param_groups | |
] | |
self.base_lr.update({k: _base_lr}) | |
else: | |
for group in optimizer.param_groups: | |
group.setdefault('initial_lr', group['lr']) | |
self.base_lr = [ | |
group['initial_lr'] for group in optimizer.param_groups | |
] | |
def after_train_epoch(self, optimizer): | |
self._step_count += 1 | |
curr_epoch = self._step_count | |
self.regular_lr = self.get_regular_lr(curr_epoch, optimizer) | |
if self.warmup is None or curr_epoch > self.warmup_epoches: | |
self._set_lr(optimizer, self.regular_lr) | |
else: | |
#self.warmup_iters = int(self.warmup_epochs * epoch_len) | |
warmup_lr = self.get_warmup_lr(curr_epoch) | |
self._set_lr(optimizer, warmup_lr) | |
def after_train_iter(self, optimizer): | |
self._step_count += 1 | |
cur_iter = self._step_count | |
self.regular_lr = self.get_regular_lr(cur_iter, optimizer) | |
if self.warmup is None or cur_iter >= self.warmup_iters: | |
self._set_lr(optimizer, self.regular_lr) | |
else: | |
warmup_lr = self.get_warmup_lr(cur_iter) | |
self._set_lr(optimizer, warmup_lr) | |
def get_curr_lr(self, cur_iter): | |
if self.warmup is None or cur_iter >= self.warmup_iters: | |
return self.regular_lr | |
else: | |
return self.get_warmup_lr(cur_iter) | |
def state_dict(self): | |
""" | |
Returns the state of the scheduler as a :class:`dict`. | |
It contains an entry for every variable in self.__dict__ which | |
is not the optimizer. | |
""" | |
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} | |
def load_state_dict(self, state_dict): | |
"""Loads the schedulers state. | |
Args: | |
@state_dict (dict): scheduler state. Should be an object returned | |
from a call to :meth:`state_dict`. | |
""" | |
self.__dict__.update(state_dict) | |
class PolyLrUpdater(LrUpdater): | |
def __init__(self, power=1., min_lr=0., **kwargs): | |
self.power = power | |
self.min_lr = min_lr | |
super(PolyLrUpdater, self).__init__(**kwargs) | |
def get_lr(self, _iter, max_iters, base_lr): | |
progress = _iter | |
max_progress = max_iters | |
coeff = (1 - progress / max_progress)**self.power | |
return (base_lr - self.min_lr) * coeff + self.min_lr | |
def build_lr_schedule_with_cfg(cfg): | |
# build learning rate schedule with config. | |
lr_config = copy.deepcopy(cfg.lr_config) | |
policy = lr_config.pop('policy') | |
if cfg.lr_config.policy == 'poly': | |
schedule = PolyLrUpdater(runner=cfg.runner, **lr_config) | |
else: | |
raise RuntimeError(f'{cfg.lr_config.policy} \ | |
is not supported in this framework.') | |
return schedule | |
#def step_learning_rate(base_lr, epoch, step_epoch, multiplier=0.1): | |
# """Sets the learning rate to the base LR decayed by 10 every step epochs""" | |
# lr = base_lr * (multiplier ** (epoch // step_epoch)) | |
# return lr | |
def register_torch_optimizers(): | |
torch_optimizers = {} | |
for module_name in dir(torch.optim): | |
if module_name.startswith('__'): | |
continue | |
_optim = getattr(torch.optim, module_name) | |
if inspect.isclass(_optim) and issubclass(_optim, | |
torch.optim.Optimizer): | |
torch_optimizers[module_name] = _optim | |
return torch_optimizers | |
TORCH_OPTIMIZER = register_torch_optimizers() | |
def build_optimizer_with_cfg(cfg, model): | |
# encoder_parameters = [] | |
# decoder_parameters = [] | |
# nongrad_parameters = [] | |
# for key, value in dict(model.named_parameters()).items(): | |
# if value.requires_grad: | |
# if 'encoder' in key: | |
# encoder_parameters.append(value) | |
# else: | |
# decoder_parameters.append(value) | |
# else: | |
# nongrad_parameters.append(value) | |
#params = [{"params": filter(lambda p: p.requires_grad, model.parameters())}] | |
optim_cfg = copy.deepcopy(cfg.optimizer) | |
optim_type = optim_cfg.pop('type', None) | |
if optim_type is None: | |
raise RuntimeError(f'{optim_type} is not set') | |
if optim_type not in TORCH_OPTIMIZER: | |
raise RuntimeError(f'{optim_type} is not supported in torch {torch.__version__}') | |
if 'others' not in optim_cfg: | |
optim_cfg['others'] = optim_cfg['decoder'] | |
def match(key1, key_list, strict_match=False): | |
if not strict_match: | |
for k in key_list: | |
if k in key1: | |
return k | |
else: | |
for k in key_list: | |
if k == key1.split('.')[1]: | |
return k | |
return None | |
optim_obj = TORCH_OPTIMIZER[optim_type] | |
matching_type = optim_cfg.pop('strict_match', False) | |
module_names = optim_cfg.keys() | |
model_parameters = {i: [] for i in module_names} | |
model_parameters['others'] = [] | |
nongrad_parameters = [] | |
for key, value in dict(model.named_parameters()).items(): | |
if value.requires_grad: | |
match_key = match(key, module_names, matching_type) | |
# if optim_cfg[match_key]['lr'] == 0: | |
# value.requires_grad=False | |
# continue | |
if match_key is None: | |
model_parameters['others'].append(value) | |
else: | |
model_parameters[match_key].append(value) | |
else: | |
nongrad_parameters.append(value) | |
optims = [{'params':model_parameters[k], **optim_cfg[k]} for k in optim_cfg.keys()] | |
optimizer = optim_obj(optims) | |
# optim_args_encoder = optim_cfg.optimizer.encoder | |
# optim_args_decoder = optim_cfg.optimizer.decoder | |
# optimizer = optim_obj( | |
# [{'params':encoder_parameters, **optim_args_encoder}, | |
# {'params':decoder_parameters, **optim_args_decoder}, | |
# ]) | |
return optimizer | |
def load_ckpt(load_path, model, optimizer=None, scheduler=None, strict_match=True, loss_scaler=None): | |
""" | |
Load the check point for resuming training or finetuning. | |
""" | |
logger = logging.getLogger() | |
if os.path.isfile(load_path): | |
if main_process(): | |
logger.info(f"Loading weight '{load_path}'") | |
checkpoint = torch.load(load_path, map_location="cpu") | |
ckpt_state_dict = checkpoint['model_state_dict'] | |
model.module.load_state_dict(ckpt_state_dict, strict=strict_match) | |
if optimizer is not None: | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
if scheduler is not None: | |
scheduler.load_state_dict(checkpoint['scheduler']) | |
if loss_scaler is not None and 'scaler' in checkpoint: | |
loss_scaler.load_state_dict(checkpoint['scaler']) | |
print('Loss scaler loaded', loss_scaler) | |
del ckpt_state_dict | |
del checkpoint | |
if main_process(): | |
logger.info(f"Successfully loaded weight: '{load_path}'") | |
if scheduler is not None and optimizer is not None: | |
logger.info(f"Resume training from: '{load_path}'") | |
else: | |
if main_process(): | |
raise RuntimeError(f"No weight found at '{load_path}'") | |
return model, optimizer, scheduler, loss_scaler | |
def save_ckpt(cfg, model, optimizer, scheduler, curr_iter=0, curr_epoch=None, loss_scaler=None): | |
""" | |
Save the model, optimizer, lr scheduler. | |
""" | |
logger = logging.getLogger() | |
if 'IterBasedRunner' in cfg.runner.type: | |
max_iters = cfg.runner.max_iters | |
elif 'EpochBasedRunner' in cfg.runner.type: | |
max_iters = cfg.runner.max_epoches | |
else: | |
raise TypeError(f'{cfg.runner.type} is not supported') | |
ckpt = dict(model_state_dict=model.module.state_dict(), | |
optimizer=optimizer.state_dict(), | |
max_iter=cfg.runner.max_iters if 'max_iters' in cfg.runner \ | |
else cfg.runner.max_epoches, | |
scheduler=scheduler.state_dict(), | |
# current_iter=curr_iter, | |
# current_epoch=curr_epoch, | |
) | |
if loss_scaler is not None: | |
# amp state_dict | |
ckpt.update(dict(scaler=loss_scaler.state_dict())) | |
ckpt_dir = os.path.join(cfg.work_dir, 'ckpt') | |
os.makedirs(ckpt_dir, exist_ok=True) | |
save_name = os.path.join(ckpt_dir, 'step%08d.pth' % curr_iter) | |
saved_ckpts = glob.glob(ckpt_dir + '/step*.pth') | |
torch.save(ckpt, save_name) | |
# keep the last 8 ckpts | |
if len(saved_ckpts) > 8: | |
saved_ckpts.sort() | |
os.remove(saved_ckpts.pop(0)) | |
logger.info(f'Save model: {save_name}') | |
if __name__ == '__main__': | |
print(TORCH_OPTIMIZER) |