File size: 7,409 Bytes
cc0dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Copyright (c) OpenMMLab. All rights reserved.
from collections import defaultdict
from typing import Callable, List, Optional

from mmengine.logging import MMLogger
from mmengine.optim import DefaultOptimWrapperConstructor
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
from torch import nn
from torch.nn import GroupNorm, LayerNorm

from mmpretrain.registry import OPTIM_WRAPPER_CONSTRUCTORS


@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
    """Different learning rates are set for different layers of backbone.

    By default, each parameter share the same optimizer settings, and we
    provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
    It is a dict and may contain the following fields:

    - ``layer_decay_rate`` (float): The learning rate of a parameter will
      multiply it by multiple times according to the layer depth of the
      parameter. Usually, it's less than 1, so that the earlier layers will
      have a lower learning rate. Defaults to 1.
    - ``bias_decay_mult`` (float): It will be multiplied to the weight
      decay for all bias parameters (except for those in normalization layers).
    - ``norm_decay_mult`` (float): It will be multiplied to the weight
      decay for all weight and bias parameters of normalization layers.
    - ``flat_decay_mult`` (float): It will be multiplied to the weight
      decay for all one-dimensional parameters
    - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
      one of the keys in ``custom_keys`` is a substring of the name of one
      parameter, then the setting of the parameter will be specified by
      ``custom_keys[key]`` and other setting like ``bias_decay_mult`` will be
      ignored. It should be a dict and may contain fields ``decay_mult``.
      (The ``lr_mult`` is disabled in this constructor).

    Example:

    In the config file, you can use this constructor as below:

    .. code:: python

        optim_wrapper = dict(
            optimizer=dict(
                type='AdamW',
                lr=4e-3,
                weight_decay=0.05,
                eps=1e-8,
                betas=(0.9, 0.999)),
            constructor='LearningRateDecayOptimWrapperConstructor',
            paramwise_cfg=dict(
                layer_decay_rate=0.75,  # layer-wise lr decay factor
                norm_decay_mult=0.,
                flat_decay_mult=0.,
                custom_keys={
                    '.cls_token': dict(decay_mult=0.0),
                    '.pos_embed': dict(decay_mult=0.0)
                }))
    """

    def add_params(self,
                   params: List[dict],
                   module: nn.Module,
                   prefix: str = '',
                   get_layer_depth: Optional[Callable] = None,
                   **kwargs) -> None:
        """Add all parameters of module to the params list.

        The parameters of the given module will be added to the list of param
        groups, with specific rules defined by paramwise_cfg.

        Args:
            params (List[dict]): A list of param groups, it will be modified
                in place.
            module (nn.Module): The module to be added.
            optimizer_cfg (dict): The configuration of optimizer.
            prefix (str): The prefix of the module.
        """
        # get param-wise options
        custom_keys = self.paramwise_cfg.get('custom_keys', {})
        # first sort with alphabet order and then sort with reversed len of str
        sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
        logger = MMLogger.get_current_instance()

        # The model should have `get_layer_depth` method
        if get_layer_depth is None and not hasattr(module, 'get_layer_depth'):
            raise NotImplementedError('The layer-wise learning rate decay need'
                                      f' the model {type(module)} has'
                                      ' `get_layer_depth` method.')
        else:
            get_layer_depth = get_layer_depth or module.get_layer_depth

        bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None)
        norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None)
        flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None)
        decay_rate = self.paramwise_cfg.get('layer_decay_rate', 1.0)

        # special rules for norm layers and depth-wise conv layers
        is_norm = isinstance(module,
                             (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))

        for name, param in module.named_parameters(recurse=False):
            param_group = {'params': [param]}
            param_name = prefix + name
            if not param.requires_grad:
                continue

            if self.base_wd is not None:
                base_wd = self.base_wd
                custom_key = next(
                    filter(lambda k: k in param_name, sorted_keys), None)
                # custom parameters decay
                if custom_key is not None:
                    custom_cfg = custom_keys[custom_key].copy()
                    decay_mult = custom_cfg.pop('decay_mult', 1.)

                    param_group['weight_decay'] = base_wd * decay_mult
                    # add custom settings to param_group
                    param_group.update(custom_cfg)
                # norm decay
                elif is_norm and norm_decay_mult is not None:
                    param_group['weight_decay'] = base_wd * norm_decay_mult
                # bias decay
                elif name == 'bias' and bias_decay_mult is not None:
                    param_group['weight_decay'] = base_wd * bias_decay_mult
                # flatten parameters decay
                elif param.ndim == 1 and flat_decay_mult is not None:
                    param_group['weight_decay'] = base_wd * flat_decay_mult
                else:
                    param_group['weight_decay'] = base_wd

            layer_id, max_id = get_layer_depth(param_name)
            scale = decay_rate**(max_id - layer_id - 1)
            param_group['lr'] = self.base_lr * scale
            param_group['lr_scale'] = scale
            param_group['layer_id'] = layer_id
            param_group['param_name'] = param_name

            params.append(param_group)

        for child_name, child_mod in module.named_children():
            child_prefix = f'{prefix}{child_name}.'
            self.add_params(
                params,
                child_mod,
                prefix=child_prefix,
                get_layer_depth=get_layer_depth,
            )

        if prefix == '':
            layer_params = defaultdict(list)
            for param in params:
                layer_params[param['layer_id']].append(param)
            for layer_id, layer_params in layer_params.items():
                lr_scale = layer_params[0]['lr_scale']
                lr = layer_params[0]['lr']
                msg = [
                    f'layer {layer_id} params '
                    f'(lr={lr:.3g}, lr_scale={lr_scale:.3g}):'
                ]
                for param in layer_params:
                    msg.append(f'\t{param["param_name"]}: '
                               f'weight_decay={param["weight_decay"]:.3g}')
                logger.debug('\n'.join(msg))