File size: 2,117 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
Limit the maximum timestamps in a batch to realize dynamic batching.

Authors:
  * Leo 2022
"""

from typing import List

import torch

__all__ = [
    "MaxTimestampBatchSampler",
]


class MaxTimestampBatchSampler:
    """
    The reduced timestamps for a batch should not exceed the max_timestamp.
    If shuffled, each indices are first shuffled before aggregated into batches
    """

    def __init__(
        self,
        lengths: List[int],
        max_length: int,
        shuffle: bool = False,
        seed: int = 12345678,
        reduce_func: callable = None,
    ) -> None:
        self.lengths = lengths
        self.max_length = max_length
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = 0
        self.reduce_func = reduce_func or self._default_reduce_func

    @staticmethod
    def _default_reduce_func(timestamps):
        return max(timestamps) * len(timestamps)

    def set_epoch(self, epoch: int):
        self.epoch = epoch

    def _evaluate_reduced_timestamps(self, batch_indices):
        return self.reduce_func([self.lengths[indice] for indice in batch_indices])

    def __iter__(self):
        if self.shuffle:
            generator = torch.Generator()
            generator.manual_seed(self.epoch + self.seed)
            indices = torch.randperm(len(self.lengths), generator=generator).tolist()
        else:
            indices = list(range(len(self.lengths)))

        batch = []
        for indice in indices:
            try_new_batch = batch + [indice]
            if self._evaluate_reduced_timestamps(try_new_batch) <= self.max_length:
                batch = try_new_batch
            elif len(batch) == 0:
                raise ValueError(
                    f"There is a single length {self.lengths[indice]} larger than "
                    f"max_length {self.max_length}. Please increase "
                    "the max_length."
                )
            else:
                yield batch
                batch = [indice]

        if len(batch) > 0:
            yield batch

    def __len__(self):
        return len(list(iter(self)))