whisper-small-belarusian / run_3 /src /custom_trainer.py
ales's picture
upload train run_3 artifacts
e6570a6
raw
history blame
2 kB
import logging
import sys
import torch
import transformers
from transformers import Seq2SeqTrainer
logger = logging.getLogger('custom_trainer')
logger.setLevel(logging.INFO)
class Seq2SeqTrainerCustomLinearScheduler(Seq2SeqTrainer):
"""
Custom trainer to initialize Learning Rate Scheduler
and define the learning rate in the end of a training.
"""
@staticmethod
def scheduler_n_steps_for_fixed_lr_in_end(lr_max, lr_end, num_train_steps, num_warmup_steps) -> int:
assert lr_end < lr_max
return num_warmup_steps + (num_train_steps - num_warmup_steps) * lr_max / (lr_max - lr_end)
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
use_custom_scheduler = False
try:
# if learning_rate_end was passed as an argument
learning_rate_end = self.args.learning_rate_end
use_custom_scheduler = True
logger.info('TrainerCustomLinearScheduler.create_scheduler(). '
f'initializing custom linear scheduler using learning_rate_end={learning_rate_end}')
except:
logger.info('TrainerCustomLinearScheduler.create_scheduler(). '
'learning_rate_end was not set. fallback to a default behavior')
if use_custom_scheduler is True:
scheduler_num_steps = self.scheduler_n_steps_for_fixed_lr_in_end(
lr_max=self.args.learning_rate,
lr_end=learning_rate_end,
num_train_steps=num_training_steps,
num_warmup_steps=self.args.warmup_steps
)
self.lr_scheduler = transformers.get_scheduler(
'linear', optimizer=optimizer,
num_warmup_steps=self.args.warmup_steps,
num_training_steps=scheduler_num_steps
)
return self.lr_scheduler
else:
return super().create_scheduler(num_training_steps, optimizer)