Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113) | |
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# Based on fairseq code bases | |
# https://github.com/pytorch/fairseq | |
# -------------------------------------------------------- | |
import logging | |
from dataclasses import dataclass, field | |
from fairseq.data import Dictionary | |
from fairseq.tasks import register_task | |
from fairseq.tasks.hubert_pretraining import HubertPretrainingConfig, HubertPretrainingTask, LabelEncoder | |
from speech2c.data.speech2c_dataset import Speech2cDataset | |
logger = logging.getLogger(__name__) | |
class Speech2cPretrainingConfig(HubertPretrainingConfig): | |
add_decoder: bool = field( | |
default=False, | |
metadata={"help": "whether to add decoder for CE Loss on code"}, | |
) | |
# For inference | |
ctc_weight: float = field( | |
default=0.0, | |
metadata={"help": "ctc weight during inference"}, | |
) | |
class Speech2cPretrainingTask(HubertPretrainingTask): | |
cfg: Speech2cPretrainingConfig | |
def load_dictionaries(self): | |
label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir | |
dictionaries = [Dictionary.load(f"{label_dir}/dict.{label}.txt") for label in self.cfg.labels] | |
return dictionaries[0] if self.cfg.fine_tuning else dictionaries | |
def load_dataset(self, split: str, **kwargs) -> None: | |
manifest = f"{self.cfg.data}/{split}.tsv" | |
dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries | |
pad_list = [dict.pad() for dict in dicts] | |
eos_list = [dict.eos() for dict in dicts] | |
procs = [LabelEncoder(dict) for dict in dicts] | |
paths = [ | |
f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels | |
] | |
# hubert v1: pad_audio=True, random_crop=False; | |
self.datasets[split] = Speech2cDataset( | |
manifest, | |
sample_rate=self.cfg.sample_rate, | |
label_paths=paths, | |
label_rates=self.cfg.label_rate, | |
pad_list=pad_list, | |
eos_list=eos_list, | |
label_processors=procs, | |
max_keep_sample_size=self.cfg.max_keep_size, | |
min_keep_sample_size=self.cfg.min_sample_size, | |
max_sample_size=self.cfg.max_sample_size, | |
pad_audio=self.cfg.pad_audio, | |
normalize=self.cfg.normalize, | |
store_labels=False, | |
random_crop=self.cfg.random_crop, | |
single_target=self.cfg.single_target, | |
tgt_dict=dicts[0], | |
add_decoder=self.cfg.add_decoder, | |
fine_tuning=self.cfg.fine_tuning, | |
) | |
def build_generator( | |
self, | |
models, | |
args, | |
seq_gen_cls=None, | |
extra_gen_cls_kwargs=None, | |
): | |
from speech2c.squence_generator import SequenceGenerator | |
extra_gen_cls_kwargs = { | |
"ctc_weight": self.cfg.ctc_weight, | |
**extra_gen_cls_kwargs | |
} | |
return super().build_generator( | |
models, args, seq_gen_cls=SequenceGenerator, extra_gen_cls_kwargs=extra_gen_cls_kwargs | |
) | |