File size: 1,227 Bytes
8044721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#!/usr/bin/env python3
# coding=utf-8

import math


class LinearLr:
    def __init__(self, param_group, learning_rate: float, total_steps: int, delay: bool, multiplier: int):
        self.total_steps = total_steps
        self.delay_steps = total_steps / 20 if delay else 0
        self.max_lr = learning_rate
        self.steps = 0
        self.param_group = param_group
        self.decay_multiplier = multiplier

    def __call__(self, _):
        self.steps += 1

        if self.steps < self.delay_steps:
            lr = 0.0
        elif self.steps < self.total_steps / 10:
            lr = self.max_lr * (self.steps - self.delay_steps) / (self.total_steps / 10 - self.delay_steps)
        else:
            max_lr = self.max_lr - self.max_lr / self.decay_multiplier
            min_lr = self.max_lr / self.decay_multiplier
            lr = max_lr * (math.cos(math.pi * (self.steps - self.total_steps / 10) / (self.total_steps * 9 / 10)) + 1) / 2 + min_lr
            #lr = self.max_lr * (self.total_steps - self.steps) / (self.total_steps * 9 / 10)

        # Safety first!
        if lr < 0.0:
            lr = 0.0

        self.param_group["lr"] = lr

    def lr(self) -> float:
        return self.param_group["lr"]