Spaces:
Runtime error
Runtime error
# ---------------------------------------------------------------------------- | |
# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730) | |
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT | |
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4 | |
# | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# ---------------------------------------------------------------------------- | |
import logging | |
import os | |
import sys | |
from typing import Dict, List, Optional, Tuple | |
from pathlib import Path | |
import numpy as np | |
from argparse import Namespace | |
from collections import OrderedDict | |
import torch | |
from dataclasses import dataclass, field | |
from fairseq.data import ( | |
Dictionary, | |
encoders, | |
data_utils, | |
StripTokenDataset, | |
PrependTokenDataset, | |
AppendTokenDataset, | |
DenoisingDataset, | |
ConcatDataset, | |
FairseqDataset, | |
iterators, | |
ResamplingDataset, | |
MaskTokensDataset, | |
LanguagePairDataset, | |
) | |
from fairseq.data.audio.speech_to_text_joint_dataset import S2TJointDataConfig | |
from fairseq.data.shorten_dataset import maybe_shorten_dataset | |
# from fairseq.data.encoders.utils import get_whole_word_mask | |
from fairseq.dataclass.configs import FairseqDataclass | |
from fairseq.tasks import register_task | |
from fairseq.tasks.fairseq_task import FairseqTask | |
from fairseq.dataclass.constants import ChoiceEnum | |
from omegaconf import MISSING | |
from speechut.data.multimodal_corpus_dataset import MultiCorpusDataset | |
from speechut.data.load_langpair_dataset import load_langpair_dataset | |
from speechut.data.language_trible_dataset import LanguageTripleDataset, load_langtriple_dataset | |
from speechut.data.hubert_dataset import HubertDataset | |
logger = logging.getLogger(__name__) | |
TOKENIZER_CHOICES = ChoiceEnum(["sentencepiece", "hubert_letters", "none"]) | |
def _lang_token(lang: str): | |
return "<lang:{}>".format(lang) | |
def _lang_token_index(dic: Dictionary, lang: str): | |
"""Return language token index.""" | |
idx = dic.index(_lang_token(lang)) | |
assert idx != dic.unk_index, "cannot find language token for lang {}".format(lang) | |
return idx | |
class LabelEncoder(object): | |
def __init__(self, dictionary: Dictionary) -> None: | |
self.dictionary = dictionary | |
def __call__(self, label: str) -> List[str]: | |
return self.dictionary.encode_line( | |
label, append_eos=False, add_if_not_exist=False, | |
) | |
### wrap the initial get_whole_word_mask which needs bpe_tokenizer, | |
### here we just assume words are splited by "|" or "<SIL>" | |
def get_whole_word_mask(args, dictionary): | |
def is_beginning_of_word(i): | |
if i < dictionary.nspecial: | |
# special elements are always considered beginnings | |
return True | |
tok = dictionary[i] | |
if tok.startswith("madeupword"): | |
return True | |
elif tok in ["<unk>", "<s>", "</s>", "<pad>", "|", "<eps>"]: | |
return True | |
else: | |
return False | |
mask_whole_words = torch.ByteTensor( | |
list(map(is_beginning_of_word, range(len(dictionary)))) | |
) | |
return mask_whole_words | |
def get_repeative_start(tokens): | |
""" | |
tokens: torch.Tensor with repeative tokens | |
""" | |
length = len(tokens) | |
rep_start_id = tokens[:-1] != tokens[1:] | |
return torch.cat([torch.tensor([True]), rep_start_id]) | |
class TextPretrainingConfig(FairseqDataclass): | |
### added for joint pretraining | |
text_data: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "if set, path to text data directory", | |
}, | |
) | |
seed: Optional[int] = field( | |
default=1, | |
metadata={ | |
"help": "for ordered_indices in MulticorpusDataset", | |
}, | |
) | |
tokens_per_sample: Optional[int] = field( | |
default=512, | |
metadata={ | |
"help": "max number of total tokens over all segments per sample for dataset", | |
}, | |
) | |
tokens_per_sample_tgt: Optional[int] = field( | |
default=512, | |
metadata={ | |
"help": "max number of total tokens over all segments per target sample for dataset", | |
}, | |
) | |
sample_break_mode: Optional[str] = field( | |
default="eos", | |
metadata={ | |
"help": "mode for breaking sentence", | |
}, | |
) | |
mask: Optional[float] = field( | |
default=0.3, | |
metadata={ | |
"help": "fraction of words/subwords that will be masked", | |
}, | |
) | |
leave_unmasked_prob: float = field( | |
default=0.1, | |
metadata={"help": "probability that a masked token is unmasked"}, | |
) | |
mask_random: Optional[float] = field( | |
default=0.1, | |
metadata={ | |
"help": "instead of using [MASK], use random token this often", | |
}, | |
) | |
freq_weighted_replacement: bool = field( | |
default=False, | |
metadata={"help": "sample random replacement words based on word frequencies"}, | |
) | |
mask_whole_words: bool = field( | |
default=True, | |
metadata={"help": "mask whole words; you may also want to set --bpe"}, | |
) | |
mask_repeative_tokens: bool = field( | |
default=True, | |
metadata={"help": "mask repeative_tokens; if mask_whole_words=False"}, | |
) | |
mask_multiple_length: int = field( | |
default=1, | |
metadata={"help": "repeat the mask indices multiple times"}, | |
) | |
mask_stdev: float = field( | |
default=0.0, | |
metadata={"help": "stdev of the mask length"}, | |
) | |
shorten_method: Optional[str] = field( | |
default="none", | |
metadata={ | |
"help": "if not none, shorten sequences that exceed tokens_per_sample", | |
"choices": "none/truncate/random_crop" | |
}, | |
) | |
shorten_data_split_list: Optional[str] = field( | |
default="", | |
metadata={ | |
"help": "comma_separated list of dataset splits to apply shortening to, e.g., train,valid (default: all dataset splits)", | |
}, | |
) | |
### below hypra-parameters is used in bart | |
insert: Optional[float] = field( | |
default=0.0, | |
metadata={ | |
"help": "insert this percentage of additional random tokens", | |
}, | |
) | |
permute: Optional[float] = field( | |
default=0.0, | |
metadata={ | |
"help": "take this proportion of subwords and permute them", | |
}, | |
) | |
rotate: Optional[float] = field( | |
default=0.0, | |
metadata={ | |
"help": "rotate this proportion of inputs", | |
}, | |
) | |
poisson_lambda: Optional[float] = field( | |
default=3.5, | |
metadata={ | |
"help": "randomly shuffle sentences for this proportion of inputs", | |
}, | |
) | |
permute_sentences: Optional[float] = field( | |
default=0.0, | |
metadata={ | |
"help": "shuffle this proportion of sentences in all inputs", | |
}, | |
) | |
mask_length: Optional[str] = field( | |
default="span-poisson", | |
metadata={ | |
"help": "mask length to choose", | |
"choice": "subword/word/span-poisson" | |
}, | |
) | |
replace_length: Optional[int] = field( | |
default=1, | |
metadata={ | |
"help": "when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)", | |
}, | |
) | |
shuffle_instance: Optional[bool] = field( | |
default=False, | |
metadata={"help": "shuffle instance"}, | |
) | |
max_source_positions: Optional[int] = field( | |
default=1024, | |
metadata={"help": "max number of tokens in the source sequence"}, | |
) | |
max_target_positions: Optional[int] = field( | |
default=1024, | |
metadata={"help": "max number of tokens in the target sequence"}, | |
) | |
bpe: Optional[str] = field( | |
default="", | |
metadata={ | |
"help": "will wrapped by the text_data_config yaml", | |
}, | |
) | |
data_config: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "a config yaml specify the bpe model of text data", | |
}, | |
) | |
text_maxtokens_ratio: Optional[float] = field( | |
default=1.0, | |
metadata={ | |
"help": "for text, max_tokens = max_tokens * text_maxtokens_ratio / 320 ", | |
}, | |
) | |
prepend_tgt_lang_tag: bool = field( | |
default=False, | |
metadata={"help": "prepend tgt_lang_tag to replace <eos>"}, | |
) | |
mask_text_ratio: Optional[float] = field( | |
default=0.0, | |
metadata={ | |
"help": "mask_text_ratio, for paired data", | |
}, | |
) | |
truncate_mono_source: bool = field( | |
default=True, | |
metadata={"help": "truncate mono source-side examples that exceed max-positions"}, | |
) | |
class JointPretrainingConfig(FairseqDataclass): | |
data: str = field( | |
default=MISSING, metadata={"help": "path to speech data directory"} | |
) | |
fine_tuning: bool = field( | |
default=False, metadata={"help": "set to true if fine-tuning Hubert"} | |
) | |
labels: List[str] = field( | |
default_factory=lambda: ["ltr"], | |
metadata={ | |
"help": ( | |
"extension of the label files to load, frame-level labels for" | |
" pre-training, and sequence-level label for fine-tuning" | |
) | |
}, | |
) | |
label_dir: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "if set, looks for labels in this directory instead", | |
}, | |
) | |
label_rate: int = field( | |
default=-1, | |
metadata={"help": "label frame rate. -1 for sequence label"}, | |
) | |
sample_rate: int = field( | |
default=16_000, | |
metadata={ | |
"help": "target sample rate. audio files will be up/down " | |
"sampled to this rate" | |
}, | |
) | |
normalize: bool = field( | |
default=False, | |
metadata={ | |
"help": "if set, normalizes input to have 0 mean and unit variance" | |
}, | |
) | |
enable_padding: bool = field( | |
default=False, | |
metadata={"help": "pad shorter samples instead of cropping"}, | |
) | |
max_keep_size: Optional[int] = field( | |
default=None, | |
metadata={"help": "exclude sample longer than this"}, | |
) | |
max_sample_size: Optional[int] = field( | |
default=None, | |
metadata={"help": "max sample size to crop to for batching"}, | |
) | |
min_sample_size: Optional[int] = field( | |
default=None, | |
metadata={"help": "min sample size to crop to for batching"}, | |
) | |
single_target: Optional[bool] = field( | |
default=False, | |
metadata={ | |
"help": "if set, AddTargetDatasets outputs same keys " | |
"as AddTargetDataset" | |
}, | |
) | |
random_crop: Optional[bool] = field( | |
default=True, | |
metadata={"help": "always crop from the beginning if false"}, | |
) | |
pad_audio: Optional[bool] = field( | |
default=False, | |
metadata={"help": "pad audio to the longest one in the batch if true"}, | |
) | |
store_labels: Optional[bool] = field( | |
default=True, | |
metadata={"help": "store spm labels in memory, should be true when fine-tune with bpe"}, | |
) | |
add_decoder_target: bool = field( | |
default=False, | |
metadata={"help": "contral the model architecture, if set True, load reduced unit as target"}, | |
) | |
split_modality_batch: bool = field( | |
default=False, | |
metadata={"help": "whether create all samples of different modalities in a batch"}, | |
) | |
speech_tgt_lang: str = field( | |
default="", | |
metadata={"help": "prepend <tgt-id> to prev_output_tokens to replace <eos>, only used for decoder"}, | |
) | |
speech_sampling_alpha: float = field( | |
default=0.2, | |
metadata={ | |
"help": "Hyper-parameter alpha = 1/T for temperature-based speech resampling." | |
"(alpha = 1 for no resampling)" | |
}, | |
) | |
text_sampling_alpha: float = field( | |
default=0.2, | |
metadata={ | |
"help": "Hyper-parameter alpha = 1/T for temperature-based text resampling." | |
"(alpha = 1 for no resampling)" | |
}, | |
) | |
hubert_tokenizer: Optional[TOKENIZER_CHOICES] = field( | |
default="none", | |
metadata={"help": "which tokenizer for processing text"}, | |
) | |
sp_path: Optional[str] = field( | |
default=None, | |
metadata={"help": "sentencepiece model path if using bpe tokenizer"}, | |
) | |
text_cfg: TextPretrainingConfig = TextPretrainingConfig() | |
# For inference | |
ctc_weight: float = field( | |
default=0.0, | |
metadata={"help": "ctc weight during inference"}, | |
) | |
lm_dict: Optional[str] = field( | |
default="dict.txt", | |
metadata={"help": "dict used for decoding with language model, should be in cfg.data/"}, | |
) | |
class Jsc2tPretrainingTask(FairseqTask): | |
cfg: JointPretrainingConfig | |
def __init__( | |
self, | |
cfg: JointPretrainingConfig, | |
load_local_states: True, | |
) -> None: | |
super().__init__(cfg) | |
logger.info(f"current directory is {os.getcwd()}") | |
logger.info(f"JSTPretrainingTask Config {cfg}") | |
self.cfg = cfg | |
self.fine_tuning = cfg.fine_tuning | |
self.blank_symbol = "<s>" | |
if load_local_states: | |
self.state.add_factory("hubert_tokenizer", self.build_tokenizer) | |
if self.cfg.text_cfg.text_data is not None and os.path.exists(self.cfg.text_cfg.text_data): | |
self.state.add_factory("text_dictionary", self.load_text_dictionary) | |
self.state.add_factory("text_src_dictionary", self.load_text_src_dictionary) | |
if cfg.fine_tuning: | |
self.state.add_factory("target_dictionary", self.load_dictionaries) | |
else: | |
self.state.add_factory("dictionaries", self.load_dictionaries) | |
if cfg.text_cfg.data_config is not None: | |
self.text_data_cfg = S2TJointDataConfig(Path(f"{cfg.text_cfg.text_data}/{cfg.text_cfg.data_config}")) | |
self.cfg.text_cfg.bpe = self.text_data_cfg.bpe_tokenizer["bpe"] | |
else: | |
self.text_data_cfg = None | |
def source_dictionary(self) -> Optional[Dictionary]: | |
return None | |
def target_dictionary(self) -> Optional[Dictionary]: | |
return self.state.target_dictionary | |
def dictionaries(self) -> List[Dictionary]: | |
return self.state.dictionaries | |
def text_dictionary(self) -> Optional[Dictionary]: | |
return self.state.text_dictionary | |
def text_src_dictionary(self) -> Optional[Dictionary]: | |
return self.state.text_src_dictionary | |
def hubert_tokenizer(self): | |
return self.state.hubert_tokenizer | |
def load_dictionaries(self): | |
label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir | |
dictionaries = [Dictionary.load(f"{label_dir}/dict.{label}.txt") for label in self.cfg.labels] | |
if not self.cfg.fine_tuning: | |
for dictionary in dictionaries: | |
dictionary.add_symbol("<mask>") | |
return dictionaries[0] if self.cfg.fine_tuning else dictionaries | |
def load_text_dictionary(self): | |
tgt_dict_path = f"{self.cfg.text_cfg.text_data}/{self.text_data_cfg.vocab_filename if self.text_data_cfg is not None else 'dict.txt'}" | |
if not os.path.isfile(tgt_dict_path): | |
raise FileNotFoundError(f"Dict not found: {tgt_dict_path}") | |
text_dictionary = Dictionary.load(tgt_dict_path) | |
self.mask_idx = text_dictionary.add_symbol("<mask>") | |
return text_dictionary | |
def load_text_src_dictionary(self): | |
src_dict_path = f"{self.cfg.text_cfg.text_data}/{self.text_data_cfg.src_vocab_filename if self.text_data_cfg is not None else 'dict.txt'}" | |
if not os.path.isfile(src_dict_path): | |
raise FileNotFoundError(f"Dict not found: {src_dict_path}") | |
src_text_dictionary = Dictionary.load(src_dict_path) | |
self.mask_idx = src_text_dictionary.add_symbol("<mask>") | |
return src_text_dictionary | |
def setup_task( | |
cls, cfg: JointPretrainingConfig, **kwargs | |
) -> "Jsc2tPretrainingTask": | |
load_local_states = kwargs.get("load_local_states", True) | |
return cls(cfg, load_local_states) | |
def get_label_dir(self) -> str: | |
if self.cfg.label_dir is None: | |
return self.cfg.data | |
return self.cfg.label_dir | |
def load_paired_dataset(self, text_split, truncate_source=False): | |
text_split, lp = text_split.rsplit('.', 1) # e.g. "libritext.ltr-ltr" | |
if len(lp.split("-")) == 2: | |
src, tgt = lp.split("-") | |
if src == tgt: | |
logger.warn(f"| trying to load monolingual dataset {text_split}.{lp}, please check your task is right.") | |
paired_dataset = self.load_char_bart_dataset(f"{text_split}.{lp}.{tgt}") | |
return paired_dataset | |
paired_dataset = load_langpair_dataset( | |
self.cfg.text_cfg.text_data, | |
text_split, | |
src, | |
self.text_src_dictionary, | |
tgt, | |
self.text_dictionary, | |
combine=True, | |
dataset_impl=None, | |
upsample_primary=1, | |
left_pad_source=False, | |
left_pad_target=False, | |
max_source_positions=self.cfg.text_cfg.tokens_per_sample, | |
max_target_positions=self.cfg.text_cfg.tokens_per_sample, | |
truncate_source=truncate_source, | |
prepend_bos=False, | |
load_alignments=False, | |
append_source_id=True if self.cfg.text_cfg.prepend_tgt_lang_tag else False, | |
lang_format="<lang:{}>" if self.cfg.text_cfg.prepend_tgt_lang_tag else "[{}]", | |
input_feeding=self.cfg.add_decoder_target, | |
) | |
if self.cfg.text_cfg.mask_text_ratio > 0: | |
# add mask | |
self.mask_idx = self.text_src_dictionary.index("<mask>") | |
mask_whole_words = None | |
if self.cfg.text_cfg.mask_whole_words: | |
mask_whole_words = get_whole_word_mask(self.cfg.text_cfg, self.text_src_dictionary) | |
elif self.cfg.text_cfg.mask_repeative_tokens: | |
mask_whole_words = get_repeative_start | |
src_dataset, src_unmasked_dataset = MaskTokensDataset.apply_mask( | |
paired_dataset.src, | |
self.text_src_dictionary, | |
pad_idx=self.text_src_dictionary.pad(), | |
mask_idx=self.mask_idx, | |
seed=self.cfg.text_cfg.seed, | |
mask_prob=self.cfg.text_cfg.mask_text_ratio, | |
leave_unmasked_prob=self.cfg.text_cfg.leave_unmasked_prob, | |
random_token_prob=self.cfg.text_cfg.mask_random, | |
freq_weighted_replacement=self.cfg.text_cfg.freq_weighted_replacement, | |
mask_whole_words=mask_whole_words, | |
mask_multiple_length=self.cfg.text_cfg.mask_multiple_length, | |
mask_stdev=self.cfg.text_cfg.mask_stdev, | |
) | |
tgt_dataset = paired_dataset.tgt if paired_dataset.tgt is not None else src_unmasked_dataset | |
paired_dataset = LanguageTripleDataset( | |
src_dataset, | |
src_dataset.sizes, | |
self.text_src_dictionary, | |
src_unmasked_dataset, | |
src_unmasked_dataset.sizes, | |
self.text_src_dictionary, | |
tgt_dataset, | |
tgt_dataset.sizes, | |
self.text_dictionary, | |
left_pad_source=False, | |
left_pad_target=False, | |
align_dataset=None, | |
eos=None, | |
num_buckets=0, | |
shuffle=True, | |
pad_to_multiple=1, | |
) | |
else: | |
src, ref, tgt = lp.split("-") | |
paired_dataset = load_langtriple_dataset( | |
self.cfg.text_cfg.text_data, | |
text_split, | |
src, | |
self.text_src_dictionary, | |
ref, | |
self.dictionaries[-1], | |
tgt, | |
self.text_dictionary, | |
combine=True, | |
dataset_impl=None, | |
upsample_primary=1, | |
left_pad_source=False, | |
left_pad_target=False, | |
max_source_positions=self.cfg.text_cfg.tokens_per_sample, | |
max_target_positions=self.cfg.text_cfg.tokens_per_sample, | |
truncate_source=truncate_source, | |
prepend_bos=False, | |
load_alignments=False, | |
append_source_id=True if self.cfg.text_cfg.prepend_tgt_lang_tag else False, | |
lang_format="<lang:{}>" if self.cfg.text_cfg.prepend_tgt_lang_tag else "[{}]", | |
) | |
return paired_dataset | |
def load_dataset(self, split: str, epoch=1, **kwargs) -> None: | |
""" | |
Create Wav dataset for audio, and Index dataset for phonemized text, | |
then concatenate them to by fairseq.data.multi_corpus_dataset.MultiCorpusDataset. | |
""" | |
speech_splits = split.split('+')[0].split(',') | |
### 1st, create a speech dataset using STSpeechDataset (modified from HubertDataset) | |
dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries | |
pad_list = [dict.pad() for dict in dicts] | |
eos_list = [dict.eos() for dict in dicts] | |
procs = [LabelEncoder(dict) for dict in dicts] | |
if self.cfg.speech_tgt_lang != "": | |
tgt_lang_idx = _lang_token_index(dicts[0], self.cfg.speech_tgt_lang) | |
logger.info(f"Will prepend <{tgt_lang_idx}> at the beginning of prev_output_tokens to replace <eos>") | |
else: | |
tgt_lang_idx = None | |
# hubert v1: pad_audio=True, random_crop=False; | |
speech_datasets = [] | |
for speech_split in speech_splits: | |
paths = [ | |
f"{self.get_label_dir()}/{speech_split}.{l}" for l in self.cfg.labels | |
] | |
speech_datasets.append( | |
HubertDataset( | |
f"{self.cfg.data}/{speech_split}.tsv", | |
sample_rate=self.cfg.sample_rate, | |
label_paths=paths, | |
label_rates=self.cfg.label_rate, | |
pad_list=pad_list, | |
eos_list=eos_list, | |
label_processors=procs, | |
max_keep_sample_size=self.cfg.max_keep_size, | |
min_keep_sample_size=self.cfg.min_sample_size, | |
max_sample_size=self.cfg.max_sample_size, | |
pad_audio=self.cfg.pad_audio, | |
normalize=self.cfg.normalize, | |
store_labels=self.cfg.store_labels, | |
random_crop=self.cfg.random_crop, | |
single_target=self.cfg.single_target, | |
tgt_dict=dicts[0], | |
add_decoder_target=self.cfg.add_decoder_target, | |
fine_tuning=self.cfg.fine_tuning, | |
tgt_lang_idx=tgt_lang_idx, | |
tokenizer=self.hubert_tokenizer, | |
) | |
) | |
if len(speech_datasets) > 1: | |
speech_dataset = ConcatDataset(speech_datasets) | |
else: | |
speech_dataset = speech_datasets[0] | |
has_text = len(split.split('+')) > 1 | |
if not has_text: | |
assert speech_dataset is not None | |
self.datasets[split] = speech_dataset | |
return | |
### 2nd, create paired/mono text datasets using Langpairdataset | |
if split.split('+')[1] != '': | |
paired_splits = [paired_split for paired_split in split.split('+')[1].split(',') if paired_split != ''] | |
paired_datasets = [self.load_paired_dataset(paired_split) for paired_split in paired_splits] | |
else: | |
paired_splits, paired_datasets = [], [] | |
if len(split.split('+')) > 2 and split.split('+')[2] != '': | |
mono_splits = [mono_split for mono_split in split.split('+')[2].split(',') if mono_split != ''] | |
mono_datasets = [self.load_paired_dataset(mono_split, truncate_source=self.cfg.text_cfg.truncate_mono_source) for mono_split in mono_splits] | |
else: | |
mono_splits, mono_datasets = [], [] | |
assert len(mono_datasets + paired_datasets) > 0, f"split {split} has no text! you should check out for that" | |
### 3rd, if provided, create a supervised dataset with labeled data | |
if len(split.split('+')) > 3 and split.split('+')[3] != '': | |
assert len(paired_splits) > 0, f"supervised dataset can not be loaded without text paired dataset!" | |
tgt = paired_splits[0].rsplit('.', 1)[1].split("-")[1] | |
sup_split = split.split('+')[3] | |
sup_dataset = HubertDataset( | |
f"{self.cfg.data}/{sup_split}.tsv", | |
sample_rate=self.cfg.sample_rate, | |
label_paths=[f"{self.get_label_dir()}/{sup_split}.{tgt}"], | |
label_rates=[-1], | |
pad_list=[self.text_dictionary.pad()], | |
eos_list=[self.text_dictionary.eos()], | |
label_processors=[LabelEncoder(self.text_dictionary)], | |
max_keep_sample_size=self.cfg.max_keep_size, | |
min_keep_sample_size=None, | |
max_sample_size=None, | |
pad_audio=True, | |
normalize=self.cfg.normalize, | |
store_labels=self.cfg.store_labels, | |
random_crop=False, | |
single_target=True, | |
tgt_dict=self.text_dictionary, | |
add_decoder_target=self.cfg.add_decoder_target, | |
fine_tuning=True, | |
tgt_lang_idx=None, | |
tokenizer=None, | |
) | |
else: | |
sup_dataset = None | |
### 4th, compose a MultiCorpusDataset | |
dataset_dict, max_positions_dict, distributions, max_tokens_ratios = self.resample_multi_modality_dataset( | |
speech_dataset, sup_dataset, mono_datasets, paired_datasets, mono_splits, paired_splits, epoch=epoch, | |
) | |
self.datasets[split] = MultiCorpusDataset( | |
dataset_dict, | |
max_positions=max_positions_dict, | |
distribution=distributions, | |
max_tokens_ratio=max_tokens_ratios, | |
seed=self.cfg.text_cfg.seed, | |
sort_indices=True, | |
) | |
def max_positions(self) -> Tuple[int, int]: | |
return (sys.maxsize, sys.maxsize) | |
def filter_indices_by_size( | |
self, indices: np.array, *args, **kwargs | |
) -> np.array: | |
return indices | |
def get_batch_iterator( | |
self, | |
dataset, | |
max_tokens=None, | |
max_sentences=None, | |
max_positions=None, | |
ignore_invalid_inputs=False, | |
required_batch_size_multiple=1, | |
seed=1, | |
num_shards=1, | |
shard_id=0, | |
num_workers=0, | |
epoch=1, | |
data_buffer_size=0, | |
disable_iterator_cache=False, | |
skip_remainder_batch=False, | |
grouped_shuffling=False, | |
update_epoch_batch_itr=False, | |
): | |
""" | |
Get an iterator that yields batches of data from the given dataset. | |
Args: | |
dataset (~fairseq.data.FairseqDataset): dataset to batch | |
max_tokens (int, optional): max number of tokens in each batch | |
(default: None). | |
max_sentences (int, optional): max number of sentences in each | |
batch (default: None). | |
max_positions (optional): max sentence length supported by the | |
model (default: None). | |
ignore_invalid_inputs (bool, optional): don't raise Exception for | |
sentences that are too long (default: False). | |
required_batch_size_multiple (int, optional): require batch size to | |
be a multiple of N (default: 1). | |
seed (int, optional): seed for random number generator for | |
reproducibility (default: 1). | |
num_shards (int, optional): shard the data iterator into N | |
shards (default: 1). | |
shard_id (int, optional): which shard of the data iterator to | |
return (default: 0). | |
num_workers (int, optional): how many subprocesses to use for data | |
loading. 0 means the data will be loaded in the main process | |
(default: 0). | |
epoch (int, optional): the epoch to start the iterator from | |
(default: 1). | |
data_buffer_size (int, optional): number of batches to | |
preload (default: 0). | |
disable_iterator_cache (bool, optional): don't cache the | |
EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) | |
(default: False). | |
skip_remainder_batch (bool, optional): if set, discard the last | |
batch in each training epoch, as the last batch is often smaller than | |
local_batch_size * distributed_word_size (default: ``True``). | |
grouped_shuffling (bool, optional): group batches with each groups | |
containing num_shards batches and shuffle groups. Reduces difference | |
between sequence lengths among workers for batches sorted by length. | |
update_epoch_batch_itr (bool optional): if true then donot use the cached | |
batch iterator for the epoch | |
Returns: | |
~fairseq.iterators.EpochBatchIterator: a batched iterator over the | |
given dataset split | |
""" | |
if self.fine_tuning or not isinstance(dataset, MultiCorpusDataset): | |
return super().get_batch_iterator( | |
dataset, | |
max_tokens=max_tokens, | |
max_sentences=max_sentences, | |
max_positions=max_positions, | |
ignore_invalid_inputs=ignore_invalid_inputs, | |
required_batch_size_multiple=required_batch_size_multiple, | |
seed=seed, | |
num_shards=num_shards, | |
shard_id=shard_id, | |
num_workers=num_workers, | |
epoch=epoch, | |
data_buffer_size=data_buffer_size, | |
disable_iterator_cache=disable_iterator_cache, | |
skip_remainder_batch=skip_remainder_batch, | |
grouped_shuffling=grouped_shuffling, | |
update_epoch_batch_itr=update_epoch_batch_itr, | |
) | |
can_reuse_epoch_itr = ( | |
not disable_iterator_cache | |
and not update_epoch_batch_itr | |
and self.can_reuse_epoch_itr(dataset) | |
) | |
if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter: | |
logger.debug("reusing EpochBatchIterator for epoch {}".format(epoch)) | |
return self.dataset_to_epoch_iter[dataset] | |
assert isinstance(dataset, FairseqDataset) | |
# initialize the dataset with the correct starting epoch | |
dataset.set_epoch(epoch) | |
# get indices ordered by example size | |
with data_utils.numpy_seed(seed): | |
indices = dataset.ordered_indices() | |
# filter examples that are too large | |
if max_positions is not None: | |
indices = self.filter_indices_by_size( | |
indices, dataset, max_positions, ignore_invalid_inputs | |
) | |
# create mini-batches with given size constraints | |
batch_sampler = dataset.get_batch_sampler( | |
indices, | |
num_shards, | |
seed, | |
max_tokens=max_tokens, | |
max_sentences=max_sentences, | |
required_batch_size_multiple=required_batch_size_multiple, | |
split_modality_batch=self.cfg.split_modality_batch, | |
) | |
# return a reusable, sharded iterator | |
epoch_iter = iterators.EpochBatchIterator( | |
dataset=dataset, | |
collate_fn=dataset.collater, | |
batch_sampler=batch_sampler, | |
seed=seed, | |
num_shards=num_shards, | |
shard_id=shard_id, | |
num_workers=num_workers, | |
epoch=epoch, | |
buffer_size=data_buffer_size, | |
skip_remainder_batch=skip_remainder_batch, | |
disable_shuffling=True, | |
grouped_shuffling=grouped_shuffling, | |
) | |
if can_reuse_epoch_itr: | |
self.dataset_to_epoch_iter[dataset] = epoch_iter | |
return epoch_iter | |
def build_generator( | |
self, | |
models, | |
args, | |
seq_gen_cls=None, | |
extra_gen_cls_kwargs=None, | |
): | |
"""Build ED-CTC generator for finet-tuned ASR model""" | |
from speechut.squence_generator import SequenceGenerator | |
extra_gen_cls_kwargs = { | |
"ctc_weight": self.cfg.ctc_weight, | |
"lm_dict": Dictionary.load(os.path.join(self.cfg.data, self.cfg.lm_dict)), | |
**extra_gen_cls_kwargs | |
} | |
return super().build_generator( | |
models, args, seq_gen_cls=SequenceGenerator, extra_gen_cls_kwargs=extra_gen_cls_kwargs | |
) | |
def _get_size_ratios(cls, ids: List[str], sizes: List[int], alpha: float = 1.0): | |
"""Size ratios for temperature-based sampling | |
(https://arxiv.org/abs/1907.05019)""" | |
_sizes = np.array(sizes) | |
prob = _sizes / _sizes.sum() | |
smoothed_prob = prob ** alpha | |
smoothed_prob = smoothed_prob / smoothed_prob.sum() | |
size_ratio = (smoothed_prob * _sizes.sum()) / _sizes | |
o_str = str({_i: f"{prob[i]:.3f}" for i, _i in enumerate(ids)}) | |
logger.info(f"original sampling probability: {o_str}") | |
p_str = str({_i: f"{smoothed_prob[i]:.3f}" for i, _i in enumerate(ids)}) | |
logger.info(f"balanced sampling probability: {p_str}") | |
sr_str = str({_id: f"{size_ratio[i]:.3f}" for i, _id in enumerate(ids)}) | |
logger.info(f"balanced sampling size ratio: {sr_str}") | |
return size_ratio.tolist() | |
def resample_multi_modality_dataset(self, speech_dataset, sup_dataset, mono_datasets, paired_datasets, mono_splits, paired_splits, epoch=1, train=True): | |
assert len(mono_datasets+paired_datasets) > 0, f"No text data loaded!" | |
if len(mono_datasets) > 1 and self.cfg.text_sampling_alpha != 1.0: | |
size_ratios = self._get_size_ratios( | |
mono_splits, [len(s) for s in mono_datasets], alpha=self.cfg.text_sampling_alpha | |
) | |
mono_datasets = [ | |
ResamplingDataset( | |
d, size_ratio=r, seed=0, epoch=epoch, replace=(r >= 1.0) | |
) for d, r in zip(mono_datasets, size_ratios) | |
] | |
if len(paired_datasets) > 1 and self.cfg.text_sampling_alpha != 1.0: | |
size_ratios = self._get_size_ratios( | |
paired_splits, [len(s) for s in paired_datasets], alpha=self.cfg.text_sampling_alpha | |
) | |
paired_datasets = [ | |
ResamplingDataset( | |
d, size_ratio=r, seed=0, epoch=epoch, replace=(r >= 1.0) | |
) for d, r in zip(paired_datasets, size_ratios) | |
] | |
dataset_list = [speech_dataset, sup_dataset] | |
for datasets in [mono_datasets, paired_datasets]: | |
if len(datasets) > 1: | |
dataset_list.append(ConcatDataset(datasets)) | |
elif len(datasets) == 1: | |
dataset_list.append(datasets[0]) | |
else: | |
dataset_list.append(None) | |
### match speech/text datasets according to modality | |
dataset_dict = OrderedDict((name, d) for name, d in zip(["speech", "speech_sup", "text_mono", "text_paired"], dataset_list) if d is not None) | |
max_positions_dict = { | |
"speech": None, | |
"speech_sup": None, | |
"text_mono": (self.cfg.text_cfg.tokens_per_sample, self.cfg.text_cfg.tokens_per_sample), | |
"text_paired": (self.cfg.text_cfg.tokens_per_sample, self.cfg.text_cfg.tokens_per_sample), | |
} | |
max_positions_dict = OrderedDict((name, max_positions_dict[name]) for name in dataset_dict.keys()) | |
max_tokens_ratios_dict = { | |
"speech": 1.0, | |
"speech_sup": 1.0, | |
"text_mono": 1.0 / 320 / self.cfg.text_cfg.text_maxtokens_ratio, | |
"text_paired": 1.0 / 320 / self.cfg.text_cfg.text_maxtokens_ratio, | |
} | |
max_tokens_ratios = [max_tokens_ratios_dict[name] for name in dataset_dict.keys()] | |
dataset_lens = np.array([len(dataset) for dataset in dataset_dict.values()]) | |
dataset_avg_sample_lens = np.array([ | |
sum([dataset.num_tokens(i) for i in np.random.randint(low=0, high=len(dataset), size=10000)]) / 10000.0 | |
for dataset in dataset_dict.values() | |
]) | |
if not "speech" in dataset_dict: | |
distributions = [l / sum(dataset_lens) for l in dataset_lens] | |
else: | |
## we just keep the batches of speech and non-speech the same, expand_coef is to ensure speech batches is less than others | |
first_ratio = dataset_lens[0] / sum(dataset_lens) | |
expand_coef = 1.2 if sup_dataset is None else 1.1 * sum(dataset_lens[0:2]) / dataset_lens[0] | |
distributions = [expand_coef * max_tokens_ratios[i] * dataset_avg_sample_lens[0] / l for (i, l) in enumerate(dataset_avg_sample_lens)] | |
distributions[0] = 1.0 | |
if sup_dataset is not None: | |
distributions[1] = dataset_lens[1] / dataset_lens[0] | |
distributions = [first_ratio * d for d in distributions] | |
logging.info(f"Number samples of datasets is {dataset_lens}") | |
logging.info(f"Avg sample length of datasets is {dataset_avg_sample_lens}") | |
logging.info(f"Sampling distributions is {distributions}") | |
logging.info(f"Maxtokens ratio is {max_tokens_ratios}") | |
return dataset_dict, max_positions_dict, distributions, max_tokens_ratios | |
def build_tokenizer(self, cfg=None): | |
logger.info(f"tokenizer: {self.cfg.hubert_tokenizer}") | |
if self.cfg.hubert_tokenizer != "none": | |
return encoders.build_bpe(Namespace(**{"bpe": self.cfg.hubert_tokenizer, "sentencepiece_model": self.cfg.sp_path})) | |
else: | |
return None | |
def load_char_bart_dataset(self, split): | |
mono_dataset = data_utils.load_indexed_dataset( | |
f"{self.cfg.text_cfg.text_data}/{split}", | |
self.text_dictionary, | |
) | |
mono_dataset = StripTokenDataset(mono_dataset, self.text_dictionary.eos()) | |
mono_dataset = maybe_shorten_dataset( | |
mono_dataset, | |
split, | |
self.cfg.text_cfg.shorten_data_split_list, | |
self.cfg.text_cfg.shorten_method, | |
self.cfg.text_cfg.tokens_per_sample - 2, | |
self.cfg.text_cfg.seed, | |
) | |
logger.info("loaded {} samples from: {}".format(len(mono_dataset), mono_dataset)) | |
### prepend bos and eos to dataset | |
mono_dataset = PrependTokenDataset(mono_dataset, self.text_dictionary.bos()) | |
mono_dataset = AppendTokenDataset(mono_dataset, self.text_dictionary.eos()) | |
mask_whole_words = ( | |
get_whole_word_mask(None, self.text_dictionary) | |
if self.cfg.text_cfg.mask_whole_words | |
else None | |
) | |
lang=self.cfg.speech_tgt_lang | |
mono_dataset = DenoisingDataset( | |
mono_dataset, | |
mono_dataset.sizes, | |
self.text_dictionary, | |
self.mask_idx, | |
mask_whole_words, | |
shuffle=self.cfg.text_cfg.shuffle_instance, | |
seed=self.cfg.text_cfg.seed, | |
args=self.cfg.text_cfg, | |
tgt_lang_idx=_lang_token_index(self.text_dictionary, lang) if self.cfg.text_cfg.prepend_tgt_lang_tag else None, | |
) | |
return mono_dataset | |