import copy import torch from torch import nn __all__ = ['build_optimizer'] def param_groups_weight_decay(model: nn.Module, weight_decay=1e-5, no_weight_decay_list=()): no_weight_decay_list = set(no_weight_decay_list) decay = [] no_decay = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if param.ndim <= 1 or name.endswith( '.bias') or any(nd in name for nd in no_weight_decay_list): no_decay.append(param) else: decay.append(param) return [ { 'params': no_decay, 'weight_decay': 0.0 }, { 'params': decay, 'weight_decay': weight_decay }, ] def build_optimizer(optim_config, lr_scheduler_config, epochs, step_each_epoch, model): from . import lr config = copy.deepcopy(optim_config) if isinstance(model, nn.Module): # a model was passed in, extract parameters and add weight decays to appropriate layers weight_decay = config.get('weight_decay', 0.0) filter_bias_and_bn = (config.pop('filter_bias_and_bn') if 'filter_bias_and_bn' in config else False) if weight_decay > 0.0 and filter_bias_and_bn: no_weight_decay = {} if hasattr(model, 'no_weight_decay'): no_weight_decay = model.no_weight_decay() parameters = param_groups_weight_decay(model, weight_decay, no_weight_decay) config['weight_decay'] = 0.0 # print('debug adamw') else: parameters = model.parameters() else: # iterable of parameters or param groups passed in parameters = model optim = getattr(torch.optim, config.pop('name'))(params=parameters, **config) lr_config = copy.deepcopy(lr_scheduler_config) lr_config.update({ 'epochs': epochs, 'step_each_epoch': step_each_epoch, 'lr': config['lr'] }) lr_scheduler = getattr(lr, lr_config.pop('name'))(**lr_config)(optimizer=optim) return optim, lr_scheduler