File size: 3,781 Bytes
3094730
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional

from mmengine.hooks import ParamSchedulerHook
from mmengine.runner import Runner

from mmyolo.registry import HOOKS


@HOOKS.register_module()
class PPYOLOEParamSchedulerHook(ParamSchedulerHook):
    """A hook to update learning rate and momentum in optimizer of PPYOLOE. We
    use this hook to implement adaptive computation for `warmup_total_iters`,
    which is not possible with the built-in ParamScheduler in mmyolo.

    Args:
        warmup_min_iter (int): Minimum warmup iters. Defaults to 1000.
        start_factor (float): The number we multiply learning rate in the
            first epoch. The multiplication factor changes towards end_factor
            in the following epochs. Defaults to 0.
        warmup_epochs (int): Epochs for warmup. Defaults to 5.
        min_lr_ratio (float): Minimum learning rate ratio.
        total_epochs (int): In PPYOLOE, `total_epochs` is set to
            training_epochs x 1.2. Defaults to 360.
    """
    priority = 9

    def __init__(self,
                 warmup_min_iter: int = 1000,
                 start_factor: float = 0.,
                 warmup_epochs: int = 5,
                 min_lr_ratio: float = 0.0,
                 total_epochs: int = 360):

        self.warmup_min_iter = warmup_min_iter
        self.start_factor = start_factor
        self.warmup_epochs = warmup_epochs
        self.min_lr_ratio = min_lr_ratio
        self.total_epochs = total_epochs

        self._warmup_end = False
        self._base_lr = None

    def before_train(self, runner: Runner):
        """Operations before train.

        Args:
            runner (Runner): The runner of the training process.
        """
        optimizer = runner.optim_wrapper.optimizer
        for group in optimizer.param_groups:
            # If the param is never be scheduled, record the current value
            # as the initial value.
            group.setdefault('initial_lr', group['lr'])

        self._base_lr = [
            group['initial_lr'] for group in optimizer.param_groups
        ]
        self._min_lr = [i * self.min_lr_ratio for i in self._base_lr]

    def before_train_iter(self,
                          runner: Runner,
                          batch_idx: int,
                          data_batch: Optional[dict] = None):
        """Operations before each training iteration.

        Args:
            runner (Runner): The runner of the training process.
            batch_idx (int): The index of the current batch in the train loop.
            data_batch (dict or tuple or list, optional): Data from dataloader.
        """
        cur_iters = runner.iter
        optimizer = runner.optim_wrapper.optimizer
        dataloader_len = len(runner.train_dataloader)

        # The minimum warmup is self.warmup_min_iter
        warmup_total_iters = max(
            round(self.warmup_epochs * dataloader_len), self.warmup_min_iter)

        if cur_iters <= warmup_total_iters:
            # warm up
            alpha = cur_iters / warmup_total_iters
            factor = self.start_factor * (1 - alpha) + alpha

            for group_idx, param in enumerate(optimizer.param_groups):
                param['lr'] = self._base_lr[group_idx] * factor
        else:
            for group_idx, param in enumerate(optimizer.param_groups):
                total_iters = self.total_epochs * dataloader_len
                lr = self._min_lr[group_idx] + (
                    self._base_lr[group_idx] -
                    self._min_lr[group_idx]) * 0.5 * (
                        math.cos((cur_iters - warmup_total_iters) * math.pi /
                                 (total_iters - warmup_total_iters)) + 1.0)
                param['lr'] = lr