File size: 4,217 Bytes
78e32cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
from torch.optim.optimizer import Optimizer
import pytorch_lightning as pl
from torch.optim.lr_scheduler import _LRScheduler


class BaseScheduler(object):
    """Base class for the step-wise scheduler logic.

    Args:
        optimizer (Optimize): Optimizer instance to apply lr schedule on.

    Subclass this and overwrite ``_get_lr`` to write your own step-wise scheduler.
    """

    def __init__(self, optimizer):
        self.optimizer = optimizer
        self.step_num = 0

    def zero_grad(self):
        self.optimizer.zero_grad()

    def _get_lr(self):
        raise NotImplementedError

    def _set_lr(self, lr):
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr

    def step(self, metrics=None, epoch=None):
        """Update step-wise learning rate before optimizer.step."""
        self.step_num += 1
        lr = self._get_lr()
        self._set_lr(lr)

    def load_state_dict(self, state_dict):
        self.__dict__.update(state_dict)

    def state_dict(self):
        return {key: value for key, value in self.__dict__.items() if key != "optimizer"}

    def as_tensor(self, start=0, stop=100_000):
        """Returns the scheduler values from start to stop."""
        lr_list = []
        for _ in range(start, stop):
            self.step_num += 1
            lr_list.append(self._get_lr())
        self.step_num = 0
        return torch.tensor(lr_list)

    def plot(self, start=0, stop=100_000):  # noqa
        """Plot the scheduler values from start to stop."""
        import matplotlib.pyplot as plt

        all_lr = self.as_tensor(start=start, stop=stop)
        plt.plot(all_lr.numpy())
        plt.show()

class DPTNetScheduler(BaseScheduler):
    """Dual Path Transformer Scheduler used in [1]

    Args:
        optimizer (Optimizer): Optimizer instance to apply lr schedule on.
        steps_per_epoch (int): Number of steps per epoch.
        d_model(int): The number of units in the layer output.
        warmup_steps (int): The number of steps in the warmup stage of training.
        noam_scale (float): Linear increase rate in first phase.
        exp_max (float): Max learning rate in second phase.
        exp_base (float): Exp learning rate base in second phase.

    Schedule:
        This scheduler increases the learning rate linearly for the first
        ``warmup_steps``, and then decay it by 0.98 for every two epochs.

    References
        [1]: Jingjing Chen et al. "Dual-Path Transformer Network: Direct Context-
        Aware Modeling for End-to-End Monaural Speech Separation" Interspeech 2020.
    """

    def __init__(
        self,
        optimizer,
        steps_per_epoch,
        d_model,
        warmup_steps=4000,
        noam_scale=1.0,
        exp_max=0.0004,
        exp_base=0.98,
    ):
        super().__init__(optimizer)
        self.noam_scale = noam_scale
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.exp_max = exp_max
        self.exp_base = exp_base
        self.steps_per_epoch = steps_per_epoch
        self.epoch = 0

    def _get_lr(self):
        if self.step_num % self.steps_per_epoch == 0:
            self.epoch += 1

        if self.step_num > self.warmup_steps:
            # exp decaying
            lr = self.exp_max * (self.exp_base ** ((self.epoch - 1) // 2))
        else:
            # noam
            lr = (
                self.noam_scale
                * self.d_model ** (-0.5)
                * min(self.step_num ** (-0.5), self.step_num * self.warmup_steps ** (-1.5))
            )
        return lr

class CustomExponentialLR(_LRScheduler):
    def __init__(self, optimizer, gamma, step_size, last_epoch=-1):
        self.gamma = gamma
        self.step_size = step_size
        self.base_lrs = list(map(lambda group: group['lr'], optimizer.param_groups))
        super(CustomExponentialLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch == 0 or (self.last_epoch + 1) % self.step_size != 0:
            return [group['lr'] for group in self.optimizer.param_groups]
        return [lr * self.gamma for lr in self.base_lrs]


# Backward compat
_BaseScheduler = BaseScheduler