Spaces:
Runtime error
Runtime error
File size: 7,409 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
# Copyright (c) OpenMMLab. All rights reserved.
from collections import defaultdict
from typing import Callable, List, Optional
from mmengine.logging import MMLogger
from mmengine.optim import DefaultOptimWrapperConstructor
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
from torch import nn
from torch.nn import GroupNorm, LayerNorm
from mmpretrain.registry import OPTIM_WRAPPER_CONSTRUCTORS
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
"""Different learning rates are set for different layers of backbone.
By default, each parameter share the same optimizer settings, and we
provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
It is a dict and may contain the following fields:
- ``layer_decay_rate`` (float): The learning rate of a parameter will
multiply it by multiple times according to the layer depth of the
parameter. Usually, it's less than 1, so that the earlier layers will
have a lower learning rate. Defaults to 1.
- ``bias_decay_mult`` (float): It will be multiplied to the weight
decay for all bias parameters (except for those in normalization layers).
- ``norm_decay_mult`` (float): It will be multiplied to the weight
decay for all weight and bias parameters of normalization layers.
- ``flat_decay_mult`` (float): It will be multiplied to the weight
decay for all one-dimensional parameters
- ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
one of the keys in ``custom_keys`` is a substring of the name of one
parameter, then the setting of the parameter will be specified by
``custom_keys[key]`` and other setting like ``bias_decay_mult`` will be
ignored. It should be a dict and may contain fields ``decay_mult``.
(The ``lr_mult`` is disabled in this constructor).
Example:
In the config file, you can use this constructor as below:
.. code:: python
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=4e-3,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999)),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.75, # layer-wise lr decay factor
norm_decay_mult=0.,
flat_decay_mult=0.,
custom_keys={
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0)
}))
"""
def add_params(self,
params: List[dict],
module: nn.Module,
prefix: str = '',
get_layer_depth: Optional[Callable] = None,
**kwargs) -> None:
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (List[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
optimizer_cfg (dict): The configuration of optimizer.
prefix (str): The prefix of the module.
"""
# get param-wise options
custom_keys = self.paramwise_cfg.get('custom_keys', {})
# first sort with alphabet order and then sort with reversed len of str
sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
logger = MMLogger.get_current_instance()
# The model should have `get_layer_depth` method
if get_layer_depth is None and not hasattr(module, 'get_layer_depth'):
raise NotImplementedError('The layer-wise learning rate decay need'
f' the model {type(module)} has'
' `get_layer_depth` method.')
else:
get_layer_depth = get_layer_depth or module.get_layer_depth
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None)
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None)
flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None)
decay_rate = self.paramwise_cfg.get('layer_decay_rate', 1.0)
# special rules for norm layers and depth-wise conv layers
is_norm = isinstance(module,
(_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
for name, param in module.named_parameters(recurse=False):
param_group = {'params': [param]}
param_name = prefix + name
if not param.requires_grad:
continue
if self.base_wd is not None:
base_wd = self.base_wd
custom_key = next(
filter(lambda k: k in param_name, sorted_keys), None)
# custom parameters decay
if custom_key is not None:
custom_cfg = custom_keys[custom_key].copy()
decay_mult = custom_cfg.pop('decay_mult', 1.)
param_group['weight_decay'] = base_wd * decay_mult
# add custom settings to param_group
param_group.update(custom_cfg)
# norm decay
elif is_norm and norm_decay_mult is not None:
param_group['weight_decay'] = base_wd * norm_decay_mult
# bias decay
elif name == 'bias' and bias_decay_mult is not None:
param_group['weight_decay'] = base_wd * bias_decay_mult
# flatten parameters decay
elif param.ndim == 1 and flat_decay_mult is not None:
param_group['weight_decay'] = base_wd * flat_decay_mult
else:
param_group['weight_decay'] = base_wd
layer_id, max_id = get_layer_depth(param_name)
scale = decay_rate**(max_id - layer_id - 1)
param_group['lr'] = self.base_lr * scale
param_group['lr_scale'] = scale
param_group['layer_id'] = layer_id
param_group['param_name'] = param_name
params.append(param_group)
for child_name, child_mod in module.named_children():
child_prefix = f'{prefix}{child_name}.'
self.add_params(
params,
child_mod,
prefix=child_prefix,
get_layer_depth=get_layer_depth,
)
if prefix == '':
layer_params = defaultdict(list)
for param in params:
layer_params[param['layer_id']].append(param)
for layer_id, layer_params in layer_params.items():
lr_scale = layer_params[0]['lr_scale']
lr = layer_params[0]['lr']
msg = [
f'layer {layer_id} params '
f'(lr={lr:.3g}, lr_scale={lr_scale:.3g}):'
]
for param in layer_params:
msg.append(f'\t{param["param_name"]}: '
f'weight_decay={param["weight_decay"]:.3g}')
logger.debug('\n'.join(msg))
|