|
import sys |
|
|
|
import torch |
|
import speechbrain as sb |
|
from speechbrain.utils.distributed import run_on_main |
|
from hyperpyyaml import load_hyperpyyaml |
|
|
|
|
|
class ASR(sb.core.Brain): |
|
def compute_forward(self, batch, stage): |
|
batch = batch.to(self.device) |
|
wavs, wavs_len = batch.sig |
|
tokens_bos, _ = batch.tokens_bos |
|
feats = self.hparams.compute_features(wavs) |
|
current_epoch = self.hparams.epoch_counter.current |
|
feats = self.modules.normalize(feats, wavs_len, epoch=current_epoch) |
|
|
|
src = self.modules.CNN(feats) |
|
enc_out, pred = self.modules.Transformer( |
|
src, tokens_bos, wavs_len, pad_idx=self.hparams.pad_index |
|
) |
|
|
|
logits = self.modules.ctc_lin(enc_out) |
|
p_ctc = self.hparams.log_softmax(logits) |
|
|
|
pred = self.modules.seq_lin(pred) |
|
p_seq = self.hparams.log_softmax(pred) |
|
|
|
hyps = None |
|
if stage == sb.Stage.TRAIN: |
|
hyps = None |
|
elif stage == sb.Stage.VALID: |
|
hyps = None |
|
current_epoch = self.hparams.epoch_counter.current |
|
if current_epoch % self.hparams.valid_search_interval == 0: |
|
hyps, _ = self.hparams.valid_search(enc_out.detach(), wavs_len) |
|
elif stage == sb.Stage.TEST: |
|
hyps, _ = self.hparams.test_search(enc_out.detach(), wavs_len) |
|
|
|
return p_ctc, p_seq, wavs_len, hyps |
|
|
|
def compute_objectives(self, predictions, batch, stage): |
|
|
|
(p_ctc, p_seq, wavs_len, hyps,) = predictions |
|
|
|
ids = batch.id |
|
tokens_eos, tokens_eos_len = batch.tokens_eos |
|
tokens, tokens_len = batch.tokens |
|
|
|
attention_loss = self.hparams.seq_cost( |
|
p_seq, tokens_eos, length=tokens_eos_len |
|
) |
|
ctc_loss = self.hparams.ctc_cost(p_ctc, tokens, wavs_len, tokens_len) |
|
loss = ( |
|
self.hparams.ctc_weight * ctc_loss |
|
+ (1 - self.hparams.ctc_weight) * attention_loss |
|
) |
|
|
|
if stage != sb.Stage.TRAIN: |
|
current_epoch = self.hparams.epoch_counter.current |
|
valid_search_interval = self.hparams.valid_search_interval |
|
|
|
if current_epoch % valid_search_interval == 0 or ( |
|
stage == sb.Stage.TEST |
|
): |
|
predictions = [ |
|
hparams["tokenizer"].decode_ids(utt_seq).split(" ") |
|
for utt_seq in hyps |
|
] |
|
targets = [ |
|
transcription.split(" ") |
|
for transcription in batch.transcription |
|
] |
|
if self.hparams.remove_spaces: |
|
predictions = [ |
|
"".join(prediction_words) |
|
for prediction_words in predictions |
|
] |
|
targets = [ |
|
"".join(target_words) for target_words in targets |
|
] |
|
self.cer_metric.append(ids, predictions, targets) |
|
|
|
self.acc_metric.append(p_seq, tokens_eos, tokens_eos_len) |
|
|
|
return loss |
|
|
|
def fit_batch(self, batch): |
|
self.check_and_reset_optimizer() |
|
|
|
predictions = self.compute_forward(batch, sb.Stage.TRAIN) |
|
loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN) |
|
|
|
(loss / self.hparams.gradient_accumulation).backward() |
|
|
|
if self.step % self.hparams.gradient_accumulation == 0: |
|
self.check_gradients(loss) |
|
|
|
self.optimizer.step() |
|
self.optimizer.zero_grad() |
|
|
|
self.hparams.noam_annealing(self.optimizer) |
|
|
|
return loss.detach() |
|
|
|
def evaluate_batch(self, batch, stage): |
|
with torch.no_grad(): |
|
predictions = self.compute_forward(batch, stage=stage) |
|
loss = self.compute_objectives(predictions, batch, stage=stage) |
|
|
|
return loss.detach() |
|
|
|
def on_stage_start(self, stage, epoch): |
|
if stage != sb.Stage.TRAIN: |
|
self.acc_metric = self.hparams.acc_computer() |
|
self.cer_metric = self.hparams.cer_computer() |
|
|
|
def on_stage_end(self, stage, stage_loss, epoch): |
|
stage_stats = {"loss": stage_loss} |
|
if stage == sb.Stage.TRAIN: |
|
self.train_stats = stage_stats |
|
else: |
|
stage_stats["ACC"] = self.acc_metric.summarize() |
|
current_epoch = self.hparams.epoch_counter.current |
|
valid_search_interval = self.hparams.valid_search_interval |
|
if ( |
|
current_epoch % valid_search_interval == 0 |
|
or stage == sb.Stage.TEST |
|
): |
|
stage_stats["CER"] = self.cer_metric.summarize("error_rate") |
|
|
|
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process(): |
|
|
|
current_epoch = self.hparams.epoch_counter.current |
|
if current_epoch <= self.hparams.stage_one_epochs: |
|
lr = self.hparams.noam_annealing.current_lr |
|
steps = self.hparams.noam_annealing.n_steps |
|
optimizer = self.optimizer.__class__.__name__ |
|
else: |
|
lr = self.hparams.lr_sgd |
|
steps = -1 |
|
optimizer = self.optimizer.__class__.__name__ |
|
|
|
epoch_stats = { |
|
"epoch": epoch, |
|
"lr": lr, |
|
"steps": steps, |
|
"optimizer": optimizer, |
|
} |
|
self.hparams.train_logger.log_stats( |
|
stats_meta=epoch_stats, |
|
train_stats=self.train_stats, |
|
valid_stats=stage_stats, |
|
) |
|
self.checkpointer.save_and_keep_only( |
|
meta={"ACC": stage_stats["ACC"], "epoch": epoch}, |
|
max_keys=["ACC"], |
|
num_to_keep=10, |
|
) |
|
|
|
elif stage == sb.Stage.TEST: |
|
self.hparams.train_logger.log_stats( |
|
stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, |
|
test_stats=stage_stats, |
|
) |
|
with open(self.hparams.cer_file, "w") as cer_file: |
|
self.cer_metric.write_stats(cer_file) |
|
|
|
self.checkpointer.save_and_keep_only( |
|
meta={"ACC": 1.1, "epoch": epoch}, |
|
max_keys=["ACC"], |
|
num_to_keep=1, |
|
) |
|
|
|
def check_and_reset_optimizer(self): |
|
current_epoch = self.hparams.epoch_counter.current |
|
if not hasattr(self, "switched"): |
|
self.switched = False |
|
if isinstance(self.optimizer, torch.optim.SGD): |
|
self.switched = True |
|
|
|
if self.switched is True: |
|
return |
|
|
|
if current_epoch > self.hparams.stage_one_epochs: |
|
self.optimizer = self.hparams.SGD(self.modules.parameters()) |
|
|
|
if self.checkpointer is not None: |
|
self.checkpointer.add_recoverable("optimizer", self.optimizer) |
|
|
|
self.switched = True |
|
|
|
def on_fit_start(self): |
|
"""Initialize the right optimizer on the training start""" |
|
super().on_fit_start() |
|
|
|
current_epoch = self.hparams.epoch_counter.current |
|
current_optimizer = self.optimizer |
|
if current_epoch > self.hparams.stage_one_epochs: |
|
del self.optimizer |
|
self.optimizer = self.hparams.SGD(self.modules.parameters()) |
|
|
|
if self.checkpointer is not None: |
|
group = current_optimizer.param_groups[0] |
|
if "momentum" not in group: |
|
return |
|
self.checkpointer.recover_if_possible( |
|
device=torch.device(self.device) |
|
) |
|
|
|
def on_evaluate_start(self, max_key=None, min_key=None): |
|
super().on_evaluate_start() |
|
|
|
checkpointers = self.checkpointer.find_checkpoints( |
|
max_key=max_key, min_key=min_key |
|
) |
|
checkpointer = sb.utils.checkpoints.average_checkpoints( |
|
checkpointers, recoverable_name="model", device=self.device |
|
) |
|
|
|
self.hparams.model.load_state_dict(checkpointer, strict=True) |
|
self.hparams.model.eval() |
|
|
|
|
|
def dataio_prepare(hparams): |
|
@sb.utils.data_pipeline.takes("transcription") |
|
@sb.utils.data_pipeline.provides( |
|
"transcription", "tokens_bos", "tokens_eos", "tokens" |
|
) |
|
def transcription_pipline(transcription): |
|
yield transcription |
|
tokens_list = hparams["tokenizer"].encode_as_ids(transcription) |
|
tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list)) |
|
yield tokens_bos |
|
tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]]) |
|
yield tokens_eos |
|
tokens = torch.LongTensor(tokens_list) |
|
yield tokens |
|
|
|
@sb.utils.data_pipeline.takes("wav") |
|
@sb.utils.data_pipeline.provides("sig") |
|
def audio_pipline(wav): |
|
sig = sb.dataio.dataio.read_audio(wav) |
|
return sig |
|
|
|
@sb.utils.data_pipeline.takes("wav") |
|
@sb.utils.data_pipeline.provides("sig") |
|
def sp_audio_pipline(wav): |
|
sig = sb.dataio.dataio.read_audio(wav) |
|
sig = sig.unsqueeze(0) |
|
sig = hparams["speed_perturb"](sig) |
|
sig = sig.squeeze(0) |
|
return sig |
|
|
|
datasets = {} |
|
data_folder = hparams["data_folder"] |
|
output_keys = [ |
|
"transcription", |
|
"tokens_bos", |
|
"tokens_eos", |
|
"tokens", |
|
"sig", |
|
"id", |
|
] |
|
default_dynamic_items = [transcription_pipline, audio_pipline] |
|
train_dynamic_item = [transcription_pipline, sp_audio_pipline] |
|
|
|
for dataset_name in ["train", "dev", "test"]: |
|
if dataset_name == "train": |
|
dynamic_items = train_dynamic_item |
|
else: |
|
dynamic_items = default_dynamic_items |
|
|
|
json_path = f"{data_folder}/{dataset_name}.json" |
|
datasets[dataset_name] = sb.dataio.dataset.DynamicItemDataset.from_json( |
|
json_path=json_path, |
|
replacements={"data_root": data_folder}, |
|
dynamic_items=dynamic_items, |
|
output_keys=output_keys, |
|
) |
|
|
|
return datasets |
|
|
|
|
|
if __name__ == "__main__": |
|
hparams_file_path, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) |
|
|
|
sb.utils.distributed.ddp_init_group(run_opts) |
|
|
|
with open(hparams_file_path) as hparams_file: |
|
hparams = load_hyperpyyaml(hparams_file, overrides) |
|
|
|
sb.create_experiment_directory( |
|
experiment_directory=hparams["output_folder"], |
|
hyperparams_to_save=hparams_file_path, |
|
overrides=overrides, |
|
) |
|
|
|
run_on_main(hparams["pretrainer"].collect_files) |
|
hparams["pretrainer"].load_collected(device=run_opts["device"]) |
|
|
|
datasets = dataio_prepare(hparams) |
|
|
|
asr_brain = ASR( |
|
modules=hparams["modules"], |
|
opt_class=hparams["Adam"], |
|
hparams=hparams, |
|
run_opts=run_opts, |
|
checkpointer=hparams["checkpointer"], |
|
) |
|
|
|
asr_brain.fit( |
|
asr_brain.hparams.epoch_counter, |
|
datasets["train"], |
|
datasets["dev"], |
|
train_loader_kwargs=hparams["train_dataloader_opts"], |
|
valid_loader_kwargs=hparams["valid_dataloader_opts"], |
|
) |
|
|
|
|
|
|
|
|
|
|