File size: 1,284 Bytes
6ed21b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from torch.optim.lr_scheduler import ReduceLROnPlateau


class WarmupThenReduceLROnPlateau(ReduceLROnPlateau):
    def __init__(self, optimizer, warmup_steps, *args, **kwargs):
        """
        Args:
            optimizer (Optimizer): Optimizer to wrap
            warmup_steps: number of steps before reaching base learning rate
            *args: Arguments for ReduceLROnPlateau
            **kwargs: Arguments for ReduceLROnPlateau
        """
        super().__init__(optimizer, *args, **kwargs)
        self.warmup_steps = warmup_steps
        self.steps_taken = 0
        self.base_lrs = list(map(lambda group: group["lr"], optimizer.param_groups))
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group["lr"] = lr

    def get_lr(self):
        assert self.steps_taken <= self.warmup_steps
        return [
            base_lr * (self.steps_taken / self.warmup_steps)
            for base_lr in self.base_lrs
        ]

    def step(self, metrics=None):
        self.steps_taken += 1
        if self.steps_taken <= self.warmup_steps:
            for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
                param_group["lr"] = lr
        elif metrics is not None:
            super().step(metrics)