|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import time |
|
from collections import OrderedDict |
|
from typing import Dict, List |
|
|
|
import numpy as np |
|
from fairseq.data import data_utils |
|
|
|
from . import FairseqDataset |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class MultiCorpusDataset(FairseqDataset): |
|
""" |
|
Stores multiple instances of FairseqDataset together. Requires each instance |
|
to be the same dataset, as the collate method needs to work on batches with |
|
samples from each dataset. |
|
|
|
Allows specifying a distribution over the datasets to use. Note that unlike |
|
MultiCorpusSampledDataset, this distribution allows sampling for each item, |
|
rather than on a batch level. |
|
|
|
Each time ordered_indices() is called, a new sample is generated with |
|
the specified distribution. |
|
|
|
Args: |
|
datasets: a OrderedDict of FairseqDataset instances. |
|
distribution: a List containing the probability of getting an utterance from |
|
corresponding dataset |
|
seed: random seed for sampling the datsets |
|
sort_indices: if true, will sort the ordered indices by size |
|
batch_sample: if true, will ensure each batch is from a single dataset |
|
""" |
|
|
|
def __init__( |
|
self, |
|
datasets: Dict[str, FairseqDataset], |
|
distribution: List[float], |
|
seed: int, |
|
sort_indices: bool = False, |
|
batch_sample: bool = False, |
|
distributed_rank=None, |
|
): |
|
super().__init__() |
|
assert isinstance(datasets, OrderedDict) |
|
assert len(datasets) == len(distribution) |
|
assert sum(distribution) == 1 |
|
self.datasets = datasets |
|
self.distribution = distribution |
|
self.seed = seed |
|
self.sort_indices = sort_indices |
|
self.batch_sample = batch_sample |
|
self.distributed_rank = distributed_rank |
|
|
|
|
|
self.dataset_list = list(datasets.values()) |
|
self.total_num_instances = 0 |
|
|
|
first_dataset = list(self.datasets.values())[0] |
|
|
|
self.dataset_offsets = [] |
|
for dataset in datasets.values(): |
|
assert isinstance(dataset, FairseqDataset) |
|
assert type(dataset) is type(first_dataset) |
|
self.dataset_offsets.append(self.total_num_instances) |
|
self.total_num_instances += len(dataset) |
|
|
|
def ordered_indices(self): |
|
start = time.time() |
|
with data_utils.numpy_seed(self.seed, self.epoch): |
|
logger.info(f"sampling new dataset with seed {self.seed} epoch {self.epoch}") |
|
sampled_indices = [] |
|
num_selected_instances = 0 |
|
|
|
|
|
for i, key in enumerate(self.datasets): |
|
|
|
if i < len(self.datasets) - 1: |
|
num_instances = int(self.distribution[i] * self.total_num_instances) |
|
high = self.dataset_offsets[i + 1] |
|
else: |
|
num_instances = self.total_num_instances - num_selected_instances |
|
high = self.total_num_instances |
|
|
|
logger.info(f"sampling {num_instances} from {key} dataset") |
|
num_selected_instances += num_instances |
|
|
|
|
|
|
|
|
|
dataset_size = len(self.datasets[key]) |
|
num_copies = num_instances // dataset_size |
|
dataset_indices = ( |
|
np.random.permutation(high - self.dataset_offsets[i]) |
|
+ self.dataset_offsets[i] |
|
)[: num_instances - num_copies * dataset_size] |
|
if num_copies > 0: |
|
sampled_indices += list( |
|
np.concatenate( |
|
( |
|
np.repeat( |
|
np.arange(self.dataset_offsets[i], high), num_copies |
|
), |
|
dataset_indices, |
|
) |
|
) |
|
) |
|
else: |
|
sampled_indices += list(dataset_indices) |
|
|
|
assert ( |
|
len(sampled_indices) == self.total_num_instances |
|
), f"{len(sampled_indices)} vs {self.total_num_instances}" |
|
|
|
np.random.shuffle(sampled_indices) |
|
if self.sort_indices: |
|
sampled_indices.sort(key=lambda i: self.num_tokens(i)) |
|
|
|
logger.info( |
|
"multi_corpus_dataset ordered_indices took {}s".format( |
|
time.time() - start |
|
) |
|
) |
|
return np.array(sampled_indices, dtype=np.int64) |
|
|
|
def _map_index(self, index: int): |
|
""" |
|
If dataset A has length N and dataset B has length M |
|
then index 1 maps to index 1 of dataset A, and index N + 1 |
|
maps to index 1 of B. |
|
""" |
|
counter = 0 |
|
for key, dataset in self.datasets.items(): |
|
if index < counter + len(dataset): |
|
return index - counter, key |
|
counter += len(dataset) |
|
raise ValueError( |
|
"Invalid index: {}, max: {}".format(index, self.total_num_instances) |
|
) |
|
|
|
def __len__(self): |
|
""" |
|
Length of this dataset is the sum of individual datasets |
|
""" |
|
return self.total_num_instances |
|
|
|
def __getitem__(self, index): |
|
new_index, key = self._map_index(index) |
|
try: |
|
item = self.datasets[key][new_index] |
|
item["full_id"] = index |
|
return item |
|
except Exception as e: |
|
e.args = (f"Error from {key} dataset", *e.args) |
|
raise |
|
|
|
def collater(self, samples): |
|
""" |
|
If we are doing batch sampling, then pick the right collater to use. |
|
|
|
Otherwise we assume all collaters are the same. |
|
""" |
|
if len(samples) == 0: |
|
return None |
|
if "full_id" in samples[0]: |
|
_, key = self._map_index(samples[0]["full_id"]) |
|
return self.datasets[key].collater(samples) |
|
else: |
|
|
|
return list(self.datasets.values())[0].collater(samples) |
|
|
|
def num_tokens(self, index: int): |
|
index, key = self._map_index(index) |
|
return self.datasets[key].num_tokens(index) |
|
|
|
def size(self, index: int): |
|
index, key = self._map_index(index) |
|
return self.datasets[key].size(index) |
|
|
|
@property |
|
def can_reuse_epoch_itr_across_epochs(self): |
|
return False |
|
|
|
def set_epoch(self, epoch, **unused): |
|
super().set_epoch(epoch) |
|
logger.info(f"setting epoch of multi_corpus_dataset to {epoch}") |
|
self.epoch = epoch |
|
|
|
@property |
|
def supports_prefetch(self): |
|
return False |
|
|
|
@property |
|
def supports_fetch_outside_dataloader(self): |
|
return all( |
|
self.datasets[key].supports_fetch_outside_dataloader |
|
for key in self.datasets |
|
) |
|
|
|
def batch_by_size( |
|
self, |
|
indices, |
|
max_tokens=None, |
|
max_sentences=None, |
|
required_batch_size_multiple=1, |
|
): |
|
if not self.batch_sample: |
|
return super().batch_by_size( |
|
indices, max_tokens, max_sentences, required_batch_size_multiple |
|
) |
|
|
|
dataset_indices = {key: [] for key in self.datasets} |
|
for i in indices: |
|
_, key = self._map_index(i) |
|
dataset_indices[key].append(i) |
|
|
|
batches = [] |
|
for key in dataset_indices: |
|
cur_batches = super().batch_by_size( |
|
np.array(dataset_indices[key], dtype=np.int64), |
|
max_tokens, |
|
max_sentences, |
|
required_batch_size_multiple, |
|
) |
|
logger.info(f"Created {len(cur_batches)} batches for dataset {key}") |
|
batches += cur_batches |
|
|
|
|
|
|
|
|
|
if self.distributed_rank is not None: |
|
with data_utils.numpy_seed(self.seed, self.epoch, self.distributed_rank): |
|
np.random.shuffle(batches) |
|
return batches |
|
|