File size: 7,827 Bytes
d5175d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import hashlib
import logging
import math

import numpy as np
from fairseq.data import SampledMultiDataset

from .sampled_multi_dataset import CollateFormat, default_virtual_size_func


logger = logging.getLogger(__name__)


class SampledMultiEpochDataset(SampledMultiDataset):
    """Samples from multiple sub-datasets according to sampling ratios
       using virtual epoch sizes to speed up dataloading.
    Args:
        datasets (
            List[~torch.utils.data.Dataset]
            or OrderedDict[str, ~torch.utils.data.Dataset]
        ): datasets
        sampling_ratios (List[float]): list of probability of each dataset to be sampled
            (default: None, which corresponds to concating all dataset together).
        seed (int): RNG seed to use (default: 2).
        epoch (int): starting epoch number (default: 1).
        eval_key (str, optional): a key used at evaluation time that causes
            this instance to pass-through batches from *datasets[eval_key]*.
        collate_format (CollateFormat):  collater output format, either CollateFormat.ordered_dict or
            CollateFormat.single (default: CollateFormat.single) where CollateFormat.single configures
            the collater to output batches of data mixed from all sub-datasets,
            and CollateFormat.ordered_dict configures the collater to output a dictionary of batches indexed by keys
            of sub-datasets.
            Note that not all sub-datasets will present in a single batch in both formats.
        virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func).
        split (str): the split of the data, e.g. 'train', 'valid' or 'test'.
        virtual_epoch_size (int): virtual epoch size, the dataset will go through the data by
            this virtual epoch size one by one to speed up data loading, e.g. indicing and filtering
            can be performed whenever a virtual epoch is loaded without waiting for the whole dataset to be loaded.
        shared_collater (bool): whether or not to all sub-datasets have the same collater.
        shard_epoch (int): the real epoch number for shard selection.
        shuffle (bool): whether or not to shuffle data (default: True).
    """

    def __init__(
        self,
        datasets,
        sampling_ratios=None,
        seed=2,
        epoch=1,
        eval_key=None,
        collate_format=CollateFormat.single,
        virtual_size=default_virtual_size_func,
        split="",
        virtual_epoch_size=None,
        shared_collater=False,
        shard_epoch=1,
        shuffle=True,
    ):
        self.virtual_epoch_size = virtual_epoch_size
        self._current_epoch_start_index = None
        self._random_global_indices = None
        self.shard_epoch = shard_epoch if shard_epoch is not None else 1
        self.load_next_shard = None
        self._epoch_sizes = None
        super().__init__(
            datasets=datasets,
            sampling_ratios=sampling_ratios,
            seed=seed,
            epoch=epoch,
            eval_key=eval_key,
            collate_format=collate_format,
            virtual_size=virtual_size,
            split=split,
            shared_collater=shared_collater,
            shuffle=shuffle,
        )

    def _setup(self, epoch):
        self.virtual_epoch_size = (
            self.virtual_epoch_size
            if self.virtual_epoch_size is not None
            else self.virtual_size
        )
        if self.virtual_epoch_size > self.virtual_size:
            logger.warning(
                f"virtual epoch size {self.virtual_epoch_size} "
                f"is greater than virtual dataset size {self.virtual_size}"
            )
            self.virtual_epoch_size = self.virtual_size
        self.num_virtual_epochs = math.ceil(self.virtual_size / self.virtual_epoch_size)
        self._current_epoch_start_index = self._get_epoch_start_index(epoch)
        logger.info(
            f"virtual epoch size {self.virtual_epoch_size}; virtual dataset size {self.virtual_size}"
        )

    def _map_epoch_index_to_global(self, index):
        index = self._current_epoch_start_index + index
        # add randomness
        return self._random_global_indices[index]

    @property
    def sizes(self):
        if self._epoch_sizes is not None:
            return self._epoch_sizes
        _sizes = super().sizes
        indices = self._random_global_indices[
            self._current_epoch_start_index : self._current_epoch_start_index
            + len(self)
        ]
        self._epoch_sizes = _sizes[indices]
        # del super()._sizes to save memory
        del self._sizes
        self._sizes = None
        return self._epoch_sizes

    def _get_dataset_and_index(self, index):
        i = self._map_epoch_index_to_global(index)
        return super()._get_dataset_and_index(i)

    def __len__(self):
        return (
            self.virtual_epoch_size
            if self._current_epoch_start_index + self.virtual_epoch_size
            < self.virtual_size
            else self.virtual_size - self._current_epoch_start_index
        )

    def set_epoch(self, epoch):
        if self._current_epoch_start_index is None:
            # initializing epoch idnices of a virtual dataset
            self._setup(epoch)
            self._next_virtual_epoch(epoch)
        else:
            # working on already intialized epoch indices
            if epoch == self._cur_epoch:
                # re-enter so return
                return
            self._next_virtual_epoch(epoch)

    def _get_epoch_start_index(self, epoch):
        assert epoch >= 1  # fairseq is using 1-based epoch everywhere
        return ((epoch - 1) % self.num_virtual_epochs) * self.virtual_epoch_size

    def _next_global_indices(self, epoch):
        rng = np.random.RandomState(
            [
                int(
                    hashlib.sha1(
                        str(self.__class__.__name__).encode("utf-8")
                    ).hexdigest(),
                    16,
                )
                % (2 ** 32),
                self.seed % (2 ** 32),  # global seed
                epoch,  # epoch index,
            ]
        )
        del self._random_global_indices
        self._random_global_indices = rng.choice(
            self.virtual_size, self.virtual_size, replace=False
        )
        if self.load_next_shard is None:
            self.load_next_shard = False
        else:
            # increase shard epoch for next loading
            self.shard_epoch += 1
            self.load_next_shard = True
            logger.info(
                "to load next epoch/shard in next load_dataset: "
                f"epoch={epoch}/shard_epoch={self.shard_epoch}"
            )

    def _next_virtual_epoch(self, epoch):
        index = self._get_epoch_start_index(epoch)
        if index == 0 or self._random_global_indices is None:
            # need to start from the beginning,
            # so call super().set_epoch(epoch) to establish the global virtual indices
            logger.info(
                "establishing a new set of global virtual indices for "
                f"epoch={epoch}/shard_epoch={self.shard_epoch}"
            )
            super().set_epoch(epoch)
            self._next_global_indices(epoch)
        else:
            self._cur_epoch = epoch

        # reset cache sizes and ordered_indices for the epoch after moving to a new epoch
        self._clean_if_not_none(
            [
                self._epoch_sizes,
            ]
        )
        self._epoch_sizes = None
        self._current_epoch_start_index = index