File size: 3,831 Bytes
641e6f7 367b2e8 641e6f7 367b2e8 641e6f7 367b2e8 641e6f7 367b2e8 641e6f7 367b2e8 641e6f7 367b2e8 641e6f7 367b2e8 641e6f7 367b2e8 641e6f7 367b2e8 641e6f7 367b2e8 641e6f7 367b2e8 641e6f7 367b2e8 641e6f7 367b2e8 641e6f7 367b2e8 641e6f7 367b2e8 00568c1 641e6f7 367b2e8 641e6f7 367b2e8 641e6f7 367b2e8 641e6f7 367b2e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
"""
Multipack Batch Sampler
"""
import logging
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count
import numba
import numpy as np
from torch.utils.data import BatchSampler
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
# First-fit-decreasing bin packing.
@numba.njit
def pack_group(items, group_offset, bin_capacity, max_items_per_bin):
idxs = np.argsort(items)[::-1]
sorted_items = items[idxs]
num_bins = len(items)
bins = np.full(num_bins, bin_capacity, dtype=np.int32)
bin_counts = np.zeros(num_bins, dtype=np.int32)
group_packing = np.full((num_bins, max_items_per_bin), -1, dtype=np.int32)
for idx, item in enumerate(sorted_items):
global_idx = idxs[idx] + group_offset
placed = False
for i in range(num_bins):
if bins[i] >= item and bin_counts[i] < max_items_per_bin:
bins[i] -= item
group_packing[i, bin_counts[i]] = global_idx
bin_counts[i] += 1
placed = True
break
if not placed:
raise ValueError(
f"Item could not be packed. Try increasing cfg.sample_packing_bin_size ({max_items_per_bin})."
)
return group_packing
def pack(items, bin_capacity, group_size, max_items_per_bin):
num_items = len(items)
num_processes = max(1, min(num_items // group_size, cpu_count()))
tasks = [
(items[i : i + group_size], i, bin_capacity, max_items_per_bin)
for i in range(0, num_items, group_size)
]
packed_bins = []
with ProcessPoolExecutor(max_workers=num_processes) as executor:
for group_packing in executor.map(pack_group, *zip(*tasks)):
for bin_pack in group_packing:
filtered_pack = bin_pack[bin_pack != -1]
if filtered_pack.size > 0:
packed_bins.append(filtered_pack.tolist())
return packed_bins
class MultipackBatchSampler(BatchSampler):
"""
Batch Sampler class for multipack
"""
def __init__(
self,
sampler,
lengths,
batch_max_len,
batch_size,
group_size=100_000,
bin_size=200,
drop_last=False,
):
self.sampler = sampler
self.lengths = np.array(lengths, dtype=np.int32)
self.batch_max_len = batch_max_len
self.batch_size = batch_size
self.group_size = group_size
self.bin_size = bin_size
self.drop_last = drop_last
self._efficiency = None
self._batches = None
def efficiency(self):
if self._efficiency is None:
self._batches = self._pack_batches()
return self._efficiency
def _pack_batches(self):
# Get possibly shuffled indices from sampler.
sample_idxs = np.arange(len(self.sampler))
lengths = self.lengths[sample_idxs]
pack_idxs = pack(
lengths,
self.batch_max_len,
self.group_size,
self.bin_size,
)
used_tokens = self.lengths.sum()
available_tokens = len(pack_idxs) * self.batch_max_len
self._efficiency = used_tokens / available_tokens
# Wrap packs into batches.
batch_idxs = [
pack_idxs[i : i + self.batch_size]
for i in range(0, len(pack_idxs), self.batch_size)
]
# Drop last batch if needed.
if self.drop_last and len(batch_idxs[-1]) < self.batch_size:
batch_idxs = batch_idxs[:-1]
return batch_idxs
def __iter__(self):
self._batches = self._pack_batches()
return iter(self._batches)
def __len__(self):
if self._batches is None:
self._batches = self._pack_batches()
return len(self._batches)
|