import random from pathlib import Path from random import shuffle import PIL import pandas as pd import numpy as np import torch import torch.nn.functional as F from torch.nn.utils import clip_grad_norm_ from torchvision.transforms import ToTensor from tqdm import tqdm from hw_asr.base import BaseTrainer from hw_asr.base.base_text_encoder import BaseTextEncoder from hw_asr.logger.utils import plot_spectrogram_to_buf from hw_asr.metric.utils import calc_cer, calc_wer from hw_asr.utils import inf_loop, MetricTracker class Trainer(BaseTrainer): """ Trainer class """ def __init__( self, model, criterion, metrics, optimizer, config, device, dataloaders, text_encoder, lr_scheduler=None, len_epoch=None, skip_oom=True, ): super().__init__(model, criterion, metrics, optimizer, config, device) self.skip_oom = skip_oom self.text_encoder = text_encoder self.config = config self.train_dataloader = dataloaders["train"] if len_epoch is None: # epoch-based training self.len_epoch = len(self.train_dataloader) else: # iteration-based training self.train_dataloader = inf_loop(self.train_dataloader) self.len_epoch = len_epoch self.evaluation_dataloaders = {k: v for k, v in dataloaders.items() if k != "train"} self.lr_scheduler = lr_scheduler self.log_step = 50 self.train_metrics = MetricTracker("loss", "grad norm", *[m.name for m in self.metrics], writer=self.writer) self.evaluation_metrics = MetricTracker("loss", *[m.name for m in self.metrics], writer=self.writer) @staticmethod def move_batch_to_device(batch, device: torch.device): """ Move all necessary tensors to the HPU """ for tensor_for_gpu in ["spectrogram", "text_encoded"]: batch[tensor_for_gpu] = batch[tensor_for_gpu].to(device) return batch def _clip_grad_norm(self): if self.config["trainer"].get("grad_norm_clip", None) is not None: clip_grad_norm_(self.model.parameters(), self.config["trainer"]["grad_norm_clip"]) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() self.writer.add_scalar("epoch", epoch) for batch_idx, batch in enumerate(tqdm(self.train_dataloader, desc="train", total=self.len_epoch - 1)): try: batch = self.process_batch( batch, is_train=True, metrics=self.train_metrics, ) except RuntimeError as e: if "out of memory" in str(e) and self.skip_oom: self.logger.warning("OOM on batch. Skipping batch.") for p in self.model.parameters(): if p.grad is not None: del p.grad # free some memory torch.cuda.empty_cache() continue else: raise e self.train_metrics.update("grad norm", self.get_grad_norm()) if batch_idx % self.log_step == 0: self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.logger.debug("Train Epoch: {} {} Loss: {:.6f}".format(epoch, self._progress(batch_idx), batch["loss"].item())) self.writer.add_scalar("learning rate", self.lr_scheduler.get_last_lr()[0]) self._log_predictions(**batch) self._log_spectrogram(batch["spectrogram"]) self._log_scalars(self.train_metrics) # we don't want to reset train metrics at the start of every epoch # because we are interested in recent train metrics last_train_metrics = self.train_metrics.result() self.train_metrics.reset() if batch_idx + 1 >= self.len_epoch: break log = last_train_metrics for part, dataloader in self.evaluation_dataloaders.items(): val_log = self._evaluation_epoch(epoch, part, dataloader) log.update(**{f"{part}_{name}": value for name, value in val_log.items()}) return log def process_batch(self, batch, is_train: bool, metrics: MetricTracker, part: str = None, epoch: int = None): batch = self.move_batch_to_device(batch, self.device) if is_train: self.optimizer.zero_grad() outputs = self.model(**batch) if type(outputs) is dict: batch.update(outputs) else: batch["logits"] = outputs batch["log_probs"] = F.log_softmax(batch["logits"], dim=-1) batch["log_probs_length"] = self.model.transform_input_lengths(batch["spectrogram_length"]) batch["loss"] = self.criterion(**batch) if is_train: batch["loss"].backward() self._clip_grad_norm() self.optimizer.step() if self.lr_scheduler is not None: self.lr_scheduler.step() metrics.update("loss", batch["loss"].item()) for met in self.metrics: is_not_test = is_train or ("val" in part) is_test = not is_not_test hard_to_calc_metric = "beam search" in met.name or "LM" in met.name if hard_to_calc_metric and (is_not_test or (is_test and (epoch % 25) != 0)): continue metrics.update(met.name, met(**batch)) return batch def _evaluation_epoch(self, epoch, part, dataloader): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.evaluation_metrics.reset() with torch.no_grad(): for batch_idx, batch in tqdm( enumerate(dataloader), desc=part, total=len(dataloader), ): batch = self.process_batch(batch, is_train=False, metrics=self.evaluation_metrics, part=part, epoch=epoch) self.writer.set_step(epoch * self.len_epoch, part) self._log_predictions(**batch) self._log_spectrogram(batch["spectrogram"]) self._log_scalars(self.evaluation_metrics) # add histogram of model parameters to the tensorboard # for name, p in self.model.named_parameters(): # self.writer.add_histogram(name, p, bins="auto") return self.evaluation_metrics.result() def _progress(self, batch_idx): base = "[{}/{} ({:.0f}%)]" if hasattr(self.train_dataloader, "n_samples"): current = batch_idx * self.train_dataloader.batch_size total = self.train_dataloader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total) def _log_predictions( self, text, logits, log_probs, log_probs_length, audio_path, audio, examples_to_log=10, *args, **kwargs, ): # TODO: implement logging of beam search results if self.writer is None: return ids = np.random.choice(len(text), examples_to_log, replace=False) text = [text[i] for i in ids] logits = logits[ids] log_probs = log_probs[ids] log_probs_length = log_probs_length[ids] audio_path = [audio_path[i] for i in ids] audio = [audio[i] for i in ids] argmax_inds = log_probs.cpu().argmax(-1).numpy() argmax_inds = [inds[: int(ind_len)] for inds, ind_len in zip(argmax_inds, log_probs_length.numpy())] argmax_texts_raw = [self.text_encoder.decode(inds) for inds in argmax_inds] argmax_texts = [self.text_encoder.ctc_decode(inds) for inds in argmax_inds] probs = np.exp(log_probs.detach().cpu().numpy()) probs_length = log_probs_length.detach().cpu().numpy() bs_preds = [self.text_encoder.ctc_beam_search(prob[:prob_length], 4) for prob, prob_length in zip(probs, probs_length)] logits = logits.detach().cpu().numpy() lm_preds = [self.text_encoder.ctc_lm_beam_search(logit[:length]) for logit, length in zip(logits, probs_length)] tuples = list(zip(argmax_texts, bs_preds, lm_preds, text, argmax_texts_raw, audio_path, audio)) rows = {} for pred, bs_pred, lm_pred, target, raw_pred, audio_path, audio in tuples: target = BaseTextEncoder.normalize_text(target) wer = calc_wer(target, pred) * 100 cer = calc_cer(target, pred) * 100 bs_wer = calc_wer(target, bs_pred) * 100 bs_cer = calc_cer(target, bs_pred) * 100 lm_wer = calc_wer(target, lm_pred) * 100 lm_cer = calc_cer(target, lm_pred) * 100 rows[Path(audio_path).name] = { "orig_audio": self.writer.wandb.Audio(audio_path), # inaccurate, but no changes in the template "augm_audio": self.writer.wandb.Audio(audio.squeeze().numpy(), sample_rate=16000), # inaccurate, but no changes in the template "target": target, "raw pred": raw_pred, "pred": pred, "bs pred": bs_pred, "lm pred": lm_pred, "wer": wer, "cer": cer, "bs wer": bs_wer, "bs cer": bs_cer, "lm wer": lm_wer, "lm cer": lm_cer, } self.writer.add_table("predictions", pd.DataFrame.from_dict(rows, orient="index")) def _log_spectrogram(self, spectrogram_batch): spectrogram = random.choice(spectrogram_batch.cpu()) image = PIL.Image.open(plot_spectrogram_to_buf(spectrogram)) self.writer.add_image("spectrogram", ToTensor()(image)) @torch.no_grad() def get_grad_norm(self, norm_type=2): parameters = self.model.parameters() if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = [p for p in parameters if p.grad is not None] total_norm = torch.norm( torch.stack([torch.norm(p.grad.detach(), norm_type).cpu() for p in parameters]), norm_type, ) return total_norm.item() def _log_scalars(self, metric_tracker: MetricTracker): if self.writer is None: return for metric_name in metric_tracker.keys(): self.writer.add_scalar(f"{metric_name}", metric_tracker.avg(metric_name))