# Copyright (c) 2021-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. import logging import math from typing import List, Optional, NamedTuple import numpy as np from fairseq.data.resampling_dataset import ResamplingDataset import torch from fairseq.data import ( ConcatDataset, LanguagePairDataset, FileAudioDataset, data_utils, ) from fairseq.data import FairseqDataset logger = logging.getLogger(__name__) class ModalityDatasetItem(NamedTuple): datasetname: str dataset: any max_positions: List[int] max_tokens: Optional[int] = None max_sentences: Optional[int] = None def resampling_dataset_present(ds): if isinstance(ds, ResamplingDataset): return True if isinstance(ds, ConcatDataset): return any(resampling_dataset_present(d) for d in ds.datasets) if hasattr(ds, "dataset"): return resampling_dataset_present(ds.dataset) return False # MultiModalityDataset: it concate multiple datasets with different modalities. # Compared with ConcatDataset it can 1) sample data given the ratios for different datasets # 2) it adds mode to indicate what type of the data samples come from. # It will be used with GroupedEpochBatchIterator together to generate mini-batch with samples # from the same type of dataset # If only one dataset is used, it will perform like the original dataset with mode added class MultiModalityDataset(ConcatDataset): def __init__(self, datasets: List[ModalityDatasetItem]): id_to_mode = [] dsets = [] max_tokens = [] max_sentences = [] max_positions = [] for dset in datasets: id_to_mode.append(dset.datasetname) dsets.append(dset.dataset) max_tokens.append(dset.max_tokens) max_positions.append(dset.max_positions) max_sentences.append(dset.max_sentences) weights = [1.0 for s in dsets] super().__init__(dsets, weights) self.max_tokens = max_tokens self.max_positions = max_positions self.max_sentences = max_sentences self.id_to_mode = id_to_mode self.raw_sub_batch_samplers = [] self._cur_epoch = 0 def set_epoch(self, epoch): super().set_epoch(epoch) self._cur_epoch = epoch def __getitem__(self, idx): dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) sample = self.datasets[dataset_idx][sample_idx] return (dataset_idx, sample) def collater(self, samples): if len(samples) == 0: return {} dataset_idx = samples[0][0] # make sure all samples in samples are from same dataset assert sum([0 if dataset_idx == s[0] else 1 for s in samples]) == 0 samples = self.datasets[dataset_idx].collater([x[1] for x in samples]) # add mode samples["net_input"]["mode"] = self.id_to_mode[dataset_idx] return samples def size(self, index: int): if len(self.datasets) == 1: return self.datasets[0].size(index) return super().size(index) @property def sizes(self): if len(self.datasets) == 1: return self.datasets[0].sizes return super().sizes def ordered_indices(self): """ Returns indices sorted by length. So less padding is needed. """ if len(self.datasets) == 1: return self.datasets[0].ordered_indices() indices_group = [] for d_idx, ds in enumerate(self.datasets): sample_num = self.cumulative_sizes[d_idx] if d_idx > 0: sample_num = sample_num - self.cumulative_sizes[d_idx - 1] assert sample_num == len(ds) indices_group.append(ds.ordered_indices()) return indices_group def get_raw_batch_samplers(self, required_batch_size_multiple, seed): with data_utils.numpy_seed(seed): indices = self.ordered_indices() for i, ds in enumerate(self.datasets): # If we have ResamplingDataset, the same id can correpond to a different # sample in the next epoch, so we need to rebuild this at every epoch if i < len(self.raw_sub_batch_samplers) and not resampling_dataset_present( ds ): logger.info(f"dataset {i} is valid and it is not re-sampled") continue indices[i] = ds.filter_indices_by_size( indices[i], self.max_positions[i], )[0] sub_batch_sampler = ds.batch_by_size( indices[i], max_tokens=self.max_tokens[i], max_sentences=self.max_sentences[i], required_batch_size_multiple=required_batch_size_multiple, ) if i < len(self.raw_sub_batch_samplers): self.raw_sub_batch_samplers[i] = sub_batch_sampler else: self.raw_sub_batch_samplers.append(sub_batch_sampler) def get_batch_samplers(self, mult_ratios, required_batch_size_multiple, seed): self.get_raw_batch_samplers(required_batch_size_multiple, seed) batch_samplers = [] for i, _ in enumerate(self.datasets): if i > 0: sub_batch_sampler = [ [y + self.cumulative_sizes[i - 1] for y in x] for x in self.raw_sub_batch_samplers[i] ] else: sub_batch_sampler = list(self.raw_sub_batch_samplers[i]) smp_r = mult_ratios[i] if smp_r != 1: is_increase = "increased" if smp_r > 1 else "decreased" logger.info( "number of batch for the dataset {} is {} from {} to {}".format( self.id_to_mode[i], is_increase, len(sub_batch_sampler), int(len(sub_batch_sampler) * smp_r), ) ) mul_samplers = [] for _ in range(math.floor(smp_r)): mul_samplers = mul_samplers + sub_batch_sampler if math.floor(smp_r) != smp_r: with data_utils.numpy_seed(seed + self._cur_epoch): np.random.shuffle(sub_batch_sampler) smp_num = int( (smp_r - math.floor(smp_r)) * len(sub_batch_sampler) ) mul_samplers = mul_samplers + sub_batch_sampler[:smp_num] sub_batch_sampler = mul_samplers else: logger.info( "dataset {} batch number is {} ".format( self.id_to_mode[i], len(sub_batch_sampler) ) ) batch_samplers.append(sub_batch_sampler) return batch_samplers class LangPairMaskDataset(FairseqDataset): def __init__( self, dataset: LanguagePairDataset, src_eos: int, src_bos: Optional[int] = None, noise_id: Optional[int] = -1, mask_ratio: Optional[float] = 0, mask_type: Optional[str] = "random", ): self.dataset = dataset self.src_eos = src_eos self.src_bos = src_bos self.noise_id = noise_id self.mask_ratio = mask_ratio self.mask_type = mask_type assert mask_type in ("random", "tail") @property def src_sizes(self): return self.dataset.src_sizes @property def tgt_sizes(self): return self.dataset.tgt_sizes @property def sizes(self): # dataset.sizes can be a dynamically computed sizes: return self.dataset.sizes def get_batch_shapes(self): if hasattr(self.dataset, "get_batch_shapes"): return self.dataset.get_batch_shapes() return self.dataset.buckets def num_tokens_vec(self, indices): return self.dataset.num_tokens_vec(indices) def __len__(self): return len(self.dataset) def num_tokens(self, index): return self.dataset.num_tokens(index) def size(self, index): return self.dataset.size(index) def ordered_indices(self): return self.dataset.ordered_indices() @property def supports_prefetch(self): return getattr(self.dataset, "supports_prefetch", False) def prefetch(self, indices): return self.dataset.prefetch(indices) def mask_src_tokens(self, sample): src_item = sample["source"] mask = None if self.mask_type == "random": mask = torch.rand(len(src_item)).le(self.mask_ratio) else: mask = torch.ones(len(src_item)) mask[: int(len(src_item) * (1 - self.mask_ratio))] = 0 mask = mask.eq(1) if src_item[0] == self.src_bos: mask[0] = False if src_item[-1] == self.src_eos: mask[-1] = False mask_src_item = src_item.masked_fill(mask, self.noise_id) smp = {"id": sample["id"], "source": mask_src_item, "target": sample["target"]} return smp def __getitem__(self, index): sample = self.dataset[index] if self.mask_ratio > 0: sample = self.mask_src_tokens(sample) return sample def collater(self, samples, pad_to_length=None): return self.dataset.collater(samples, pad_to_length) class FileAudioDatasetWrapper(FileAudioDataset): def collater(self, samples): samples = super().collater(samples) if len(samples) == 0: return {} samples["net_input"]["src_tokens"] = samples["net_input"]["source"] samples["net_input"]["prev_output_tokens"] = None del samples["net_input"]["source"] samples["net_input"]["src_lengths"] = None samples["net_input"]["alignment"] = None return samples