pengdadaaa's picture
Upload 741 files
786f6a6 verified
raw
history blame
15.9 kB
""" Optimizer Factory w/ Custom Weight Decay
Hacked together by / Copyright 2021 Ross Wightman
"""
import logging
from itertools import islice
from typing import Optional, Callable, Tuple
import torch
import torch.nn as nn
import torch.optim as optim
from timm.models import group_parameters
from .adabelief import AdaBelief
from .adafactor import Adafactor
from .adahessian import Adahessian
from .adamp import AdamP
from .adan import Adan
from .lamb import Lamb
from .lars import Lars
from .lion import Lion
from .lookahead import Lookahead
from .madgrad import MADGRAD
from .nadam import Nadam
from .nadamw import NAdamW
from .nvnovograd import NvNovoGrad
from .radam import RAdam
from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP
from .sgdw import SGDW
_logger = logging.getLogger(__name__)
# optimizers to default to multi-tensor
_DEFAULT_FOREACH = {
'lion',
}
def param_groups_weight_decay(
model: nn.Module,
weight_decay=1e-5,
no_weight_decay_list=()
):
no_weight_decay_list = set(no_weight_decay_list)
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
no_decay.append(param)
else:
decay.append(param)
return [
{'params': no_decay, 'weight_decay': 0.},
{'params': decay, 'weight_decay': weight_decay}]
def _group(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def _layer_map(model, layers_per_group=12, num_groups=None):
def _in_head(n, hp):
if not hp:
return True
elif isinstance(hp, (tuple, list)):
return any([n.startswith(hpi) for hpi in hp])
else:
return n.startswith(hp)
head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None)
names_trunk = []
names_head = []
for n, _ in model.named_parameters():
names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n)
# group non-head layers
num_trunk_layers = len(names_trunk)
if num_groups is not None:
layers_per_group = -(num_trunk_layers // -num_groups)
names_trunk = list(_group(names_trunk, layers_per_group))
num_trunk_groups = len(names_trunk)
layer_map = {n: i for i, l in enumerate(names_trunk) for n in l}
layer_map.update({n: num_trunk_groups for n in names_head})
return layer_map
def param_groups_layer_decay(
model: nn.Module,
weight_decay: float = 0.05,
no_weight_decay_list: Tuple[str] = (),
layer_decay: float = .75,
end_layer_decay: Optional[float] = None,
verbose: bool = False,
):
"""
Parameter groups for layer-wise lr decay & weight decay
Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
"""
no_weight_decay_list = set(no_weight_decay_list)
param_group_names = {} # NOTE for debugging
param_groups = {}
if hasattr(model, 'group_matcher'):
# FIXME interface needs more work
layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
else:
# fallback
layer_map = _layer_map(model)
num_layers = max(layer_map.values()) + 1
layer_max = num_layers - 1
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# no decay: all 1D parameters and model specific ones
if param.ndim == 1 or name in no_weight_decay_list:
g_decay = "no_decay"
this_decay = 0.
else:
g_decay = "decay"
this_decay = weight_decay
layer_id = layer_map.get(name, layer_max)
group_name = "layer_%d_%s" % (layer_id, g_decay)
if group_name not in param_groups:
this_scale = layer_scales[layer_id]
param_group_names[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"param_names": [],
}
param_groups[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"params": [],
}
param_group_names[group_name]["param_names"].append(name)
param_groups[group_name]["params"].append(param)
if verbose:
import json
_logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
return list(param_groups.values())
def optimizer_kwargs(cfg):
""" cfg/argparse to kwargs helper
Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
"""
kwargs = dict(
opt=cfg.opt,
lr=cfg.lr,
weight_decay=cfg.weight_decay,
momentum=cfg.momentum,
)
if getattr(cfg, 'opt_eps', None) is not None:
kwargs['eps'] = cfg.opt_eps
if getattr(cfg, 'opt_betas', None) is not None:
kwargs['betas'] = cfg.opt_betas
if getattr(cfg, 'layer_decay', None) is not None:
kwargs['layer_decay'] = cfg.layer_decay
if getattr(cfg, 'opt_args', None) is not None:
kwargs.update(cfg.opt_args)
if getattr(cfg, 'opt_foreach', None) is not None:
kwargs['foreach'] = cfg.opt_foreach
return kwargs
def create_optimizer(args, model, filter_bias_and_bn=True):
""" Legacy optimizer factory for backwards compatibility.
NOTE: Use create_optimizer_v2 for new code.
"""
return create_optimizer_v2(
model,
**optimizer_kwargs(cfg=args),
filter_bias_and_bn=filter_bias_and_bn,
)
def create_optimizer_v2(
model_or_params,
opt: str = 'sgd',
lr: Optional[float] = None,
weight_decay: float = 0.,
momentum: float = 0.9,
foreach: Optional[bool] = None,
filter_bias_and_bn: bool = True,
layer_decay: Optional[float] = None,
param_group_fn: Optional[Callable] = None,
**kwargs,
):
""" Create an optimizer.
TODO currently the model is passed in and all parameters are selected for optimization.
For more general use an interface that allows selection of parameters to optimize and lr groups, one of:
* a filter fn interface that further breaks params into groups in a weight_decay compatible fashion
* expose the parameters interface and leave it up to caller
Args:
model_or_params (nn.Module): model containing parameters to optimize
opt: name of optimizer to create
lr: initial learning rate
weight_decay: weight decay to apply in optimizer
momentum: momentum for momentum based optimizers (others may use betas via kwargs)
foreach: Enable / disable foreach (multi-tensor) operation if True / False. Choose safe default if None
filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
**kwargs: extra optimizer specific kwargs to pass through
Returns:
Optimizer
"""
if isinstance(model_or_params, nn.Module):
# a model was passed in, extract parameters and add weight decays to appropriate layers
no_weight_decay = {}
if hasattr(model_or_params, 'no_weight_decay'):
no_weight_decay = model_or_params.no_weight_decay()
if param_group_fn:
parameters = param_group_fn(model_or_params)
elif layer_decay is not None:
parameters = param_groups_layer_decay(
model_or_params,
weight_decay=weight_decay,
layer_decay=layer_decay,
no_weight_decay_list=no_weight_decay,
)
weight_decay = 0.
elif weight_decay and filter_bias_and_bn:
parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay)
weight_decay = 0.
else:
parameters = model_or_params.parameters()
else:
# iterable of parameters or param groups passed in
parameters = model_or_params
opt_lower = opt.lower()
opt_split = opt_lower.split('_')
opt_lower = opt_split[-1]
if opt_lower.startswith('fused'):
try:
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
has_apex = True
except ImportError:
has_apex = False
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
if opt_lower.startswith('bnb'):
try:
import bitsandbytes as bnb
has_bnb = True
except ImportError:
has_bnb = False
assert has_bnb and torch.cuda.is_available(), 'bitsandbytes and CUDA required for bnb optimizers'
opt_args = dict(weight_decay=weight_decay, **kwargs)
if lr is not None:
opt_args.setdefault('lr', lr)
if foreach is None:
if opt in _DEFAULT_FOREACH:
opt_args.setdefault('foreach', True)
else:
opt_args['foreach'] = foreach
# basic SGD & related
if opt_lower == 'sgd' or opt_lower == 'nesterov':
# NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'momentum':
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)
elif opt_lower == 'sgdp':
optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'sgdw' or opt_lower == 'nesterovw':
# NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
opt_args.pop('eps', None)
optimizer = SGDW(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'momentumw':
opt_args.pop('eps', None)
optimizer = SGDW(parameters, momentum=momentum, nesterov=False, **opt_args)
# adaptive
elif opt_lower == 'adam':
optimizer = optim.Adam(parameters, **opt_args)
elif opt_lower == 'adamw':
optimizer = optim.AdamW(parameters, **opt_args)
elif opt_lower == 'adamp':
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
elif opt_lower == 'nadam':
try:
# NOTE PyTorch >= 1.10 should have native NAdam
optimizer = optim.Nadam(parameters, **opt_args)
except AttributeError:
optimizer = Nadam(parameters, **opt_args)
elif opt_lower == 'nadamw':
optimizer = NAdamW(parameters, **opt_args)
elif opt_lower == 'radam':
optimizer = RAdam(parameters, **opt_args)
elif opt_lower == 'adamax':
optimizer = optim.Adamax(parameters, **opt_args)
elif opt_lower == 'adabelief':
optimizer = AdaBelief(parameters, rectify=False, **opt_args)
elif opt_lower == 'radabelief':
optimizer = AdaBelief(parameters, rectify=True, **opt_args)
elif opt_lower == 'adadelta':
optimizer = optim.Adadelta(parameters, **opt_args)
elif opt_lower == 'adagrad':
opt_args.setdefault('eps', 1e-8)
optimizer = optim.Adagrad(parameters, **opt_args)
elif opt_lower == 'adafactor':
optimizer = Adafactor(parameters, **opt_args)
elif opt_lower == 'adanp':
optimizer = Adan(parameters, no_prox=False, **opt_args)
elif opt_lower == 'adanw':
optimizer = Adan(parameters, no_prox=True, **opt_args)
elif opt_lower == 'lamb':
optimizer = Lamb(parameters, **opt_args)
elif opt_lower == 'lambc':
optimizer = Lamb(parameters, trust_clip=True, **opt_args)
elif opt_lower == 'larc':
optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args)
elif opt_lower == 'lars':
optimizer = Lars(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'nlarc':
optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args)
elif opt_lower == 'nlars':
optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'madgrad':
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'madgradw':
optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args)
elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
optimizer = NvNovoGrad(parameters, **opt_args)
elif opt_lower == 'rmsprop':
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
elif opt_lower == 'rmsproptf':
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)
elif opt_lower == 'lion':
opt_args.pop('eps', None)
optimizer = Lion(parameters, **opt_args)
# second order
elif opt_lower == 'adahessian':
optimizer = Adahessian(parameters, **opt_args)
# NVIDIA fused optimizers, require APEX to be installed
elif opt_lower == 'fusedsgd':
opt_args.pop('eps', None)
optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'fusedmomentum':
opt_args.pop('eps', None)
optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args)
elif opt_lower == 'fusedadam':
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
elif opt_lower == 'fusedadamw':
optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
elif opt_lower == 'fusedlamb':
optimizer = FusedLAMB(parameters, **opt_args)
elif opt_lower == 'fusednovograd':
opt_args.setdefault('betas', (0.95, 0.98))
optimizer = FusedNovoGrad(parameters, **opt_args)
# bitsandbytes optimizers, require bitsandbytes to be installed
elif opt_lower == 'bnbsgd':
opt_args.pop('eps', None)
optimizer = bnb.optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'bnbsgd8bit':
opt_args.pop('eps', None)
optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'bnbmomentum':
opt_args.pop('eps', None)
optimizer = bnb.optim.SGD(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'bnbmomentum8bit':
opt_args.pop('eps', None)
optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'bnbadam':
optimizer = bnb.optim.Adam(parameters, **opt_args)
elif opt_lower == 'bnbadam8bit':
optimizer = bnb.optim.Adam8bit(parameters, **opt_args)
elif opt_lower == 'bnbadamw':
optimizer = bnb.optim.AdamW(parameters, **opt_args)
elif opt_lower == 'bnbadamw8bit':
optimizer = bnb.optim.AdamW8bit(parameters, **opt_args)
elif opt_lower == 'bnblamb':
optimizer = bnb.optim.LAMB(parameters, **opt_args)
elif opt_lower == 'bnblamb8bit':
optimizer = bnb.optim.LAMB8bit(parameters, **opt_args)
elif opt_lower == 'bnblars':
optimizer = bnb.optim.LARS(parameters, **opt_args)
elif opt_lower == 'bnblarsb8bit':
optimizer = bnb.optim.LAMB8bit(parameters, **opt_args)
elif opt_lower == 'bnblion':
optimizer = bnb.optim.Lion(parameters, **opt_args)
elif opt_lower == 'bnblion8bit':
optimizer = bnb.optim.Lion8bit(parameters, **opt_args)
else:
assert False and "Invalid optimizer"
raise ValueError
if len(opt_split) > 1:
if opt_split[0] == 'lookahead':
optimizer = Lookahead(optimizer)
return optimizer