|
import random |
|
from pathlib import Path |
|
from random import shuffle |
|
|
|
import PIL |
|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torch.nn.utils import clip_grad_norm_ |
|
from torchvision.transforms import ToTensor |
|
from tqdm import tqdm |
|
|
|
from hw_asr.base import BaseTrainer |
|
from hw_asr.base.base_text_encoder import BaseTextEncoder |
|
from hw_asr.logger.utils import plot_spectrogram_to_buf |
|
from hw_asr.metric.utils import calc_cer, calc_wer |
|
from hw_asr.utils import inf_loop, MetricTracker |
|
|
|
|
|
class Trainer(BaseTrainer): |
|
""" |
|
Trainer class |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model, |
|
criterion, |
|
metrics, |
|
optimizer, |
|
config, |
|
device, |
|
dataloaders, |
|
text_encoder, |
|
lr_scheduler=None, |
|
len_epoch=None, |
|
skip_oom=True, |
|
): |
|
super().__init__(model, criterion, metrics, optimizer, config, device) |
|
self.skip_oom = skip_oom |
|
self.text_encoder = text_encoder |
|
self.config = config |
|
self.train_dataloader = dataloaders["train"] |
|
if len_epoch is None: |
|
|
|
self.len_epoch = len(self.train_dataloader) |
|
else: |
|
|
|
self.train_dataloader = inf_loop(self.train_dataloader) |
|
self.len_epoch = len_epoch |
|
self.evaluation_dataloaders = {k: v for k, v in dataloaders.items() if k != "train"} |
|
self.lr_scheduler = lr_scheduler |
|
self.log_step = 50 |
|
|
|
self.train_metrics = MetricTracker("loss", "grad norm", *[m.name for m in self.metrics], writer=self.writer) |
|
self.evaluation_metrics = MetricTracker("loss", *[m.name for m in self.metrics], writer=self.writer) |
|
|
|
@staticmethod |
|
def move_batch_to_device(batch, device: torch.device): |
|
""" |
|
Move all necessary tensors to the HPU |
|
""" |
|
for tensor_for_gpu in ["spectrogram", "text_encoded"]: |
|
batch[tensor_for_gpu] = batch[tensor_for_gpu].to(device) |
|
return batch |
|
|
|
def _clip_grad_norm(self): |
|
if self.config["trainer"].get("grad_norm_clip", None) is not None: |
|
clip_grad_norm_(self.model.parameters(), self.config["trainer"]["grad_norm_clip"]) |
|
|
|
def _train_epoch(self, epoch): |
|
""" |
|
Training logic for an epoch |
|
|
|
:param epoch: Integer, current training epoch. |
|
:return: A log that contains average loss and metric in this epoch. |
|
""" |
|
self.model.train() |
|
self.train_metrics.reset() |
|
self.writer.add_scalar("epoch", epoch) |
|
for batch_idx, batch in enumerate(tqdm(self.train_dataloader, desc="train", total=self.len_epoch - 1)): |
|
try: |
|
batch = self.process_batch( |
|
batch, |
|
is_train=True, |
|
metrics=self.train_metrics, |
|
) |
|
except RuntimeError as e: |
|
if "out of memory" in str(e) and self.skip_oom: |
|
self.logger.warning("OOM on batch. Skipping batch.") |
|
for p in self.model.parameters(): |
|
if p.grad is not None: |
|
del p.grad |
|
torch.cuda.empty_cache() |
|
continue |
|
else: |
|
raise e |
|
self.train_metrics.update("grad norm", self.get_grad_norm()) |
|
if batch_idx % self.log_step == 0: |
|
self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) |
|
self.logger.debug("Train Epoch: {} {} Loss: {:.6f}".format(epoch, self._progress(batch_idx), batch["loss"].item())) |
|
self.writer.add_scalar("learning rate", self.lr_scheduler.get_last_lr()[0]) |
|
self._log_predictions(**batch) |
|
self._log_spectrogram(batch["spectrogram"]) |
|
self._log_scalars(self.train_metrics) |
|
|
|
|
|
last_train_metrics = self.train_metrics.result() |
|
self.train_metrics.reset() |
|
if batch_idx + 1 >= self.len_epoch: |
|
break |
|
log = last_train_metrics |
|
|
|
for part, dataloader in self.evaluation_dataloaders.items(): |
|
val_log = self._evaluation_epoch(epoch, part, dataloader) |
|
log.update(**{f"{part}_{name}": value for name, value in val_log.items()}) |
|
|
|
return log |
|
|
|
def process_batch(self, batch, is_train: bool, metrics: MetricTracker, part: str = None, epoch: int = None): |
|
batch = self.move_batch_to_device(batch, self.device) |
|
if is_train: |
|
self.optimizer.zero_grad() |
|
outputs = self.model(**batch) |
|
if type(outputs) is dict: |
|
batch.update(outputs) |
|
else: |
|
batch["logits"] = outputs |
|
|
|
batch["log_probs"] = F.log_softmax(batch["logits"], dim=-1) |
|
batch["log_probs_length"] = self.model.transform_input_lengths(batch["spectrogram_length"]) |
|
batch["loss"] = self.criterion(**batch) |
|
if is_train: |
|
batch["loss"].backward() |
|
self._clip_grad_norm() |
|
self.optimizer.step() |
|
if self.lr_scheduler is not None: |
|
self.lr_scheduler.step() |
|
|
|
metrics.update("loss", batch["loss"].item()) |
|
for met in self.metrics: |
|
is_not_test = is_train or ("val" in part) |
|
is_test = not is_not_test |
|
hard_to_calc_metric = "beam search" in met.name or "LM" in met.name |
|
if hard_to_calc_metric and (is_not_test or (is_test and (epoch % 25) != 0)): |
|
continue |
|
metrics.update(met.name, met(**batch)) |
|
return batch |
|
|
|
def _evaluation_epoch(self, epoch, part, dataloader): |
|
""" |
|
Validate after training an epoch |
|
|
|
:param epoch: Integer, current training epoch. |
|
:return: A log that contains information about validation |
|
""" |
|
self.model.eval() |
|
self.evaluation_metrics.reset() |
|
with torch.no_grad(): |
|
for batch_idx, batch in tqdm( |
|
enumerate(dataloader), |
|
desc=part, |
|
total=len(dataloader), |
|
): |
|
batch = self.process_batch(batch, is_train=False, metrics=self.evaluation_metrics, part=part, epoch=epoch) |
|
self.writer.set_step(epoch * self.len_epoch, part) |
|
self._log_predictions(**batch) |
|
self._log_spectrogram(batch["spectrogram"]) |
|
self._log_scalars(self.evaluation_metrics) |
|
|
|
|
|
|
|
|
|
return self.evaluation_metrics.result() |
|
|
|
def _progress(self, batch_idx): |
|
base = "[{}/{} ({:.0f}%)]" |
|
if hasattr(self.train_dataloader, "n_samples"): |
|
current = batch_idx * self.train_dataloader.batch_size |
|
total = self.train_dataloader.n_samples |
|
else: |
|
current = batch_idx |
|
total = self.len_epoch |
|
return base.format(current, total, 100.0 * current / total) |
|
|
|
def _log_predictions( |
|
self, |
|
text, |
|
logits, |
|
log_probs, |
|
log_probs_length, |
|
audio_path, |
|
audio, |
|
examples_to_log=10, |
|
*args, |
|
**kwargs, |
|
): |
|
|
|
if self.writer is None: |
|
return |
|
|
|
ids = np.random.choice(len(text), examples_to_log, replace=False) |
|
text = [text[i] for i in ids] |
|
logits = logits[ids] |
|
log_probs = log_probs[ids] |
|
log_probs_length = log_probs_length[ids] |
|
audio_path = [audio_path[i] for i in ids] |
|
audio = [audio[i] for i in ids] |
|
|
|
argmax_inds = log_probs.cpu().argmax(-1).numpy() |
|
argmax_inds = [inds[: int(ind_len)] for inds, ind_len in zip(argmax_inds, log_probs_length.numpy())] |
|
argmax_texts_raw = [self.text_encoder.decode(inds) for inds in argmax_inds] |
|
argmax_texts = [self.text_encoder.ctc_decode(inds) for inds in argmax_inds] |
|
|
|
probs = np.exp(log_probs.detach().cpu().numpy()) |
|
probs_length = log_probs_length.detach().cpu().numpy() |
|
bs_preds = [self.text_encoder.ctc_beam_search(prob[:prob_length], 4) for prob, prob_length in zip(probs, probs_length)] |
|
|
|
logits = logits.detach().cpu().numpy() |
|
lm_preds = [self.text_encoder.ctc_lm_beam_search(logit[:length]) for logit, length in zip(logits, probs_length)] |
|
|
|
tuples = list(zip(argmax_texts, bs_preds, lm_preds, text, argmax_texts_raw, audio_path, audio)) |
|
rows = {} |
|
for pred, bs_pred, lm_pred, target, raw_pred, audio_path, audio in tuples: |
|
target = BaseTextEncoder.normalize_text(target) |
|
wer = calc_wer(target, pred) * 100 |
|
cer = calc_cer(target, pred) * 100 |
|
|
|
bs_wer = calc_wer(target, bs_pred) * 100 |
|
bs_cer = calc_cer(target, bs_pred) * 100 |
|
|
|
lm_wer = calc_wer(target, lm_pred) * 100 |
|
lm_cer = calc_cer(target, lm_pred) * 100 |
|
|
|
rows[Path(audio_path).name] = { |
|
"orig_audio": self.writer.wandb.Audio(audio_path), |
|
"augm_audio": self.writer.wandb.Audio(audio.squeeze().numpy(), sample_rate=16000), |
|
"target": target, |
|
"raw pred": raw_pred, |
|
"pred": pred, |
|
"bs pred": bs_pred, |
|
"lm pred": lm_pred, |
|
"wer": wer, |
|
"cer": cer, |
|
"bs wer": bs_wer, |
|
"bs cer": bs_cer, |
|
"lm wer": lm_wer, |
|
"lm cer": lm_cer, |
|
} |
|
self.writer.add_table("predictions", pd.DataFrame.from_dict(rows, orient="index")) |
|
|
|
def _log_spectrogram(self, spectrogram_batch): |
|
spectrogram = random.choice(spectrogram_batch.cpu()) |
|
image = PIL.Image.open(plot_spectrogram_to_buf(spectrogram)) |
|
self.writer.add_image("spectrogram", ToTensor()(image)) |
|
|
|
@torch.no_grad() |
|
def get_grad_norm(self, norm_type=2): |
|
parameters = self.model.parameters() |
|
if isinstance(parameters, torch.Tensor): |
|
parameters = [parameters] |
|
parameters = [p for p in parameters if p.grad is not None] |
|
total_norm = torch.norm( |
|
torch.stack([torch.norm(p.grad.detach(), norm_type).cpu() for p in parameters]), |
|
norm_type, |
|
) |
|
return total_norm.item() |
|
|
|
def _log_scalars(self, metric_tracker: MetricTracker): |
|
if self.writer is None: |
|
return |
|
for metric_name in metric_tracker.keys(): |
|
self.writer.add_scalar(f"{metric_name}", metric_tracker.avg(metric_name)) |
|
|