Spaces:
Runtime error
Runtime error
File size: 1,227 Bytes
8044721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
#!/usr/bin/env python3
# coding=utf-8
import math
class LinearLr:
def __init__(self, param_group, learning_rate: float, total_steps: int, delay: bool, multiplier: int):
self.total_steps = total_steps
self.delay_steps = total_steps / 20 if delay else 0
self.max_lr = learning_rate
self.steps = 0
self.param_group = param_group
self.decay_multiplier = multiplier
def __call__(self, _):
self.steps += 1
if self.steps < self.delay_steps:
lr = 0.0
elif self.steps < self.total_steps / 10:
lr = self.max_lr * (self.steps - self.delay_steps) / (self.total_steps / 10 - self.delay_steps)
else:
max_lr = self.max_lr - self.max_lr / self.decay_multiplier
min_lr = self.max_lr / self.decay_multiplier
lr = max_lr * (math.cos(math.pi * (self.steps - self.total_steps / 10) / (self.total_steps * 9 / 10)) + 1) / 2 + min_lr
#lr = self.max_lr * (self.total_steps - self.steps) / (self.total_steps * 9 / 10)
# Safety first!
if lr < 0.0:
lr = 0.0
self.param_group["lr"] = lr
def lr(self) -> float:
return self.param_group["lr"]
|