File size: 4,371 Bytes
95a3ca6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from fairseq.data import BaseWrapperDataset, LanguagePairDataset, plasma_utils
import numpy as np

import logging

logger = logging.getLogger(__name__)


class SubsampleLanguagePairDataset(BaseWrapperDataset):
    """Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples

    Args:
        dataset (~torch.utils.data.Dataset): dataset to subsample
        size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive)
    """
    
    def __init__(self, dataset, size_ratio, weights=None, replace=False, seed=0, epoch=1):
        super().__init__(dataset)
        assert size_ratio <= 1
        self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int)
        logger.info(
            "subsampled dataset from {} to {} (ratio={})".format(
                len(self.dataset), self.actual_size, size_ratio
            )
        )
        self.src_dict = self.dataset.src_dict
        self.tgt_dict = self.dataset.tgt_dict
        self.left_pad_source = self.dataset.left_pad_source
        self.left_pad_target = self.dataset.left_pad_target
        self.seed = seed
        self._cur_epoch = None
        self._cur_indices = None
        self.replace = replace
        if weights is None:
            self.weights = None
        else:
            assert len(weights) == len(dataset)
            weights_arr = np.array(weights, dtype=np.float64)
            weights_arr /= weights_arr.sum()
            self.weights = plasma_utils.PlasmaArray(weights_arr)
        self.set_epoch(epoch)
    
    def __getitem__(self, index):
        index = self._cur_indices.array[index]
        return self.dataset.__getitem__(index)
    
    def __len__(self):
        return self.actual_size
    
    @property
    def sizes(self):
        return self.dataset.sizes[self._cur_indices.array]

    @property
    def src_sizes(self):
        return self.dataset.src_sizes[self._cur_indices.array]

    @property
    def tgt_sizes(self):
        return self.dataset.tgt_sizes[self._cur_indices.array]
    
    @property
    def name(self):
        return self.dataset.name
    
    def num_tokens(self, index):
        index = self._cur_indices.array[index]
        return self.dataset.num_tokens(index)
    
    def size(self, index):
        index = self._cur_indices.array[index]
        return self.dataset.size(index)
    
    def ordered_indices(self):
        if self.shuffle:
            indices = np.random.permutation(len(self)).astype(np.int64)
        else:
            indices = np.arange(len(self), dtype=np.int64)
        # sort by target length, then source length
        if self.tgt_sizes is not None:
            indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
        return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]
    
    def prefetch(self, indices):
        indices = self._cur_indices.array[indices]
        self.dataset.prefetch(indices)
    
    @property
    def can_reuse_epoch_itr_across_epochs(self):
        return False
    
    def set_epoch(self, epoch):
        logger.info("SubsampleLanguagePairDataset.set_epoch: {}".format(epoch))
        super().set_epoch(epoch)
        
        if epoch == self._cur_epoch:
            return
        
        self._cur_epoch = epoch
        
        # Generate a weighted sample of indices as a function of the
        # random seed and the current epoch.
        
        rng = np.random.RandomState(
            [
                42,  # magic number
                self.seed % (2 ** 32),  # global seed
                self._cur_epoch,  # epoch index
            ]
        )
        self._cur_indices = plasma_utils.PlasmaArray(
            rng.choice(
                len(self.dataset),
                self.actual_size,
                replace=self.replace,
                p=(None if self.weights is None else self.weights.array),
            )
        )
        
        logger.info(
            "Dataset is sub-sampled: {} -> {}, first 3 ids are: {}".format(len(self.dataset), self.actual_size,
                                                                           ",".join(
                                                                               [str(_i) for _i in
                                                                                self._cur_indices.array[:3]])))