File size: 6,732 Bytes
cc0dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from copy import deepcopy

from mmcv.transforms import Compose
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper

from mmpretrain.models.utils import RandomBatchAugment
from mmpretrain.registry import HOOKS, MODEL_WRAPPERS, MODELS


@HOOKS.register_module()
class SwitchRecipeHook(Hook):
    """switch recipe during the training loop, including train pipeline, batch
    augments and loss currently.

    Args:
        schedule (list): Every item of the schedule list should be a dict, and
            the dict should have ``action_epoch`` and some of
            ``train_pipeline``, ``train_augments`` and ``loss`` keys:

            - ``action_epoch`` (int): switch training recipe at which epoch.
            - ``train_pipeline`` (list, optional): The new data pipeline of the
              train dataset. If not specified, keep the original settings.
            - ``batch_augments`` (dict | None, optional): The new batch
              augmentations of during training. See :mod:`Batch Augmentations
              <mmpretrain.models.utils.batch_augments>` for more details.
              If None, disable batch augmentations. If not specified, keep the
              original settings.
            - ``loss`` (dict, optional): The new loss module config. If not
              specified, keep the original settings.

    Example:
        To use this hook in config files.

        .. code:: python

            custom_hooks = [
                dict(
                    type='SwitchRecipeHook',
                    schedule=[
                        dict(
                            action_epoch=30,
                            train_pipeline=pipeline_after_30e,
                            batch_augments=batch_augments_after_30e,
                            loss=loss_after_30e,
                        ),
                        dict(
                            action_epoch=60,
                            # Disable batch augmentations after 60e
                            # and keep other settings.
                            batch_augments=None,
                        ),
                    ]
                )
            ]
    """
    priority = 'NORMAL'

    def __init__(self, schedule):
        recipes = {}
        for recipe in schedule:
            assert 'action_epoch' in recipe, \
                'Please set `action_epoch` in every item ' \
                'of the `schedule` in the SwitchRecipeHook.'
            recipe = deepcopy(recipe)
            if 'train_pipeline' in recipe:
                recipe['train_pipeline'] = Compose(recipe['train_pipeline'])
            if 'batch_augments' in recipe:
                batch_augments = recipe['batch_augments']
                if isinstance(batch_augments, dict):
                    batch_augments = RandomBatchAugment(**batch_augments)
                recipe['batch_augments'] = batch_augments
            if 'loss' in recipe:
                loss = recipe['loss']
                if isinstance(loss, dict):
                    loss = MODELS.build(loss)
                recipe['loss'] = loss

            action_epoch = recipe.pop('action_epoch')
            assert action_epoch not in recipes, \
                f'The `action_epoch` {action_epoch} is repeated ' \
                'in the SwitchRecipeHook.'
            recipes[action_epoch] = recipe
        self.schedule = OrderedDict(sorted(recipes.items()))

    def before_train(self, runner) -> None:
        """before run setting. If resume form a checkpoint, do all switch
        before the current epoch.

        Args:
            runner (Runner): The runner of the training, validation or testing
                process.
        """
        if runner._resume:
            for action_epoch, recipe in self.schedule.items():
                if action_epoch >= runner.epoch + 1:
                    break
                self._do_switch(runner, recipe,
                                f' (resume recipe of epoch {action_epoch})')

    def before_train_epoch(self, runner):
        """do before train epoch."""
        recipe = self.schedule.get(runner.epoch + 1, None)
        if recipe is not None:
            self._do_switch(runner, recipe, f' at epoch {runner.epoch + 1}')

    def _do_switch(self, runner, recipe, extra_info=''):
        """do the switch aug process."""
        if 'batch_augments' in recipe:
            self._switch_batch_augments(runner, recipe['batch_augments'])
            runner.logger.info(f'Switch batch augments{extra_info}.')

        if 'train_pipeline' in recipe:
            self._switch_train_pipeline(runner, recipe['train_pipeline'])
            runner.logger.info(f'Switch train pipeline{extra_info}.')

        if 'loss' in recipe:
            self._switch_loss(runner, recipe['loss'])
            runner.logger.info(f'Switch loss{extra_info}.')

    @staticmethod
    def _switch_batch_augments(runner, batch_augments):
        """switch the train augments."""
        model = runner.model
        if is_model_wrapper(model):
            model = model.module

        model.data_preprocessor.batch_augments = batch_augments

    @staticmethod
    def _switch_train_pipeline(runner, train_pipeline):
        """switch the train loader dataset pipeline."""

        def switch_pipeline(dataset, pipeline):
            if hasattr(dataset, 'pipeline'):
                # for usual dataset
                dataset.pipeline = pipeline
            elif hasattr(dataset, 'datasets'):
                # for concat dataset wrapper
                for ds in dataset.datasets:
                    switch_pipeline(ds, pipeline)
            elif hasattr(dataset, 'dataset'):
                # for other dataset wrappers
                switch_pipeline(dataset.dataset, pipeline)
            else:
                raise RuntimeError(
                    'Cannot access the `pipeline` of the dataset.')

        train_loader = runner.train_loop.dataloader
        switch_pipeline(train_loader.dataset, train_pipeline)

        # To restart the iterator of dataloader when `persistent_workers=True`
        train_loader._iterator = None

    @staticmethod
    def _switch_loss(runner, loss_module):
        """switch the loss module."""
        model = runner.model
        if is_model_wrapper(model, MODEL_WRAPPERS):
            model = model.module

        if hasattr(model, 'loss_module'):
            model.loss_module = loss_module
        elif hasattr(model, 'head') and hasattr(model.head, 'loss_module'):
            model.head.loss_module = loss_module
        else:
            raise RuntimeError('Cannot access the `loss_module` of the model.')