File size: 4,207 Bytes
96ee597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""Original sampling logic of MQTTS.

Copyright PolyAI Limited.
"""
import math
import random

import numpy as np
from torch.utils import data


def StandardSampler(dataset, shuffle, distributed=False,
                    world_size=None, rank=None):
    if distributed:
        return data.distributed.DistributedSampler(
            dataset, shuffle=shuffle, num_replicas=world_size, rank=rank)
    if shuffle:
        return data.RandomSampler(dataset)
    return data.SequentialSampler(dataset)


def RandomBucketSampler(
        nbuckets, length, batch_size, drop_last, distributed=False,
        world_size=None, rank=None):
    if distributed:
        return DistributedRandomBucketSampler(
            nbuckets, length, batch_size, drop_last, world_size, rank)
    return SingleRandomBucketSampler(nbuckets, length, batch_size, drop_last)


class SingleRandomBucketSampler(data.Sampler):
    def __init__(self, nbuckets, length, batch_size, drop_last):
        self.length = length
        self.batch_size = batch_size
        self.drop_last = drop_last
        indices = np.argsort([-x for x in length])
        split = len(indices) // nbuckets
        self.indices = []
        for i in range(nbuckets):
            self.indices.append(indices[i*split:(i+1)*split])
        if nbuckets * split < len(length):
            self.indices.append(indices[nbuckets*split:])

    def __iter__(self):
        random.shuffle(self.indices)
        for x in self.indices:
            random.shuffle(x)
        idxs = [i for x in self.indices for i in x]
        batches, batch, sum_len, max_len = [], [], 0, 0
        for idx in idxs:
            batch.append(idx)
            sum_len += self.length[idx]
            max_len = max(self.length[idx], max_len)
            if max_len * len(batch) > self.batch_size:
                batches.append(batch[:-1])
                batch, sum_len, max_len = [batch[-1]], self.length[idx], self.length[idx]  # noqa
        if len(batch) > 0 and not self.drop_last:
            batches.append(batch)
        random.shuffle(batches)
        return iter(batches)


class DistributedRandomBucketSampler(data.Sampler):
    def __init__(self, nbuckets, length, batch_size,
                 drop_last, num_replicas, rank, seed=1234):
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                "Invalid rank {}, rank should be in the interval"
                " [0, {}]".format(rank, num_replicas - 1))
        indices = np.argsort(length)
        split = len(indices) // nbuckets
        self.length = length
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.indices = []
        for i in range(nbuckets):
            self.indices.append(indices[i*split:(i+1)*split])
        if nbuckets * split < len(length):
            self.indices.append(indices[nbuckets*split:])
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.seed = seed

    def __iter__(self):
        # Deterministic shuffling
        random.Random(self.epoch + self.seed).shuffle(self.indices)
        for i, x in enumerate(self.indices):
            seed = self.epoch + self.seed + i * 5
            random.Random(seed).shuffle(x)
        indices = [i for x in self.indices for i in x]

        # Batching
        batches, batch, sum_len, max_len = [], [], 0, 0
        for idx in indices:
            batch.append(idx)
            sum_len += self.length[idx]
            max_len = max(self.length[idx], max_len)
            if max_len * len(batch) > self.batch_size:
                batches.append(batch[:-1])
                batch, sum_len, max_len = [batch[-1]], self.length[idx], self.length[idx]  # noqa
        # Subsample
        num_samples = math.ceil(
            (len(batches) - self.num_replicas) / self.num_replicas)
        total_size = num_samples * self.num_replicas
        batches = batches[:total_size]
        batches = batches[self.rank*num_samples: (self.rank+1)*num_samples]
        assert len(batches) == num_samples

        # Stochastic suffling
        random.shuffle(batches)
        return iter(batches)

    def set_epoch(self, epoch):
        self.epoch = epoch