File size: 6,635 Bytes
297c44b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import sys
import torch
import speechbrain as sb
from speechbrain.utils.distributed import run_on_main
from hyperpyyaml import load_hyperpyyaml
class ASR(sb.core.Brain):
def compute_forward(self, batch, stage):
batch = batch.to(self.device)
wavs, wavs_len = batch.sig
tokens_bos, _ = batch.tokens_bos
feats = self.hparams.compute_features(wavs)
current_epoch = self.hparams.epoch_counter.current
feats = self.hparams.normalize(feats, wavs_len, epoch=current_epoch)
# if stage == sb.Stage.TRAIN:
# if hasattr(self.modules, "augmentation"):
# feats = self.hparams.augmentation(feats)
src = self.modules.CNN(feats)
enc_out, pred = self.modules.Transformer(
src, tokens_bos, wavs_len, pad_idx=self.hparams.pad_index
)
logits = self.modules.ctc_lin(enc_out)
p_ctc = self.hparams.log_softmax(logits)
pred = self.hparams.seq_lin(pred)
p_seq = self.hparams.log_softmax(pred)
hyps = None
if stage == sb.Stage.TRAIN:
hyps = None
elif stage == sb.Stage.VALID:
hyps = None
current_epoch = self.hparams.epoch_counter.current
if current_epoch % self.hparams.valid_search_interval == 0:
# for the sake of efficiency, we only perform beamsearch with limited capacity
# and no LM to give user some idea of how the AM is doing
hyps, _ = self.hparams.valid_search(enc_out.detach(), wavs_len)
elif stage == sb.Stage.TEST:
hyps, _ = self.hparams.test_search(enc_out.detach(), wavs_len)
return p_ctc, p_seq, wavs_len, hyps
def evaluate_batch(self, batch, stage):
with torch.no_grad():
predictions = self.compute_forward(batch, stage=stage)
loss = self.compute_objectives(predictions, batch, stage=stage)
# origin function is call loss.detach().cpu()
return loss.detach()
def on_stage_start(self, stage, epoch):
if stage != sb.Stage.TRAIN:
self.acc_metric = self.hparams.acc_computer()
self.cer_metric = self.hparams.cer_computer()
def on_stage_end(self, stage, stage_loss, epoch):
stage_stats = {"loss": stage_loss}
if stage == sb.Stage.TRAIN:
self.train_stats = stage_stats
else:
stage_stats["ACC"] = self.acc_metric.summarize()
current_epoch = self.hparams.epoch_counter.current
valid_search_interval = self.hparams.valid_search_interval
if (
current_epoch % valid_search_interval == 0
or stage == sb.Stage.TEST
):
stage_stats["CER"] = self.cer_metric.summarize("error_rate")
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
current_epoch = self.hparams.epoch_counter.current
if current_epoch <= self.hparams.stage_one_epochs:
lr = self.hparams.noam_annealing.current_lr
steps = self.hparams.noam_annealing.n_steps
optimizer = self.optimizer.__class__.__name__
else:
lr = self.hparams.lr_sgd
steps = -1
optimizer = self.optimizer.__class__.__name__
epoch_stats = {
"epoch": epoch,
"lr": lr,
"steps": steps,
"optimizer": optimizer,
}
self.hparams.train_logger.log_stats(
stats_meta=epoch_stats,
train_stats=self.train_stats,
valid_stats=stage_stats,
)
self.checkpointer.save_and_keep_only(
meta={"ACC": stage_stats["ACC"], "epoch": epoch},
max_keys=["ACC"],
num_to_keep=10,
)
elif stage == sb.Stage.TEST:
self.hparams.train_logger.log_stats(
stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
test_stats=stage_stats,
)
with open(self.hparams.cer_file, "w") as cer_file:
self.cer_metric.write_stats(cer_file)
self.checkpointer.save_and_keep_only(
meta={"ACC": 1.1, "epoch": epoch},
max_keys=["ACC"],
num_to_keep=1,
)
def check_and_reset_optimizer(self):
current_epoch = self.hparams.epoch_counter.current
if not hasattr(self, "switched"):
self.switched = False
if isinstance(self.optimizer, torch.optim.SGD):
self.switched = True
if self.switched is True:
return
if current_epoch > self.hparams.stage_one_epochs:
self.optimizer = self.hparams.SGD(self.modules.parameters())
if self.checkpointer is not None:
self.checkpointer.add_recoverable("optimizer", self.optimizer)
self.switched = True
def on_fit_start(self):
"""Initialize the right optimizer on the training start"""
super().on_fit_start()
current_epoch = self.hparams.epoch_counter.current
current_optimizer = self.optimizer
if current_epoch > self.hparams.stage_one_epochs:
del self.optimizer
self.optimizer = self.hparams.SGD(self.modules.parameters())
if self.checkpointer is not None:
group = current_optimizer.param_groups[0]
if "momentum" not in group:
return
self.checkpointer.recover_if_possible(
device=torch.device(self.device)
)
def on_evaluate_start(self, max_key=None, min_key=None):
super().on_evaluate_start()
checkpointer = self.checkpointer.find_checkpoints(
max_key=max_key, min_key=min_key
)
checkpointer = sb.utils.checkpoints.average_checkpoints(
checkpointer, recoverable_name="model", device=self.device
)
self.hparams.model.load_state_dict(checkpointer, strict=True)
self.hparams.model.eval()
if __name__ == "__main__":
hparams_file_path = "hyperparams.yaml"
run_opts = {"device": "cuda", "distributed_launch": False}
with open(hparams_file_path) as hparams_file:
hparams = load_hyperpyyaml(hparams_file)
asr_brain = ASR(
modules=hparams["modules"],
opt_class=hparams["Adam"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)
|