Spaces:
Build error
Build error
""" | |
@Date: 2021/09/14 | |
@description: | |
""" | |
class WarmupScheduler: | |
def __init__(self, optimizer, lr_pow, init_lr, warmup_lr, warmup_step, max_step, **kwargs): | |
self.lr_pow = lr_pow | |
self.init_lr = init_lr | |
self.running_lr = init_lr | |
self.warmup_lr = warmup_lr | |
self.warmup_step = warmup_step | |
self.max_step = max_step | |
self.optimizer = optimizer | |
def step_update(self, cur_step): | |
if cur_step < self.warmup_step: | |
frac = cur_step / self.warmup_step | |
step = self.warmup_lr - self.init_lr | |
self.running_lr = self.init_lr + step * frac | |
else: | |
frac = (float(cur_step) - self.warmup_step) / (self.max_step - self.warmup_step) | |
scale_running_lr = max((1. - frac), 0.) ** self.lr_pow | |
self.running_lr = self.warmup_lr * scale_running_lr | |
if self.optimizer is not None: | |
for param_group in self.optimizer.param_groups: | |
param_group['lr'] = self.running_lr | |
if __name__ == '__main__': | |
import matplotlib.pyplot as plt | |
scheduler = WarmupScheduler(optimizer=None, | |
lr_pow=4, | |
init_lr=0.0000003, | |
warmup_lr=0.00003, | |
warmup_step=10000, | |
max_step=100000) | |
x = [] | |
y = [] | |
for i in range(100000): | |
if i == 10000-1: | |
print() | |
scheduler.step_update(i) | |
x.append(i) | |
y.append(scheduler.running_lr) | |
plt.plot(x, y, linewidth=1) | |
plt.show() | |