Axelisme's picture
Upload 43 files
297c44b
raw
history blame
6.64 kB
import sys
import torch
import speechbrain as sb
from speechbrain.utils.distributed import run_on_main
from hyperpyyaml import load_hyperpyyaml
class ASR(sb.core.Brain):
def compute_forward(self, batch, stage):
batch = batch.to(self.device)
wavs, wavs_len = batch.sig
tokens_bos, _ = batch.tokens_bos
feats = self.hparams.compute_features(wavs)
current_epoch = self.hparams.epoch_counter.current
feats = self.hparams.normalize(feats, wavs_len, epoch=current_epoch)
# if stage == sb.Stage.TRAIN:
# if hasattr(self.modules, "augmentation"):
# feats = self.hparams.augmentation(feats)
src = self.modules.CNN(feats)
enc_out, pred = self.modules.Transformer(
src, tokens_bos, wavs_len, pad_idx=self.hparams.pad_index
)
logits = self.modules.ctc_lin(enc_out)
p_ctc = self.hparams.log_softmax(logits)
pred = self.hparams.seq_lin(pred)
p_seq = self.hparams.log_softmax(pred)
hyps = None
if stage == sb.Stage.TRAIN:
hyps = None
elif stage == sb.Stage.VALID:
hyps = None
current_epoch = self.hparams.epoch_counter.current
if current_epoch % self.hparams.valid_search_interval == 0:
# for the sake of efficiency, we only perform beamsearch with limited capacity
# and no LM to give user some idea of how the AM is doing
hyps, _ = self.hparams.valid_search(enc_out.detach(), wavs_len)
elif stage == sb.Stage.TEST:
hyps, _ = self.hparams.test_search(enc_out.detach(), wavs_len)
return p_ctc, p_seq, wavs_len, hyps
def evaluate_batch(self, batch, stage):
with torch.no_grad():
predictions = self.compute_forward(batch, stage=stage)
loss = self.compute_objectives(predictions, batch, stage=stage)
# origin function is call loss.detach().cpu()
return loss.detach()
def on_stage_start(self, stage, epoch):
if stage != sb.Stage.TRAIN:
self.acc_metric = self.hparams.acc_computer()
self.cer_metric = self.hparams.cer_computer()
def on_stage_end(self, stage, stage_loss, epoch):
stage_stats = {"loss": stage_loss}
if stage == sb.Stage.TRAIN:
self.train_stats = stage_stats
else:
stage_stats["ACC"] = self.acc_metric.summarize()
current_epoch = self.hparams.epoch_counter.current
valid_search_interval = self.hparams.valid_search_interval
if (
current_epoch % valid_search_interval == 0
or stage == sb.Stage.TEST
):
stage_stats["CER"] = self.cer_metric.summarize("error_rate")
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
current_epoch = self.hparams.epoch_counter.current
if current_epoch <= self.hparams.stage_one_epochs:
lr = self.hparams.noam_annealing.current_lr
steps = self.hparams.noam_annealing.n_steps
optimizer = self.optimizer.__class__.__name__
else:
lr = self.hparams.lr_sgd
steps = -1
optimizer = self.optimizer.__class__.__name__
epoch_stats = {
"epoch": epoch,
"lr": lr,
"steps": steps,
"optimizer": optimizer,
}
self.hparams.train_logger.log_stats(
stats_meta=epoch_stats,
train_stats=self.train_stats,
valid_stats=stage_stats,
)
self.checkpointer.save_and_keep_only(
meta={"ACC": stage_stats["ACC"], "epoch": epoch},
max_keys=["ACC"],
num_to_keep=10,
)
elif stage == sb.Stage.TEST:
self.hparams.train_logger.log_stats(
stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
test_stats=stage_stats,
)
with open(self.hparams.cer_file, "w") as cer_file:
self.cer_metric.write_stats(cer_file)
self.checkpointer.save_and_keep_only(
meta={"ACC": 1.1, "epoch": epoch},
max_keys=["ACC"],
num_to_keep=1,
)
def check_and_reset_optimizer(self):
current_epoch = self.hparams.epoch_counter.current
if not hasattr(self, "switched"):
self.switched = False
if isinstance(self.optimizer, torch.optim.SGD):
self.switched = True
if self.switched is True:
return
if current_epoch > self.hparams.stage_one_epochs:
self.optimizer = self.hparams.SGD(self.modules.parameters())
if self.checkpointer is not None:
self.checkpointer.add_recoverable("optimizer", self.optimizer)
self.switched = True
def on_fit_start(self):
"""Initialize the right optimizer on the training start"""
super().on_fit_start()
current_epoch = self.hparams.epoch_counter.current
current_optimizer = self.optimizer
if current_epoch > self.hparams.stage_one_epochs:
del self.optimizer
self.optimizer = self.hparams.SGD(self.modules.parameters())
if self.checkpointer is not None:
group = current_optimizer.param_groups[0]
if "momentum" not in group:
return
self.checkpointer.recover_if_possible(
device=torch.device(self.device)
)
def on_evaluate_start(self, max_key=None, min_key=None):
super().on_evaluate_start()
checkpointer = self.checkpointer.find_checkpoints(
max_key=max_key, min_key=min_key
)
checkpointer = sb.utils.checkpoints.average_checkpoints(
checkpointer, recoverable_name="model", device=self.device
)
self.hparams.model.load_state_dict(checkpointer, strict=True)
self.hparams.model.eval()
if __name__ == "__main__":
hparams_file_path = "hyperparams.yaml"
run_opts = {"device": "cuda", "distributed_launch": False}
with open(hparams_file_path) as hparams_file:
hparams = load_hyperpyyaml(hparams_file)
asr_brain = ASR(
modules=hparams["modules"],
opt_class=hparams["Adam"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)