File size: 3,831 Bytes
641e6f7
 
 
 
367b2e8
 
641e6f7
 
 
367b2e8
641e6f7
 
 
 
367b2e8
641e6f7
367b2e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641e6f7
 
367b2e8
 
 
 
641e6f7
367b2e8
641e6f7
 
367b2e8
 
 
 
 
 
 
641e6f7
367b2e8
 
 
 
 
 
 
641e6f7
367b2e8
641e6f7
 
 
 
 
 
 
 
 
367b2e8
 
 
 
 
 
 
641e6f7
367b2e8
 
641e6f7
367b2e8
 
 
 
641e6f7
367b2e8
 
641e6f7
367b2e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641e6f7
 
367b2e8
 
 
 
 
 
 
 
00568c1
641e6f7
367b2e8
 
 
641e6f7
367b2e8
641e6f7
 
367b2e8
 
641e6f7
 
367b2e8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
Multipack Batch Sampler
"""
import logging
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count

import numba
import numpy as np
from torch.utils.data import BatchSampler

LOG = logging.getLogger("axolotl.utils.samplers.multipack")


# First-fit-decreasing bin packing.
@numba.njit
def pack_group(items, group_offset, bin_capacity, max_items_per_bin):
    idxs = np.argsort(items)[::-1]
    sorted_items = items[idxs]
    num_bins = len(items)
    bins = np.full(num_bins, bin_capacity, dtype=np.int32)
    bin_counts = np.zeros(num_bins, dtype=np.int32)
    group_packing = np.full((num_bins, max_items_per_bin), -1, dtype=np.int32)

    for idx, item in enumerate(sorted_items):
        global_idx = idxs[idx] + group_offset

        placed = False
        for i in range(num_bins):
            if bins[i] >= item and bin_counts[i] < max_items_per_bin:
                bins[i] -= item
                group_packing[i, bin_counts[i]] = global_idx
                bin_counts[i] += 1
                placed = True
                break

        if not placed:
            raise ValueError(
                f"Item could not be packed. Try increasing cfg.sample_packing_bin_size ({max_items_per_bin})."
            )

    return group_packing


def pack(items, bin_capacity, group_size, max_items_per_bin):
    num_items = len(items)
    num_processes = max(1, min(num_items // group_size, cpu_count()))
    tasks = [
        (items[i : i + group_size], i, bin_capacity, max_items_per_bin)
        for i in range(0, num_items, group_size)
    ]

    packed_bins = []
    with ProcessPoolExecutor(max_workers=num_processes) as executor:
        for group_packing in executor.map(pack_group, *zip(*tasks)):
            for bin_pack in group_packing:
                filtered_pack = bin_pack[bin_pack != -1]
                if filtered_pack.size > 0:
                    packed_bins.append(filtered_pack.tolist())

    return packed_bins


class MultipackBatchSampler(BatchSampler):
    """
    Batch Sampler class for multipack
    """

    def __init__(
        self,
        sampler,
        lengths,
        batch_max_len,
        batch_size,
        group_size=100_000,
        bin_size=200,
        drop_last=False,
    ):
        self.sampler = sampler
        self.lengths = np.array(lengths, dtype=np.int32)
        self.batch_max_len = batch_max_len
        self.batch_size = batch_size
        self.group_size = group_size
        self.bin_size = bin_size
        self.drop_last = drop_last

        self._efficiency = None
        self._batches = None

    def efficiency(self):
        if self._efficiency is None:
            self._batches = self._pack_batches()
        return self._efficiency

    def _pack_batches(self):
        # Get possibly shuffled indices from sampler.
        sample_idxs = np.arange(len(self.sampler))
        lengths = self.lengths[sample_idxs]

        pack_idxs = pack(
            lengths,
            self.batch_max_len,
            self.group_size,
            self.bin_size,
        )

        used_tokens = self.lengths.sum()
        available_tokens = len(pack_idxs) * self.batch_max_len
        self._efficiency = used_tokens / available_tokens

        # Wrap packs into batches.
        batch_idxs = [
            pack_idxs[i : i + self.batch_size]
            for i in range(0, len(pack_idxs), self.batch_size)
        ]

        # Drop last batch if needed.
        if self.drop_last and len(batch_idxs[-1]) < self.batch_size:
            batch_idxs = batch_idxs[:-1]

        return batch_idxs

    def __iter__(self):
        self._batches = self._pack_batches()
        return iter(self._batches)

    def __len__(self):
        if self._batches is None:
            self._batches = self._pack_batches()
        return len(self._batches)