File size: 16,380 Bytes
de4ade4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple

import numpy as np
import torch
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase


class BinPackWrapper:
    """Utility collator for packing to reduce padding."""

    def __init__(self,
                 collator: Callable,
                 target_batch_size: int,
                 max_seq_len: int,
                 pad_token_id: int,
                 padding_side: Literal['left', 'right'],
                 max_leftover_bins_to_keep: Optional[int] = None):
        self.base_collator = collator
        self.out_size = int(target_batch_size)
        self.max_seq_len = int(max_seq_len)
        self.pad_token_id = int(pad_token_id)
        self.padding_side = padding_side

        if self.out_size <= 0:
            raise ValueError(f'{target_batch_size=} must be >0.')
        if self.max_seq_len <= 0:
            raise ValueError(f'{max_seq_len=} must be >0.')
        if self.pad_token_id < 0:
            raise ValueError(f'{pad_token_id=} must be >=0.')

        if max_leftover_bins_to_keep is None:
            self.max_leftover_bins_to_keep = int(10 * self.out_size)
        elif max_leftover_bins_to_keep < 0:
            raise ValueError(
                f'{max_leftover_bins_to_keep=} must be >=0 or None.')
        else:
            self.max_leftover_bins_to_keep = int(max_leftover_bins_to_keep)

        self.n_packed_tokens = 0
        self.n_total_tokens = 0
        self.n_packed_examples = 0

        self._leftover_bins: List[Tuple[int, Dict[str, torch.Tensor]]] = []

    @property
    def waste(self) -> float:
        return 1 - (self.n_packed_tokens / self.n_total_tokens)

    @property
    def efficiency(self) -> float:
        return self.n_packed_tokens / (self.max_seq_len *
                                       self.n_packed_examples)

    def __call__(
            self,
            examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        batch = self.base_collator(examples)

        assert 'attention_mask' in batch
        assert 'input_ids' in batch

        for key in batch.keys():
            assert key in [
                'input_ids',
                'labels',
                'attention_mask',
                'bidirectional_mask',
            ]

        # Cut everything down to size
        sizes, trimmed_examples = [], []
        for idx in range(batch['attention_mask'].shape[0]):
            size, trimmed_example = extract_trim_batch_idx(batch, idx)
            sizes.append(size)
            trimmed_examples.append(trimmed_example)

        # Apply our CS 101 bin packing algorithm.
        packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = first_fit_bin_packing(
            sizes=sizes,
            examples=trimmed_examples,
            num_bins=self.out_size,
            max_bin_size=self.max_seq_len,
            existing_bins=self._leftover_bins,
        )
        self.n_packed_tokens += n_packed_tokens
        self.n_total_tokens += n_total_tokens
        self.n_packed_examples += self.out_size
        self._leftover_bins = leftover_bins[:self.max_leftover_bins_to_keep]

        # Re-pad to max_seq_len and batch
        batch = repad(packed_examples,
                      max_seq_len=self.max_seq_len,
                      pad_token_id=self.pad_token_id,
                      padding_side=self.padding_side)
        return batch


def extract_trim_batch_idx(batch: Dict[str, torch.Tensor],
                           idx: int) -> Tuple[int, Dict[str, torch.Tensor]]:
    example = {k: v[idx] for k, v in batch.items()}

    keep = example['attention_mask'] == 1
    size = int(keep.sum())
    trim_example = {k: v[keep] for k, v in example.items()}
    trim_example['sequence_id'] = torch.zeros_like(trim_example['input_ids'])

    return size, trim_example


def combine_in_place(
        example: Dict[str, torch.Tensor],
        add_on: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    if 'labels' in add_on:
        # Prevents the last token in example from being trained to
        # predict the first token in add_on, which would make no sense.
        add_on['labels'][0] = -100

    for k in example.keys():
        if k == 'sequence_id':
            example[k] = torch.cat(
                [example[k], add_on[k] + 1 + torch.max(example[k])])
        else:
            example[k] = torch.cat([example[k], add_on[k]])
    return example


def first_fit_bin_packing(
    sizes: List[int], examples: List[Dict[str, torch.Tensor]], num_bins: int,
    max_bin_size: int, existing_bins: List[Tuple[int, Dict[str, torch.Tensor]]]
) -> Tuple[List[Dict[str, torch.Tensor]], int, int, List[Tuple[int, Dict[
        str, torch.Tensor]]]]:

    # Will contain tuples (bin_size_size, packed_example)
    bins: List[Tuple[int, Dict[str, torch.Tensor]]] = existing_bins

    starting_total_bin_sizes = sum([bin_size for bin_size, _ in bins])

    sizes_and_examples = [
        (size, example) for size, example in zip(sizes, examples)
    ]
    sorted_sizes_and_examples = sorted(sizes_and_examples,
                                       key=lambda x: x[0],
                                       reverse=True)

    required_num_examples = max(0, num_bins - len(bins))
    num_examples = len(sizes)
    if num_examples < required_num_examples:
        for size, example in sorted_sizes_and_examples:
            # Can't keep packing. All remaining items get their own bin.
            bins.append((size, example))

        total_bin_sizes = sum([bin_size for bin_size, _ in bins])
        total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes
        total_example_sizes = sum(sizes)
        if total_new_bin_sizes != total_example_sizes:
            raise AssertionError(
                f'Error in packing. {total_example_sizes=} does not equal {total_new_bin_sizes=}.'
            )

        sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True)
        bin_sizes, packed_examples = [], []
        for bin_size, packed_example in sorted_bins:
            bin_sizes.append(bin_size)
            packed_examples.append(packed_example)

        # Return:
        #  - the num_bins largest packed examples
        #  - the total tokens in those examples
        #  - the total size of all new examples
        #  - leftover bins
        return packed_examples[:num_bins], sum(
            bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:]

    # Go through each item from longest to shortest.
    # Note: all items will either go into an existing or new bin.
    for i, (size, example) in enumerate(sorted_sizes_and_examples):
        # If we can't keep packing, all remaining items get their own bin.
        required_num_examples = max(0, num_bins - len(bins))
        n_remaining = num_examples - i
        assert n_remaining >= required_num_examples
        if n_remaining == required_num_examples:
            # Can't keep packing. All remaining items get their own bin.
            bins.append((size, example))
            continue

        # Add it to the first bin it fits in
        added = False
        for bidx in range(len(bins)):
            if bins[bidx][0] + size <= max_bin_size:
                bin_size, packed_example = bins.pop(bidx)
                bin_size = bin_size + size
                packed_example = combine_in_place(packed_example, example)
                bins.append((bin_size, packed_example))
                added = True
                break
        # If it didn't fit anywhere, open a new bin
        if not added:
            bins.append((size, example))

    total_bin_sizes = sum([bin_size for bin_size, _ in bins])
    total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes
    total_example_sizes = sum(sizes)
    if total_new_bin_sizes != total_example_sizes:
        raise AssertionError(
            f'Error in packing. {total_example_sizes=} does not equal {total_new_bin_sizes=}.'
        )

    sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True)
    bin_sizes, packed_examples = [], []
    for bin_size, packed_example in sorted_bins:
        bin_sizes.append(bin_size)
        packed_examples.append(packed_example)

    # Return:
    #  - the num_bins largest packed examples
    #  - the total tokens in those examples
    #  - the total size of all new examples
    #  - leftover bins
    return packed_examples[:num_bins], sum(
        bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:]


def repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int,
          pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]:

    def pad_tensor(tensor: torch.Tensor, pad_value: int):
        if len(tensor) == max_seq_len:
            return tensor
        t = torch.full((max_seq_len,),
                       pad_value,
                       dtype=tensor.dtype,
                       device=tensor.device)
        if padding_side == 'left':
            t[-len(tensor):] = tensor
        elif padding_side == 'right':
            t[:len(tensor)] = tensor
        else:
            raise ValueError(f'Unknown {padding_side=}')
        return t

    pad_vals = {
        'input_ids': pad_token_id,
        'labels': -100,
        'attention_mask': 0,
        'bidirectional_mask': 0,
        'sequence_id': -1,
    }
    keys = packed_examples[0].keys()
    batch = {}
    for key in keys:
        batch[key] = torch.stack([
            pad_tensor(example[key], pad_vals[key])
            for example in packed_examples
        ])
    return batch


if __name__ == '__main__':
    from argparse import ArgumentParser, Namespace

    from omegaconf import OmegaConf as om

    from llmfoundry import (build_finetuning_dataloader,
                            build_text_denoising_dataloader)
    from llmfoundry.data import build_text_dataloader
    from llmfoundry.utils import build_tokenizer

    def parse_args() -> Namespace:
        """Parse commandline arguments."""
        parser = ArgumentParser(
            description=
            'Profile packing_ratio choices for a particular workload.')
        parser.add_argument(
            '--yaml-path',
            type=str,
            required=True,
            help='Path to the YAML that defines the workload to profile.')
        parser.add_argument('--num-devices',
                            type=int,
                            default=None,
                            help='How many devices your run will use.')
        parser.add_argument('--min',
                            type=float,
                            required=True,
                            help='Smallest packing_ratio to test. Must be >=1.')
        parser.add_argument(
            '--max',
            type=float,
            required=True,
            help='Largest packing_ratio to test. Must be larger than `min`.')
        parser.add_argument(
            '--num-packing-ratios',
            type=int,
            default=10,
            help=
            'Number of packing_ratio values (spaced between `min` and `max) to try.'
        )

        args = parser.parse_args()

        if not os.path.isfile(args.yaml_path):
            raise FileNotFoundError(
                '`yaml_path` does not correspond to any existing file.')
        if args.num_devices < 1:
            raise ValueError('`num_devices` must be a positive integer.')
        if args.min < 1.0:
            raise ValueError('`min` must be >=1.0.')
        if args.max < args.min:
            raise ValueError('`max` cannot be less than `min`.')
        if args.num_packing_ratios < 1:
            raise ValueError('`num_packing_ratios` must be a positive integer.')
        return args

    def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
                         device_batch_size: int):
        if cfg.name == 'text':
            return build_text_dataloader(cfg, tokenizer, device_batch_size)
        elif cfg.name == 'text_denoising':
            return build_text_denoising_dataloader(cfg, tokenizer,
                                                   device_batch_size)
        elif cfg.name == 'finetuning':
            return build_finetuning_dataloader(cfg, tokenizer,
                                               device_batch_size)
        else:
            raise ValueError(
                f'Not sure how to build dataloader with config: {cfg}')

    args = parse_args()

    with open(args.yaml_path) as f:
        cfg = om.load(f)
    if 'parameters' in cfg:
        cfg = om.to_container(cfg.parameters)
        cfg = om.create(cfg)
    device_batch_size = cfg.global_train_batch_size // args.num_devices

    # Determine the packing_ratio values we'll try
    packing_ratios, raw_batch_sizes = [], []
    for packing_ratio in np.linspace(args.min,
                                     args.max,
                                     args.num_packing_ratios,
                                     endpoint=True):
        packing_ratio = np.round(10 * packing_ratio) / 10
        raw_batch_size = int(packing_ratio * device_batch_size)
        if raw_batch_size not in raw_batch_sizes:
            packing_ratios.append(packing_ratio)
            raw_batch_sizes.append(raw_batch_size)

    # Fetch a bunch of raw examples once, which we'll re-use
    if 'train_loader' not in cfg:
        raise ValueError('config must define train_loader')
    dataloader_cfg = cfg.train_loader

    max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep',
                                                       None)

    # build tokenizer
    if 'tokenizer' not in cfg:
        raise ValueError('config must define tokenizer')

    resolved_tokenizer_cfg = om.to_container(cfg.tokenizer, resolve=True)
    if not isinstance(resolved_tokenizer_cfg, Dict):
        raise ValueError(
            'tokenizer config needs to be resolved by omegaconf into a Dict.')
    tokenizer_cfg: Dict[Any, Any] = resolved_tokenizer_cfg

    tokenizer_name = tokenizer_cfg['name']
    tokenizer_kwargs = tokenizer_cfg.get('kwargs', {})
    tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

    # Turn off packing for the dataloader (we want raw, pre-packed examples)
    dataloader_cfg.dataset.packing_ratio = None
    dataloader_cfg.dataset.max_leftovers_to_keep = None
    train_dataloader = build_dataloader(dataloader_cfg, tokenizer,
                                        max(raw_batch_sizes) * 100)

    # Get a bunch of raw examples
    big_batch = next(iter(train_dataloader))

    def split_big_batch(raw_batch_size: int) -> List:
        input_ids = big_batch['input_ids'].split(raw_batch_size)
        batches = [{'input_ids': x} for x in input_ids]

        for key in big_batch.keys():
            if key == 'input_ids':
                continue
            for idx, split in enumerate(big_batch[key].split(raw_batch_size)):
                batches[idx].update({key: split})
        return batches

    def profile_packing(raw_batch_size: int) -> Tuple[float, float]:
        packer = BinPackWrapper(
            collator=lambda x: x,
            target_batch_size=device_batch_size,
            max_seq_len=dataloader_cfg.dataset.max_seq_len,
            pad_token_id=0,  # <-- Doesn't need to be correct for profiling
            padding_side='left',  # <-- Doesn't need to be correct for profiling
            max_leftover_bins_to_keep=max_leftovers_to_keep)

        # Simulate feeding the packing collator a bunch of data
        for batch in split_big_batch(raw_batch_size):
            if batch['input_ids'].shape[0] < device_batch_size:
                continue
            _ = packer(batch)

        # Return the padding / waste stats over that bunch of data
        padding_percent = 100 * (1 - packer.efficiency)
        waste_percent = 100 * packer.waste
        return padding_percent, waste_percent

    header = '\n\n\n packing_ratio | % PADDING | % WASTE'
    fstr = '        {:5.1f}  |  {:5.2f}%   | {:6.2f}%'

    print(header)
    print('-' * len(header))
    for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes):
        padding, waste = profile_packing(raw_batch_size)
        print(fstr.format(packing_ratio, padding, waste))