|
from abc import abstractmethod |
|
|
|
import torch |
|
from numpy import inf |
|
|
|
from hw_asr.base import BaseModel |
|
from hw_asr.logger import get_visualizer |
|
|
|
|
|
class BaseTrainer: |
|
""" |
|
Base class for all trainers |
|
""" |
|
|
|
def __init__(self, model: BaseModel, criterion, metrics, optimizer, config, device): |
|
self.device = device |
|
self.config = config |
|
self.logger = config.get_logger("trainer", config["trainer"]["verbosity"]) |
|
|
|
self.model = model |
|
self.criterion = criterion |
|
self.metrics = metrics |
|
self.optimizer = optimizer |
|
|
|
|
|
self._last_epoch = 0 |
|
|
|
cfg_trainer = config["trainer"] |
|
self.epochs = cfg_trainer["epochs"] |
|
self.save_period = cfg_trainer["save_period"] |
|
self.monitor = cfg_trainer.get("monitor", "off") |
|
|
|
|
|
if self.monitor == "off": |
|
self.mnt_mode = "off" |
|
self.mnt_best = 0 |
|
else: |
|
self.mnt_mode, self.mnt_metric = self.monitor.split() |
|
assert self.mnt_mode in ["min", "max"] |
|
|
|
self.mnt_best = inf if self.mnt_mode == "min" else -inf |
|
self.early_stop = cfg_trainer.get("early_stop", inf) |
|
if self.early_stop <= 0: |
|
self.early_stop = inf |
|
|
|
self.start_epoch = 1 |
|
|
|
self.checkpoint_dir = config.save_dir |
|
|
|
|
|
self.writer = get_visualizer( |
|
config, self.logger, cfg_trainer["visualize"] |
|
) |
|
|
|
if config.resume is not None: |
|
self._load_model(config.resume) |
|
|
|
@abstractmethod |
|
def _train_epoch(self, epoch): |
|
""" |
|
Training logic for an epoch |
|
|
|
:param epoch: Current epoch number |
|
""" |
|
raise NotImplementedError() |
|
|
|
def train(self): |
|
try: |
|
self._train_process() |
|
except KeyboardInterrupt as e: |
|
self.logger.info("Saving model on keyboard interrupt") |
|
self._save_checkpoint(self._last_epoch, save_best=False) |
|
raise e |
|
|
|
def _train_process(self): |
|
""" |
|
Full training logic |
|
""" |
|
not_improved_count = 0 |
|
for epoch in range(self.start_epoch, self.epochs + 1): |
|
self._last_epoch = epoch |
|
result = self._train_epoch(epoch) |
|
|
|
|
|
log = {"epoch": epoch} |
|
log.update(result) |
|
|
|
|
|
for key, value in log.items(): |
|
self.logger.info(" {:15s}: {}".format(str(key), value)) |
|
|
|
|
|
|
|
best = False |
|
if self.mnt_mode != "off": |
|
try: |
|
|
|
|
|
if self.mnt_mode == "min": |
|
improved = log[self.mnt_metric] <= self.mnt_best |
|
elif self.mnt_mode == "max": |
|
improved = log[self.mnt_metric] >= self.mnt_best |
|
else: |
|
improved = False |
|
except KeyError: |
|
self.logger.warning( |
|
"Warning: Metric '{}' is not found. " |
|
"Model performance monitoring is disabled.".format( |
|
self.mnt_metric |
|
) |
|
) |
|
self.mnt_mode = "off" |
|
improved = False |
|
|
|
if improved: |
|
self.mnt_best = log[self.mnt_metric] |
|
not_improved_count = 0 |
|
best = True |
|
else: |
|
not_improved_count += 1 |
|
|
|
if not_improved_count > self.early_stop: |
|
self.logger.info( |
|
"Validation performance didn't improve for {} epochs. " |
|
"Training stops.".format(self.early_stop) |
|
) |
|
break |
|
|
|
if epoch % self.save_period == 0 or best: |
|
self._save_checkpoint(epoch, save_best=best, only_best=True) |
|
|
|
def _save_checkpoint(self, epoch, save_best=False, only_best=False): |
|
""" |
|
Saving checkpoints |
|
|
|
:param epoch: current epoch number |
|
:param save_best: if True, rename the saved checkpoint to 'model_best.pth' |
|
""" |
|
arch = type(self.model).__name__ |
|
state = { |
|
"arch": arch, |
|
"epoch": epoch, |
|
"state_dict": self.model.state_dict(), |
|
"optimizer": self.optimizer.state_dict(), |
|
"monitor_best": self.mnt_best, |
|
"config": self.config, |
|
} |
|
filename = str(self.checkpoint_dir / "checkpoint-epoch{}.pth".format(epoch)) |
|
if not (only_best and save_best): |
|
torch.save(state, filename) |
|
self.logger.info("Saving checkpoint: {} ...".format(filename)) |
|
if save_best: |
|
best_path = str(self.checkpoint_dir / "model_best.pth") |
|
torch.save(state, best_path) |
|
self.logger.info("Saving current best: model_best.pth ...") |
|
|
|
def _load_model(self, resume_path): |
|
""" |
|
Resume from saved checkpoints |
|
|
|
:param resume_path: Checkpoint path to be resumed |
|
""" |
|
resume_path = str(resume_path) |
|
self.logger.info("Loading model: {} ...".format(resume_path)) |
|
checkpoint = torch.load(resume_path, self.device) |
|
self.model.load_state_dict(checkpoint["state_dict"]) |
|
|
|
self.logger.info("Model loaded") |
|
|
|
|
|
def _resume_checkpoint(self, resume_path): |
|
""" |
|
Resume from saved checkpoints |
|
|
|
:param resume_path: Checkpoint path to be resumed |
|
""" |
|
resume_path = str(resume_path) |
|
self.logger.info("Loading checkpoint: {} ...".format(resume_path)) |
|
checkpoint = torch.load(resume_path, self.device) |
|
self.start_epoch = checkpoint["epoch"] + 1 |
|
self.mnt_best = checkpoint["monitor_best"] |
|
|
|
|
|
if checkpoint["config"]["arch"] != self.config["arch"]: |
|
self.logger.warning( |
|
"Warning: Architecture configuration given in config file is different from that " |
|
"of checkpoint. This may yield an exception while state_dict is being loaded." |
|
) |
|
self.model.load_state_dict(checkpoint["state_dict"]) |
|
|
|
|
|
if ( |
|
checkpoint["config"]["optimizer"] != self.config["optimizer"] or |
|
checkpoint["config"]["lr_scheduler"] != self.config["lr_scheduler"] |
|
): |
|
self.logger.warning( |
|
"Warning: Optimizer or lr_scheduler given in config file is different " |
|
"from that of checkpoint. Optimizer parameters not being resumed." |
|
) |
|
else: |
|
self.optimizer.load_state_dict(checkpoint["optimizer"]) |
|
|
|
self.logger.info( |
|
"Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch) |
|
) |
|
|