File size: 16,380 Bytes
de4ade4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 |
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
import numpy as np
import torch
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase
class BinPackWrapper:
"""Utility collator for packing to reduce padding."""
def __init__(self,
collator: Callable,
target_batch_size: int,
max_seq_len: int,
pad_token_id: int,
padding_side: Literal['left', 'right'],
max_leftover_bins_to_keep: Optional[int] = None):
self.base_collator = collator
self.out_size = int(target_batch_size)
self.max_seq_len = int(max_seq_len)
self.pad_token_id = int(pad_token_id)
self.padding_side = padding_side
if self.out_size <= 0:
raise ValueError(f'{target_batch_size=} must be >0.')
if self.max_seq_len <= 0:
raise ValueError(f'{max_seq_len=} must be >0.')
if self.pad_token_id < 0:
raise ValueError(f'{pad_token_id=} must be >=0.')
if max_leftover_bins_to_keep is None:
self.max_leftover_bins_to_keep = int(10 * self.out_size)
elif max_leftover_bins_to_keep < 0:
raise ValueError(
f'{max_leftover_bins_to_keep=} must be >=0 or None.')
else:
self.max_leftover_bins_to_keep = int(max_leftover_bins_to_keep)
self.n_packed_tokens = 0
self.n_total_tokens = 0
self.n_packed_examples = 0
self._leftover_bins: List[Tuple[int, Dict[str, torch.Tensor]]] = []
@property
def waste(self) -> float:
return 1 - (self.n_packed_tokens / self.n_total_tokens)
@property
def efficiency(self) -> float:
return self.n_packed_tokens / (self.max_seq_len *
self.n_packed_examples)
def __call__(
self,
examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
batch = self.base_collator(examples)
assert 'attention_mask' in batch
assert 'input_ids' in batch
for key in batch.keys():
assert key in [
'input_ids',
'labels',
'attention_mask',
'bidirectional_mask',
]
# Cut everything down to size
sizes, trimmed_examples = [], []
for idx in range(batch['attention_mask'].shape[0]):
size, trimmed_example = extract_trim_batch_idx(batch, idx)
sizes.append(size)
trimmed_examples.append(trimmed_example)
# Apply our CS 101 bin packing algorithm.
packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = first_fit_bin_packing(
sizes=sizes,
examples=trimmed_examples,
num_bins=self.out_size,
max_bin_size=self.max_seq_len,
existing_bins=self._leftover_bins,
)
self.n_packed_tokens += n_packed_tokens
self.n_total_tokens += n_total_tokens
self.n_packed_examples += self.out_size
self._leftover_bins = leftover_bins[:self.max_leftover_bins_to_keep]
# Re-pad to max_seq_len and batch
batch = repad(packed_examples,
max_seq_len=self.max_seq_len,
pad_token_id=self.pad_token_id,
padding_side=self.padding_side)
return batch
def extract_trim_batch_idx(batch: Dict[str, torch.Tensor],
idx: int) -> Tuple[int, Dict[str, torch.Tensor]]:
example = {k: v[idx] for k, v in batch.items()}
keep = example['attention_mask'] == 1
size = int(keep.sum())
trim_example = {k: v[keep] for k, v in example.items()}
trim_example['sequence_id'] = torch.zeros_like(trim_example['input_ids'])
return size, trim_example
def combine_in_place(
example: Dict[str, torch.Tensor],
add_on: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
if 'labels' in add_on:
# Prevents the last token in example from being trained to
# predict the first token in add_on, which would make no sense.
add_on['labels'][0] = -100
for k in example.keys():
if k == 'sequence_id':
example[k] = torch.cat(
[example[k], add_on[k] + 1 + torch.max(example[k])])
else:
example[k] = torch.cat([example[k], add_on[k]])
return example
def first_fit_bin_packing(
sizes: List[int], examples: List[Dict[str, torch.Tensor]], num_bins: int,
max_bin_size: int, existing_bins: List[Tuple[int, Dict[str, torch.Tensor]]]
) -> Tuple[List[Dict[str, torch.Tensor]], int, int, List[Tuple[int, Dict[
str, torch.Tensor]]]]:
# Will contain tuples (bin_size_size, packed_example)
bins: List[Tuple[int, Dict[str, torch.Tensor]]] = existing_bins
starting_total_bin_sizes = sum([bin_size for bin_size, _ in bins])
sizes_and_examples = [
(size, example) for size, example in zip(sizes, examples)
]
sorted_sizes_and_examples = sorted(sizes_and_examples,
key=lambda x: x[0],
reverse=True)
required_num_examples = max(0, num_bins - len(bins))
num_examples = len(sizes)
if num_examples < required_num_examples:
for size, example in sorted_sizes_and_examples:
# Can't keep packing. All remaining items get their own bin.
bins.append((size, example))
total_bin_sizes = sum([bin_size for bin_size, _ in bins])
total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes
total_example_sizes = sum(sizes)
if total_new_bin_sizes != total_example_sizes:
raise AssertionError(
f'Error in packing. {total_example_sizes=} does not equal {total_new_bin_sizes=}.'
)
sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True)
bin_sizes, packed_examples = [], []
for bin_size, packed_example in sorted_bins:
bin_sizes.append(bin_size)
packed_examples.append(packed_example)
# Return:
# - the num_bins largest packed examples
# - the total tokens in those examples
# - the total size of all new examples
# - leftover bins
return packed_examples[:num_bins], sum(
bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:]
# Go through each item from longest to shortest.
# Note: all items will either go into an existing or new bin.
for i, (size, example) in enumerate(sorted_sizes_and_examples):
# If we can't keep packing, all remaining items get their own bin.
required_num_examples = max(0, num_bins - len(bins))
n_remaining = num_examples - i
assert n_remaining >= required_num_examples
if n_remaining == required_num_examples:
# Can't keep packing. All remaining items get their own bin.
bins.append((size, example))
continue
# Add it to the first bin it fits in
added = False
for bidx in range(len(bins)):
if bins[bidx][0] + size <= max_bin_size:
bin_size, packed_example = bins.pop(bidx)
bin_size = bin_size + size
packed_example = combine_in_place(packed_example, example)
bins.append((bin_size, packed_example))
added = True
break
# If it didn't fit anywhere, open a new bin
if not added:
bins.append((size, example))
total_bin_sizes = sum([bin_size for bin_size, _ in bins])
total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes
total_example_sizes = sum(sizes)
if total_new_bin_sizes != total_example_sizes:
raise AssertionError(
f'Error in packing. {total_example_sizes=} does not equal {total_new_bin_sizes=}.'
)
sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True)
bin_sizes, packed_examples = [], []
for bin_size, packed_example in sorted_bins:
bin_sizes.append(bin_size)
packed_examples.append(packed_example)
# Return:
# - the num_bins largest packed examples
# - the total tokens in those examples
# - the total size of all new examples
# - leftover bins
return packed_examples[:num_bins], sum(
bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:]
def repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int,
pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]:
def pad_tensor(tensor: torch.Tensor, pad_value: int):
if len(tensor) == max_seq_len:
return tensor
t = torch.full((max_seq_len,),
pad_value,
dtype=tensor.dtype,
device=tensor.device)
if padding_side == 'left':
t[-len(tensor):] = tensor
elif padding_side == 'right':
t[:len(tensor)] = tensor
else:
raise ValueError(f'Unknown {padding_side=}')
return t
pad_vals = {
'input_ids': pad_token_id,
'labels': -100,
'attention_mask': 0,
'bidirectional_mask': 0,
'sequence_id': -1,
}
keys = packed_examples[0].keys()
batch = {}
for key in keys:
batch[key] = torch.stack([
pad_tensor(example[key], pad_vals[key])
for example in packed_examples
])
return batch
if __name__ == '__main__':
from argparse import ArgumentParser, Namespace
from omegaconf import OmegaConf as om
from llmfoundry import (build_finetuning_dataloader,
build_text_denoising_dataloader)
from llmfoundry.data import build_text_dataloader
from llmfoundry.utils import build_tokenizer
def parse_args() -> Namespace:
"""Parse commandline arguments."""
parser = ArgumentParser(
description=
'Profile packing_ratio choices for a particular workload.')
parser.add_argument(
'--yaml-path',
type=str,
required=True,
help='Path to the YAML that defines the workload to profile.')
parser.add_argument('--num-devices',
type=int,
default=None,
help='How many devices your run will use.')
parser.add_argument('--min',
type=float,
required=True,
help='Smallest packing_ratio to test. Must be >=1.')
parser.add_argument(
'--max',
type=float,
required=True,
help='Largest packing_ratio to test. Must be larger than `min`.')
parser.add_argument(
'--num-packing-ratios',
type=int,
default=10,
help=
'Number of packing_ratio values (spaced between `min` and `max) to try.'
)
args = parser.parse_args()
if not os.path.isfile(args.yaml_path):
raise FileNotFoundError(
'`yaml_path` does not correspond to any existing file.')
if args.num_devices < 1:
raise ValueError('`num_devices` must be a positive integer.')
if args.min < 1.0:
raise ValueError('`min` must be >=1.0.')
if args.max < args.min:
raise ValueError('`max` cannot be less than `min`.')
if args.num_packing_ratios < 1:
raise ValueError('`num_packing_ratios` must be a positive integer.')
return args
def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size: int):
if cfg.name == 'text':
return build_text_dataloader(cfg, tokenizer, device_batch_size)
elif cfg.name == 'text_denoising':
return build_text_denoising_dataloader(cfg, tokenizer,
device_batch_size)
elif cfg.name == 'finetuning':
return build_finetuning_dataloader(cfg, tokenizer,
device_batch_size)
else:
raise ValueError(
f'Not sure how to build dataloader with config: {cfg}')
args = parse_args()
with open(args.yaml_path) as f:
cfg = om.load(f)
if 'parameters' in cfg:
cfg = om.to_container(cfg.parameters)
cfg = om.create(cfg)
device_batch_size = cfg.global_train_batch_size // args.num_devices
# Determine the packing_ratio values we'll try
packing_ratios, raw_batch_sizes = [], []
for packing_ratio in np.linspace(args.min,
args.max,
args.num_packing_ratios,
endpoint=True):
packing_ratio = np.round(10 * packing_ratio) / 10
raw_batch_size = int(packing_ratio * device_batch_size)
if raw_batch_size not in raw_batch_sizes:
packing_ratios.append(packing_ratio)
raw_batch_sizes.append(raw_batch_size)
# Fetch a bunch of raw examples once, which we'll re-use
if 'train_loader' not in cfg:
raise ValueError('config must define train_loader')
dataloader_cfg = cfg.train_loader
max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep',
None)
# build tokenizer
if 'tokenizer' not in cfg:
raise ValueError('config must define tokenizer')
resolved_tokenizer_cfg = om.to_container(cfg.tokenizer, resolve=True)
if not isinstance(resolved_tokenizer_cfg, Dict):
raise ValueError(
'tokenizer config needs to be resolved by omegaconf into a Dict.')
tokenizer_cfg: Dict[Any, Any] = resolved_tokenizer_cfg
tokenizer_name = tokenizer_cfg['name']
tokenizer_kwargs = tokenizer_cfg.get('kwargs', {})
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
# Turn off packing for the dataloader (we want raw, pre-packed examples)
dataloader_cfg.dataset.packing_ratio = None
dataloader_cfg.dataset.max_leftovers_to_keep = None
train_dataloader = build_dataloader(dataloader_cfg, tokenizer,
max(raw_batch_sizes) * 100)
# Get a bunch of raw examples
big_batch = next(iter(train_dataloader))
def split_big_batch(raw_batch_size: int) -> List:
input_ids = big_batch['input_ids'].split(raw_batch_size)
batches = [{'input_ids': x} for x in input_ids]
for key in big_batch.keys():
if key == 'input_ids':
continue
for idx, split in enumerate(big_batch[key].split(raw_batch_size)):
batches[idx].update({key: split})
return batches
def profile_packing(raw_batch_size: int) -> Tuple[float, float]:
packer = BinPackWrapper(
collator=lambda x: x,
target_batch_size=device_batch_size,
max_seq_len=dataloader_cfg.dataset.max_seq_len,
pad_token_id=0, # <-- Doesn't need to be correct for profiling
padding_side='left', # <-- Doesn't need to be correct for profiling
max_leftover_bins_to_keep=max_leftovers_to_keep)
# Simulate feeding the packing collator a bunch of data
for batch in split_big_batch(raw_batch_size):
if batch['input_ids'].shape[0] < device_batch_size:
continue
_ = packer(batch)
# Return the padding / waste stats over that bunch of data
padding_percent = 100 * (1 - packer.efficiency)
waste_percent = 100 * packer.waste
return padding_percent, waste_percent
header = '\n\n\n packing_ratio | % PADDING | % WASTE'
fstr = ' {:5.1f} | {:5.2f}% | {:6.2f}%'
print(header)
print('-' * len(header))
for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes):
padding, waste = profile_packing(raw_batch_size)
print(fstr.format(packing_ratio, padding, waste))
|