File size: 1,327 Bytes
991f07c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
#!/usr/bin/env python3
# coding=utf-8
from utility.schedule.linear_lr import LinearLr
def multi_scheduler_wrapper(optimizer, args, steps_per_epoch):
n_layers = (len(optimizer.param_groups) - 2) // 2
return MultiScheduler(
[
LinearLr(optimizer.param_groups[i], args.encoder_learning_rate * (args.layerwise_lr_decay ** i), args.epochs * steps_per_epoch, False, args.lr_decay_multiplier)
for i in range(n_layers)
]
+
[
LinearLr(optimizer.param_groups[n_layers + i], args.encoder_learning_rate * (args.layerwise_lr_decay ** i), args.epochs * steps_per_epoch, False, args.lr_decay_multiplier)
for i in range(n_layers)
]
+
[
LinearLr(optimizer.param_groups[-2], args.decoder_learning_rate, args.epochs * steps_per_epoch, False, args.lr_decay_multiplier),
LinearLr(optimizer.param_groups[-1], args.decoder_learning_rate, args.epochs * steps_per_epoch, False, args.lr_decay_multiplier)
]
)
class MultiScheduler:
def __init__(self, schedulers):
self.schedulers = schedulers
def __call__(self, epoch):
for scheduler in self.schedulers:
scheduler(epoch)
def lr(self) -> float:
return [scheduler.lr() for scheduler in self.schedulers]
|