File size: 4,023 Bytes
a256709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
""" Plateau Scheduler

Adapts PyTorch plateau scheduler and allows application of noise, warmup.

Hacked together by / Copyright 2020 Ross Wightman
"""
import torch

from .scheduler import Scheduler


class PlateauLRScheduler(Scheduler):
    """Decay the LR by a factor every time the validation loss plateaus."""

    def __init__(
        self,
        optimizer,
        decay_rate=0.1,
        patience_t=10,
        verbose=True,
        threshold=1e-4,
        cooldown_t=0,
        warmup_t=0,
        warmup_lr_init=0,
        lr_min=0,
        mode="max",
        noise_range_t=None,
        noise_type="normal",
        noise_pct=0.67,
        noise_std=1.0,
        noise_seed=None,
        initialize=True,
    ):
        super().__init__(optimizer, "lr", initialize=initialize)

        self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            patience=patience_t,
            factor=decay_rate,
            verbose=verbose,
            threshold=threshold,
            cooldown=cooldown_t,
            mode=mode,
            min_lr=lr_min,
        )

        self.noise_range = noise_range_t
        self.noise_pct = noise_pct
        self.noise_type = noise_type
        self.noise_std = noise_std
        self.noise_seed = noise_seed if noise_seed is not None else 42
        self.warmup_t = warmup_t
        self.warmup_lr_init = warmup_lr_init
        if self.warmup_t:
            self.warmup_steps = [
                (v - warmup_lr_init) / self.warmup_t for v in self.base_values
            ]
            super().update_groups(self.warmup_lr_init)
        else:
            self.warmup_steps = [1 for _ in self.base_values]
        self.restore_lr = None

    def state_dict(self):
        return {
            "best": self.lr_scheduler.best,
            "last_epoch": self.lr_scheduler.last_epoch,
        }

    def load_state_dict(self, state_dict):
        self.lr_scheduler.best = state_dict["best"]
        if "last_epoch" in state_dict:
            self.lr_scheduler.last_epoch = state_dict["last_epoch"]

    # override the base class step fn completely
    def step(self, epoch, metric=None):
        if epoch <= self.warmup_t:
            lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
            super().update_groups(lrs)
        else:
            if self.restore_lr is not None:
                # restore actual LR from before our last noise perturbation before stepping base
                for i, param_group in enumerate(self.optimizer.param_groups):
                    param_group["lr"] = self.restore_lr[i]
                self.restore_lr = None

            self.lr_scheduler.step(metric, epoch)  # step the base scheduler

            if self.noise_range is not None:
                if isinstance(self.noise_range, (list, tuple)):
                    apply_noise = self.noise_range[0] <= epoch < self.noise_range[1]
                else:
                    apply_noise = epoch >= self.noise_range
                if apply_noise:
                    self._apply_noise(epoch)

    def _apply_noise(self, epoch):
        g = torch.Generator()
        g.manual_seed(self.noise_seed + epoch)
        if self.noise_type == "normal":
            while True:
                # resample if noise out of percent limit, brute force but shouldn't spin much
                noise = torch.randn(1, generator=g).item()
                if abs(noise) < self.noise_pct:
                    break
        else:
            noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct

        # apply the noise on top of previous LR, cache the old value so we can restore for normal
        # stepping of base scheduler
        restore_lr = []
        for i, param_group in enumerate(self.optimizer.param_groups):
            old_lr = float(param_group["lr"])
            restore_lr.append(old_lr)
            new_lr = old_lr + old_lr * noise
            param_group["lr"] = new_lr
        self.restore_lr = restore_lr