File size: 5,071 Bytes
297c44b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import sys
import torch
import speechbrain as sb
from speechbrain.dataio import dataset
from speechbrain.utils.distributed import run_on_main
from hyperpyyaml import load_hyperpyyaml
class LM(sb.core.Brain):
def compute_forward(self, batch, stage):
batch = batch.to(self.device)
tokens_bos, _ = batch.tokens_bos
logits = self.hparams.model(tokens_bos)
pred = self.hparams.log_softmax(logits)
return pred
def compute_objectives(self, predictions, batch, stage):
batch = batch.to(self.device)
tokens_eos, tokens_len = batch.tokens_eos
loss = self.hparams.compute_cost(
predictions, tokens_eos, length=tokens_len
)
return loss
def fit_batch(self, batch):
predictions = self.compute_forward(batch, sb.Stage.TRAIN)
loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
(loss / self.hparams.accumulation_steps).backward()
if self.step % self.hparams.accumulation_steps == 0:
self.check_gradients(loss)
self.optimizer.step()
self.optimizer.zero_grad()
if isinstance(
self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
) or isinstance(
self.hparams.lr_annealing,
sb.nnet.schedulers.CyclicCosineScheduler,
):
self.hparams.lr_annealing(self.optimizer)
return loss
def on_stage_end(self, stage, stage_loss, epoch):
stage_stats = {"loss": stage_loss}
if stage == sb.Stage.TRAIN:
self.train_stats = stage_stats
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
if not (
isinstance(
self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler
)
or isinstance(
self.hparams.lr_annealing,
sb.nnet.schedulers.CyclicCosineScheduler,
)
):
old_lr, new_lr = self.hparams.lr_annealing(stage_loss)
sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)
else:
old_lr = self.hparams.lr_annealing.current_lr
self.hparams.train_logger.log_stats(
stats_meta={"epoch": epoch, "lr": old_lr},
train_stats=self.train_stats,
valid_stats=stage_stats,
)
self.checkpointer.save_and_keep_only(
meta=stage_stats, min_keys=["loss"],
)
if stage == sb.Stage.TEST and sb.utils.distributed.if_main_process():
self.hparams.train_logger.log_stats(
stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
test_stats=stage_stats,
)
def dataio_prepare(hparams):
@sb.utils.data_pipeline.takes("transcription")
@sb.utils.data_pipeline.provides(
"transcription", "tokens_bos", "tokens_eos"
)
def transcription_pipline(transcription):
yield transcription
tokens_list = hparams["tokenizer"].encode_as_ids(transcription)
tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
yield tokens_bos
tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
yield tokens_eos
data_folder = hparams["data_folder"]
datasets = {}
for dataset_name in ["train", "dev", "test"]:
json_path = f"{data_folder}/{dataset_name}.json"
datasets[dataset_name] = dataset.DynamicItemDataset.from_json(
json_path=json_path,
replacements={"data_root": data_folder},
dynamic_items=[transcription_pipline],
output_keys=["transcription", "tokens_bos", "tokens_eos"],
)
return datasets
if __name__ == "__main__":
hparams_file_path, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
with open(hparams_file_path) as hparams_file:
hparams = load_hyperpyyaml(hparams_file, overrides)
sb.utils.distributed.ddp_init_group(run_opts)
sb.create_experiment_directory(
experiment_directory=hparams["output_folder"],
hyperparams_to_save=hparams_file_path,
overrides=overrides,
)
run_on_main(hparams["pretrainer"].collect_files)
hparams["pretrainer"].load_collected(device=run_opts["device"])
datasets = dataio_prepare(hparams)
lm_brain = LM(
modules=hparams["modules"],
opt_class=hparams["optimizer"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)
lm_brain.fit(
lm_brain.hparams.epoch_counter,
datasets["train"],
datasets["dev"],
train_loader_kwargs=hparams["train_dataloader_opts"],
valid_loader_kwargs=hparams["valid_dataloader_opts"],
)
# evaluation
lm_brain.evaluate(
datasets["test"],
min_key="loss",
test_loader_kwargs=hparams["test_dataloader_opts"],
)
|