File size: 7,828 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from contextlib import contextmanager
from typing import Dict, Iterator, List, Optional, Tuple

import torch
import torch.nn as nn

from .optimizer_wrapper import OptimWrapper


class OptimWrapperDict(OptimWrapper):
    """A dictionary container of :obj:`OptimWrapper`.

    If runner is training with multiple optimizers, all optimizer wrappers
    should be managed by :obj:`OptimWrapperDict` which is built by
    ``CustomOptimWrapperConstructor``. ``OptimWrapperDict`` will load and save
    the state dictionary of all optimizer wrappers.

    Consider the semantic ambiguity of calling :meth:``update_params``,
    :meth:`backward` of all optimizer wrappers, ``OptimWrapperDict`` will not
    implement these methods.

    Examples:
        >>> import torch.nn as nn
        >>> from torch.optim import SGD
        >>> from mmengine.optim import OptimWrapperDict, OptimWrapper
        >>> model1 = nn.Linear(1, 1)
        >>> model2 = nn.Linear(1, 1)
        >>> optim_wrapper1 = OptimWrapper(SGD(model1.parameters(), lr=0.1))
        >>> optim_wrapper2 = OptimWrapper(SGD(model2.parameters(), lr=0.1))
        >>> optim_wrapper_dict = OptimWrapperDict(model1=optim_wrapper1,
        >>>                                       model2=optim_wrapper2)

    Note:
        The optimizer wrapper contained in ``OptimWrapperDict`` can be accessed
        in the same way as `dict`.

    Args:
        **optim_wrappers: A dictionary of ``OptimWrapper`` instance.
    """

    def __init__(self, **optim_wrapper_dict: OptimWrapper):
        for key, value in optim_wrapper_dict.items():
            assert isinstance(value, OptimWrapper), (
                '`OptimWrapperDict` only accept OptimWrapper instance, '
                f'but got {key}: {type(value)}')
        self.optim_wrappers = optim_wrapper_dict

    def update_params(  # type: ignore
            self,
            loss: torch.Tensor,
            step_kwargs: Optional[Dict] = None,
            zero_kwargs: Optional[Dict] = None) -> None:
        """Update all optimizer wrappers would lead to a duplicate backward
        errors, and OptimWrapperDict does not know which optimizer wrapper
        should be updated.

        Therefore, this method is not implemented. The optimizer wrapper of
        OptimWrapperDict should be accessed and call its `update_params`.
        """
        raise NotImplementedError('`update_params` should be called by each '
                                  'optimizer separately`')

    def backward(self, loss: torch.Tensor, **kwargs) -> None:
        """Since OptimWrapperDict doesn't know which optimizer wrapper's
        backward method should be called (``loss_scaler`` maybe different in
        different :obj:AmpOptimWrapper), this method is not implemented.

        The optimizer wrapper of OptimWrapperDict should be accessed and call
        its `backward`.
        """
        raise NotImplementedError('`backward` should be called by each '
                                  'optimizer separately`')

    def step(self, **kwargs) -> None:
        """Since the backward method is not implemented, the step should not be
        implemented either."""
        raise NotImplementedError('`step` should be called by each '
                                  'optimizer separately`')

    def zero_grad(self, **kwargs) -> None:
        """Set the gradients of all optimizer wrappers to zero."""
        for optim_wrapper in self.optim_wrappers.values():
            optim_wrapper.zero_grad()

    @contextmanager
    def optim_context(self, model: nn.Module):
        """``optim_context`` should be called by each optimizer separately."""
        raise NotImplementedError(
            '`optim_context` should be called by each optimizer separately')

    def initialize_count_status(self, model: nn.Module, cur_iter,
                                max_iters) -> None:
        """Do nothing but provide unified interface for :obj:`OptimWrapper`

        Since ``OptimWrapperDict`` does not know the correspondence between
        model and optimizer wrapper. ``initialize_iter_status`` will do nothing
        and each optimizer wrapper should call ``initialize_iter_status``
        separately.
        """
        return

    @property
    def param_groups(self):
        """Returns the parameter groups of each OptimWrapper."""
        param_groups = dict()
        for key, value in self.optim_wrappers.items():
            param_groups[key] = value.param_groups
        return param_groups

    def get_lr(self) -> Dict[str, List[float]]:
        """Get the learning rate of all optimizers.

        Returns:
            Dict[str, List[float]]: Learning rate of all optimizers.
        """
        lr_dict = dict()
        for name, optim_wrapper in self.optim_wrappers.items():
            inner_lr_dict = optim_wrapper.get_lr()
            if 'base_lr' in inner_lr_dict:
                lr_dict[f'{name}.base_lr'] = inner_lr_dict['base_lr']
            lr_dict[f'{name}.lr'] = inner_lr_dict['lr']
        return lr_dict

    def get_momentum(self) -> Dict[str, List[float]]:
        """Get the momentum of all optimizers.

        Returns:
            Dict[str, List[float]]: momentum of all optimizers.
        """
        momentum_dict = dict()
        for name, optim_wrapper in self.optim_wrappers.items():
            momentum_dict[f'{name}.momentum'] = optim_wrapper.get_momentum(
            )['momentum']
        return momentum_dict

    def state_dict(self) -> dict:
        """Get the state dictionary of all optimizer wrappers.

        Returns:
            dict: Each key-value pair in the dictionary represents the name
            and state dictionary of corresponding :obj:`OptimWrapper`.
        """
        state_dict = dict()
        for name, optim_wrapper in self.optim_wrappers.items():
            state_dict[name] = optim_wrapper.state_dict()
        return state_dict

    def load_state_dict(self, state_dict: dict) -> None:
        """Load the state dictionary from the ``state_dict``.

        Args:
            state_dict (dict): Each key-value pair in `state_dict` represents
                the name and the state dictionary of corresponding
                :obj:`OptimWrapper`.
        """
        for name, _state_dict in state_dict.items():
            assert name in self.optim_wrappers, (
                f'Mismatched `state_dict`! cannot found {name} in '
                'OptimWrapperDict')
            self.optim_wrappers[name].load_state_dict(_state_dict)

    def items(self) -> Iterator[Tuple[str, OptimWrapper]]:
        """A generator to get the name and corresponding
        :obj:`OptimWrapper`"""
        yield from self.optim_wrappers.items()

    def values(self) -> Iterator[OptimWrapper]:
        """A generator to get :obj:`OptimWrapper`"""
        yield from self.optim_wrappers.values()

    def keys(self) -> Iterator[str]:
        """A generator to get the name of :obj:`OptimWrapper`"""
        yield from self.optim_wrappers.keys()

    def __getitem__(self, key: str) -> OptimWrapper:
        assert key in self.optim_wrappers, (
            f'Cannot find {key} in OptimWrapperDict, please check '
            'your optimizer constructor.')
        return self.optim_wrappers[key]

    def __contains__(self, key: str) -> bool:
        return key in self.optim_wrappers

    def __len__(self) -> int:
        return len(self.optim_wrappers)

    def __repr__(self) -> str:
        desc = ''
        for name, optim_wrapper in self.optim_wrappers.items():
            desc += f'name: {name}\n'
            desc += repr(optim_wrapper)
        return desc