Spaces:
Running
Running
# Copyright 2024 The YourMT3 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Please see the details in the LICENSE file. | |
"""lr_schedule.py""" | |
import torch | |
from typing import Dict, Optional | |
def get_lr_scheduler(optimizer: torch.optim.Optimizer, scheduler_name: str, base_lr: float, scheduler_cfg: Dict): | |
if scheduler_name.lower() == 'cosine': | |
from torch.optim.lr_scheduler import ( | |
SequentialLR, | |
LinearLR, | |
CosineAnnealingLR, | |
) | |
scheduler1 = LinearLR( | |
optimizer, | |
start_factor=0.5, | |
end_factor=1, | |
total_iters=scheduler_cfg["warmup_steps"], | |
last_epoch=-1, | |
) | |
scheduler2 = CosineAnnealingLR( | |
optimizer, | |
T_max=scheduler_cfg["total_steps"] - scheduler_cfg["warmup_steps"], | |
eta_min=scheduler_cfg["final_cosine"], | |
) | |
lr_scheduler = SequentialLR(optimizer, | |
schedulers=[scheduler1, scheduler2], | |
milestones=[scheduler_cfg["warmup_steps"]]) | |
elif scheduler_name.lower() == 'legacy': | |
import math | |
from torch.optim.lr_scheduler import ( | |
SequentialLR, | |
LinearLR, | |
LambdaLR, | |
) | |
msg = "You are using T5 legacy LR Schedule, it's independent from the optim.base_lr" | |
print(msg) | |
num_steps_optimizer1 = math.ceil(scheduler_cfg["total_steps"] * 0.9) | |
iters_left_for_optimizer2 = scheduler_cfg["total_steps"] - num_steps_optimizer1 | |
scheduler1 = LambdaLR(optimizer, lambda step: min(base_lr, 1.0 / math.sqrt(step)) / base_lr | |
if step else base_lr / base_lr) | |
scheduler2 = LinearLR(optimizer, | |
start_factor=(min(base_lr, 1.0 / math.sqrt(num_steps_optimizer1)) / base_lr), | |
end_factor=0, | |
total_iters=iters_left_for_optimizer2, | |
last_epoch=-1) | |
lr_scheduler = SequentialLR( | |
optimizer, | |
schedulers=[scheduler1, scheduler2], | |
milestones=[num_steps_optimizer1], | |
) | |
elif scheduler_name.lower() == 'constant': | |
from transformers import get_scheduler | |
lr_scheduler = get_scheduler( | |
name=scheduler_name.lower(), | |
optimizer=optimizer, | |
) | |
else: | |
raise NotImplementedError | |
return lr_scheduler | |
def extra_stats(args, model, optimizer): | |
stats = {} | |
if args.logging.weights_l2: | |
weights_l2 = sum(p.detach().norm(2).item()**2 for p in model.parameters())**0.5 | |
stats['weights_l2'] = weights_l2 | |
cur_lr = optimizer.param_groups[0]['lr'] | |
stats['lr'] = cur_lr | |
return stats | |