# ---------------------------------------------------------------------------- # SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329) # Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM # Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4 # # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # ---------------------------------------------------------------------------- """ Modified from https://github.com/facebookresearch/fairseq/blob/272c4c5197250997148fb12c0db6306035f166a4/fairseq/tasks/translation.py 1. Add custom lang_format in function load_langpair_dataset 2. If truncate_source (default no), use RandomCropDataset instead of TruncateDataset """ import itertools import logging import os from fairseq.data import ( AppendTokenDataset, LanguagePairDataset, PrependTokenDataset, StripTokenDataset, TruncateDataset, RandomCropDataset, data_utils, indexed_dataset, ) from speechlm.data.concat_dataset import ConcatDataset EVAL_BLEU_ORDER = 4 logger = logging.getLogger(__name__) def load_langpair_dataset( data_path, split, src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, truncate_source=False, append_source_id=False, num_buckets=0, shuffle=True, pad_to_multiple=1, prepend_bos_src=None, lang_format="[{}]", input_feeding=True, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) return indexed_dataset.dataset_exists(filename, impl=dataset_impl) src_datasets = [] tgt_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") # infer langcode if split_exists(split_k, src, tgt, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) elif split_exists(split_k, tgt, src, src, data_path): prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) else: if k > 0: break else: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, data_path) ) src_dataset = data_utils.load_indexed_dataset( prefix + src, src_dict, dataset_impl ) if truncate_source: src_dataset = AppendTokenDataset( RandomCropDataset( StripTokenDataset(src_dataset, src_dict.eos()), max_source_positions - 1, ), src_dict.eos(), ) src_datasets.append(src_dataset) tgt_dataset = data_utils.load_indexed_dataset( prefix + tgt, tgt_dict, dataset_impl ) if tgt_dataset is not None: tgt_datasets.append(tgt_dataset) logger.info( "{} {} {}-{} {} examples".format( data_path, split_k, src, tgt, len(src_datasets[-1]) ) ) if not combine: break assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 if len(src_datasets) == 1: src_dataset = src_datasets[0] tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) if len(tgt_datasets) > 0: tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) else: tgt_dataset = None if prepend_bos: assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) if tgt_dataset is not None: tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) elif prepend_bos_src is not None: logger.info(f"prepending src bos: {prepend_bos_src}") src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src) eos = None if append_source_id: src_dataset = AppendTokenDataset( src_dataset, src_dict.index(lang_format.format(src)) ) if tgt_dataset is not None: tgt_dataset = AppendTokenDataset( tgt_dataset, tgt_dict.index(lang_format.format(tgt)) ) eos = tgt_dict.index(lang_format.format(tgt)) align_dataset = None if load_alignments: align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): align_dataset = data_utils.load_indexed_dataset( align_path, None, dataset_impl ) tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, align_dataset=align_dataset, eos=eos, num_buckets=num_buckets, shuffle=shuffle, pad_to_multiple=pad_to_multiple, input_feeding=input_feeding, )