Spaces:
Runtime error
Runtime error
File size: 6,732 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from copy import deepcopy
from mmcv.transforms import Compose
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmpretrain.models.utils import RandomBatchAugment
from mmpretrain.registry import HOOKS, MODEL_WRAPPERS, MODELS
@HOOKS.register_module()
class SwitchRecipeHook(Hook):
"""switch recipe during the training loop, including train pipeline, batch
augments and loss currently.
Args:
schedule (list): Every item of the schedule list should be a dict, and
the dict should have ``action_epoch`` and some of
``train_pipeline``, ``train_augments`` and ``loss`` keys:
- ``action_epoch`` (int): switch training recipe at which epoch.
- ``train_pipeline`` (list, optional): The new data pipeline of the
train dataset. If not specified, keep the original settings.
- ``batch_augments`` (dict | None, optional): The new batch
augmentations of during training. See :mod:`Batch Augmentations
<mmpretrain.models.utils.batch_augments>` for more details.
If None, disable batch augmentations. If not specified, keep the
original settings.
- ``loss`` (dict, optional): The new loss module config. If not
specified, keep the original settings.
Example:
To use this hook in config files.
.. code:: python
custom_hooks = [
dict(
type='SwitchRecipeHook',
schedule=[
dict(
action_epoch=30,
train_pipeline=pipeline_after_30e,
batch_augments=batch_augments_after_30e,
loss=loss_after_30e,
),
dict(
action_epoch=60,
# Disable batch augmentations after 60e
# and keep other settings.
batch_augments=None,
),
]
)
]
"""
priority = 'NORMAL'
def __init__(self, schedule):
recipes = {}
for recipe in schedule:
assert 'action_epoch' in recipe, \
'Please set `action_epoch` in every item ' \
'of the `schedule` in the SwitchRecipeHook.'
recipe = deepcopy(recipe)
if 'train_pipeline' in recipe:
recipe['train_pipeline'] = Compose(recipe['train_pipeline'])
if 'batch_augments' in recipe:
batch_augments = recipe['batch_augments']
if isinstance(batch_augments, dict):
batch_augments = RandomBatchAugment(**batch_augments)
recipe['batch_augments'] = batch_augments
if 'loss' in recipe:
loss = recipe['loss']
if isinstance(loss, dict):
loss = MODELS.build(loss)
recipe['loss'] = loss
action_epoch = recipe.pop('action_epoch')
assert action_epoch not in recipes, \
f'The `action_epoch` {action_epoch} is repeated ' \
'in the SwitchRecipeHook.'
recipes[action_epoch] = recipe
self.schedule = OrderedDict(sorted(recipes.items()))
def before_train(self, runner) -> None:
"""before run setting. If resume form a checkpoint, do all switch
before the current epoch.
Args:
runner (Runner): The runner of the training, validation or testing
process.
"""
if runner._resume:
for action_epoch, recipe in self.schedule.items():
if action_epoch >= runner.epoch + 1:
break
self._do_switch(runner, recipe,
f' (resume recipe of epoch {action_epoch})')
def before_train_epoch(self, runner):
"""do before train epoch."""
recipe = self.schedule.get(runner.epoch + 1, None)
if recipe is not None:
self._do_switch(runner, recipe, f' at epoch {runner.epoch + 1}')
def _do_switch(self, runner, recipe, extra_info=''):
"""do the switch aug process."""
if 'batch_augments' in recipe:
self._switch_batch_augments(runner, recipe['batch_augments'])
runner.logger.info(f'Switch batch augments{extra_info}.')
if 'train_pipeline' in recipe:
self._switch_train_pipeline(runner, recipe['train_pipeline'])
runner.logger.info(f'Switch train pipeline{extra_info}.')
if 'loss' in recipe:
self._switch_loss(runner, recipe['loss'])
runner.logger.info(f'Switch loss{extra_info}.')
@staticmethod
def _switch_batch_augments(runner, batch_augments):
"""switch the train augments."""
model = runner.model
if is_model_wrapper(model):
model = model.module
model.data_preprocessor.batch_augments = batch_augments
@staticmethod
def _switch_train_pipeline(runner, train_pipeline):
"""switch the train loader dataset pipeline."""
def switch_pipeline(dataset, pipeline):
if hasattr(dataset, 'pipeline'):
# for usual dataset
dataset.pipeline = pipeline
elif hasattr(dataset, 'datasets'):
# for concat dataset wrapper
for ds in dataset.datasets:
switch_pipeline(ds, pipeline)
elif hasattr(dataset, 'dataset'):
# for other dataset wrappers
switch_pipeline(dataset.dataset, pipeline)
else:
raise RuntimeError(
'Cannot access the `pipeline` of the dataset.')
train_loader = runner.train_loop.dataloader
switch_pipeline(train_loader.dataset, train_pipeline)
# To restart the iterator of dataloader when `persistent_workers=True`
train_loader._iterator = None
@staticmethod
def _switch_loss(runner, loss_module):
"""switch the loss module."""
model = runner.model
if is_model_wrapper(model, MODEL_WRAPPERS):
model = model.module
if hasattr(model, 'loss_module'):
model.loss_module = loss_module
elif hasattr(model, 'head') and hasattr(model.head, 'loss_module'):
model.head.loss_module = loss_module
else:
raise RuntimeError('Cannot access the `loss_module` of the model.')
|