File size: 4,309 Bytes
1ea89dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import weakref

import numpy as np


class PlainCosineScheduler(object):
    def __init__(
        self,
        klass,
        key,
        warmup_iters,
        total_iters,
        overwrite=False,
        init_value=None,
        base_value=None,
        final_value=None,
        step_init=-1,
    ):
        super().__init__()
        self.iter = step_init
        self.overwrite = overwrite
        self.base_value = base_value
        self.init_value = init_value if init_value is not None else base_value
        self.final_value = final_value
        self.total_iters = total_iters
        self.warmup_iters = warmup_iters
        self.key = key
        self.klass = klass
        self.schedulers = [self.get_scheduler()]

    def get_scheduler(self):
        init_value = self.init_value
        base_value = self.base_value
        final_value = self.final_value
        warmup_iters = self.warmup_iters
        total_iters = self.total_iters

        # normalize in 0,1, then apply function (power) and denormalize
        normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True)
        normalized_schedule = np.power(normalized_schedule, 1)
        warmup_schedule = (base_value - init_value) * normalized_schedule + init_value

        # main scheduling
        iters = np.arange(total_iters - warmup_iters + 1)
        schedule = final_value + 0.5 * (base_value - final_value) * (
            1 + np.cos(np.pi * iters / (len(iters) - 1))
        )
        return np.concatenate((warmup_schedule, schedule))

    def step(self):
        self.iter = self.iter + 1
        vals = self[self.iter]
        for i, val in enumerate(vals):
            setattr(self.klass, self.key, val)

    def __getitem__(self, it):
        it = min(it, self.total_iters)
        return [scheduler[it] for scheduler in self.schedulers]


class CosineScheduler(object):
    def __init__(
        self,
        optimizer,
        warmup_iters,
        total_iters,
        key,
        overwrite=False,
        init_value=None,
        base_value=None,
        final_value=None,
        flat_iters=0,
        step_init=-1,
    ):
        super().__init__()
        self.iter = step_init
        self.overwrite = overwrite
        self.optimizer = optimizer
        self.base_value = base_value
        self.init_value = init_value
        self.final_value = final_value
        self.total_iters = total_iters
        self.warmup_iters = warmup_iters
        self.flat_iters = flat_iters
        self.key = key
        self.schedulers = [
            self.get_schedulers(group) for group in optimizer.param_groups
        ]

    def get_schedulers(self, group):
        init_value = group.get(self.key + "_init", self.init_value)
        base_value = group.get(self.key + "_base", self.base_value)
        final_value = group.get(self.key + "_final", self.final_value)
        warmup_iters = self.warmup_iters
        total_iters = self.total_iters
        flat_iters = self.flat_iters
        if self.overwrite:
            final_value = self.final_value

        # normalize in 0,1, then apply function (power) and denormalize
        normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True)
        normalized_schedule = np.power(normalized_schedule, 1)
        warmup_schedule = (base_value - init_value) * normalized_schedule + init_value

        # flat scheduling]
        flat_schedule = np.ones(flat_iters) * base_value

        # decay scheduling
        decay_iters = np.arange(total_iters - warmup_iters - flat_iters + 1)
        decay_schedule = final_value + 0.5 * (base_value - final_value) * (
            1 + np.cos(np.pi * decay_iters / (len(decay_iters) - 1))
        )
        return np.concatenate((warmup_schedule, flat_schedule, decay_schedule))

    def step(self):
        self.iter = self.iter + 1
        vals = self[self.iter]
        for group, val in zip(self.optimizer.param_groups, vals):
            if isinstance(group[self.key], (tuple, list)):
                val = (val, *group[self.key][1:])
            group[self.key] = val

    def __getitem__(self, it):
        it = min(it, self.total_iters)
        return [scheduler[it] for scheduler in self.schedulers]

    def get(self):
        return [group[self.key] for group in self.optimizer.param_groups]