topdu's picture
openocr demo
29f689c
raw
history blame
2.35 kB
import copy
import torch
from torch import nn
__all__ = ['build_optimizer']
def param_groups_weight_decay(model: nn.Module,
weight_decay=1e-5,
no_weight_decay_list=()):
no_weight_decay_list = set(no_weight_decay_list)
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if param.ndim <= 1 or name.endswith(
'.bias') or any(nd in name for nd in no_weight_decay_list):
no_decay.append(param)
else:
decay.append(param)
return [
{
'params': no_decay,
'weight_decay': 0.0
},
{
'params': decay,
'weight_decay': weight_decay
},
]
def build_optimizer(optim_config, lr_scheduler_config, epochs, step_each_epoch,
model):
from . import lr
config = copy.deepcopy(optim_config)
if isinstance(model, nn.Module):
# a model was passed in, extract parameters and add weight decays to appropriate layers
weight_decay = config.get('weight_decay', 0.0)
filter_bias_and_bn = (config.pop('filter_bias_and_bn')
if 'filter_bias_and_bn' in config else False)
if weight_decay > 0.0 and filter_bias_and_bn:
no_weight_decay = {}
if hasattr(model, 'no_weight_decay'):
no_weight_decay = model.no_weight_decay()
parameters = param_groups_weight_decay(model, weight_decay,
no_weight_decay)
config['weight_decay'] = 0.0
# print('debug adamw')
else:
parameters = model.parameters()
else:
# iterable of parameters or param groups passed in
parameters = model
optim = getattr(torch.optim, config.pop('name'))(params=parameters,
**config)
lr_config = copy.deepcopy(lr_scheduler_config)
lr_config.update({
'epochs': epochs,
'step_each_epoch': step_each_epoch,
'lr': config['lr']
})
lr_scheduler = getattr(lr,
lr_config.pop('name'))(**lr_config)(optimizer=optim)
return optim, lr_scheduler