|
|
|
|
|
|
|
import os |
|
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
from omegaconf import DictConfig |
|
from transformers import PreTrainedTokenizerBase |
|
|
|
|
|
class BinPackWrapper: |
|
"""Utility collator for packing to reduce padding.""" |
|
|
|
def __init__(self, |
|
collator: Callable, |
|
target_batch_size: int, |
|
max_seq_len: int, |
|
pad_token_id: int, |
|
padding_side: Literal['left', 'right'], |
|
max_leftover_bins_to_keep: Optional[int] = None): |
|
self.base_collator = collator |
|
self.out_size = int(target_batch_size) |
|
self.max_seq_len = int(max_seq_len) |
|
self.pad_token_id = int(pad_token_id) |
|
self.padding_side = padding_side |
|
|
|
if self.out_size <= 0: |
|
raise ValueError(f'{target_batch_size=} must be >0.') |
|
if self.max_seq_len <= 0: |
|
raise ValueError(f'{max_seq_len=} must be >0.') |
|
if self.pad_token_id < 0: |
|
raise ValueError(f'{pad_token_id=} must be >=0.') |
|
|
|
if max_leftover_bins_to_keep is None: |
|
self.max_leftover_bins_to_keep = int(10 * self.out_size) |
|
elif max_leftover_bins_to_keep < 0: |
|
raise ValueError( |
|
f'{max_leftover_bins_to_keep=} must be >=0 or None.') |
|
else: |
|
self.max_leftover_bins_to_keep = int(max_leftover_bins_to_keep) |
|
|
|
self.n_packed_tokens = 0 |
|
self.n_total_tokens = 0 |
|
self.n_packed_examples = 0 |
|
|
|
self._leftover_bins: List[Tuple[int, Dict[str, torch.Tensor]]] = [] |
|
|
|
@property |
|
def waste(self) -> float: |
|
return 1 - (self.n_packed_tokens / self.n_total_tokens) |
|
|
|
@property |
|
def efficiency(self) -> float: |
|
return self.n_packed_tokens / (self.max_seq_len * |
|
self.n_packed_examples) |
|
|
|
def __call__( |
|
self, |
|
examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: |
|
batch = self.base_collator(examples) |
|
|
|
assert 'attention_mask' in batch |
|
assert 'input_ids' in batch |
|
|
|
for key in batch.keys(): |
|
assert key in [ |
|
'input_ids', |
|
'labels', |
|
'attention_mask', |
|
'bidirectional_mask', |
|
] |
|
|
|
|
|
sizes, trimmed_examples = [], [] |
|
for idx in range(batch['attention_mask'].shape[0]): |
|
size, trimmed_example = extract_trim_batch_idx(batch, idx) |
|
sizes.append(size) |
|
trimmed_examples.append(trimmed_example) |
|
|
|
|
|
packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = first_fit_bin_packing( |
|
sizes=sizes, |
|
examples=trimmed_examples, |
|
num_bins=self.out_size, |
|
max_bin_size=self.max_seq_len, |
|
existing_bins=self._leftover_bins, |
|
) |
|
self.n_packed_tokens += n_packed_tokens |
|
self.n_total_tokens += n_total_tokens |
|
self.n_packed_examples += self.out_size |
|
self._leftover_bins = leftover_bins[:self.max_leftover_bins_to_keep] |
|
|
|
|
|
batch = repad(packed_examples, |
|
max_seq_len=self.max_seq_len, |
|
pad_token_id=self.pad_token_id, |
|
padding_side=self.padding_side) |
|
return batch |
|
|
|
|
|
def extract_trim_batch_idx(batch: Dict[str, torch.Tensor], |
|
idx: int) -> Tuple[int, Dict[str, torch.Tensor]]: |
|
example = {k: v[idx] for k, v in batch.items()} |
|
|
|
keep = example['attention_mask'] == 1 |
|
size = int(keep.sum()) |
|
trim_example = {k: v[keep] for k, v in example.items()} |
|
trim_example['sequence_id'] = torch.zeros_like(trim_example['input_ids']) |
|
|
|
return size, trim_example |
|
|
|
|
|
def combine_in_place( |
|
example: Dict[str, torch.Tensor], |
|
add_on: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
if 'labels' in add_on: |
|
|
|
|
|
add_on['labels'][0] = -100 |
|
|
|
for k in example.keys(): |
|
if k == 'sequence_id': |
|
example[k] = torch.cat( |
|
[example[k], add_on[k] + 1 + torch.max(example[k])]) |
|
else: |
|
example[k] = torch.cat([example[k], add_on[k]]) |
|
return example |
|
|
|
|
|
def first_fit_bin_packing( |
|
sizes: List[int], examples: List[Dict[str, torch.Tensor]], num_bins: int, |
|
max_bin_size: int, existing_bins: List[Tuple[int, Dict[str, torch.Tensor]]] |
|
) -> Tuple[List[Dict[str, torch.Tensor]], int, int, List[Tuple[int, Dict[ |
|
str, torch.Tensor]]]]: |
|
|
|
|
|
bins: List[Tuple[int, Dict[str, torch.Tensor]]] = existing_bins |
|
|
|
starting_total_bin_sizes = sum([bin_size for bin_size, _ in bins]) |
|
|
|
sizes_and_examples = [ |
|
(size, example) for size, example in zip(sizes, examples) |
|
] |
|
sorted_sizes_and_examples = sorted(sizes_and_examples, |
|
key=lambda x: x[0], |
|
reverse=True) |
|
|
|
required_num_examples = max(0, num_bins - len(bins)) |
|
num_examples = len(sizes) |
|
if num_examples < required_num_examples: |
|
for size, example in sorted_sizes_and_examples: |
|
|
|
bins.append((size, example)) |
|
|
|
total_bin_sizes = sum([bin_size for bin_size, _ in bins]) |
|
total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes |
|
total_example_sizes = sum(sizes) |
|
if total_new_bin_sizes != total_example_sizes: |
|
raise AssertionError( |
|
f'Error in packing. {total_example_sizes=} does not equal {total_new_bin_sizes=}.' |
|
) |
|
|
|
sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True) |
|
bin_sizes, packed_examples = [], [] |
|
for bin_size, packed_example in sorted_bins: |
|
bin_sizes.append(bin_size) |
|
packed_examples.append(packed_example) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return packed_examples[:num_bins], sum( |
|
bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:] |
|
|
|
|
|
|
|
for i, (size, example) in enumerate(sorted_sizes_and_examples): |
|
|
|
required_num_examples = max(0, num_bins - len(bins)) |
|
n_remaining = num_examples - i |
|
assert n_remaining >= required_num_examples |
|
if n_remaining == required_num_examples: |
|
|
|
bins.append((size, example)) |
|
continue |
|
|
|
|
|
added = False |
|
for bidx in range(len(bins)): |
|
if bins[bidx][0] + size <= max_bin_size: |
|
bin_size, packed_example = bins.pop(bidx) |
|
bin_size = bin_size + size |
|
packed_example = combine_in_place(packed_example, example) |
|
bins.append((bin_size, packed_example)) |
|
added = True |
|
break |
|
|
|
if not added: |
|
bins.append((size, example)) |
|
|
|
total_bin_sizes = sum([bin_size for bin_size, _ in bins]) |
|
total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes |
|
total_example_sizes = sum(sizes) |
|
if total_new_bin_sizes != total_example_sizes: |
|
raise AssertionError( |
|
f'Error in packing. {total_example_sizes=} does not equal {total_new_bin_sizes=}.' |
|
) |
|
|
|
sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True) |
|
bin_sizes, packed_examples = [], [] |
|
for bin_size, packed_example in sorted_bins: |
|
bin_sizes.append(bin_size) |
|
packed_examples.append(packed_example) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return packed_examples[:num_bins], sum( |
|
bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:] |
|
|
|
|
|
def repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int, |
|
pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]: |
|
|
|
def pad_tensor(tensor: torch.Tensor, pad_value: int): |
|
if len(tensor) == max_seq_len: |
|
return tensor |
|
t = torch.full((max_seq_len,), |
|
pad_value, |
|
dtype=tensor.dtype, |
|
device=tensor.device) |
|
if padding_side == 'left': |
|
t[-len(tensor):] = tensor |
|
elif padding_side == 'right': |
|
t[:len(tensor)] = tensor |
|
else: |
|
raise ValueError(f'Unknown {padding_side=}') |
|
return t |
|
|
|
pad_vals = { |
|
'input_ids': pad_token_id, |
|
'labels': -100, |
|
'attention_mask': 0, |
|
'bidirectional_mask': 0, |
|
'sequence_id': -1, |
|
} |
|
keys = packed_examples[0].keys() |
|
batch = {} |
|
for key in keys: |
|
batch[key] = torch.stack([ |
|
pad_tensor(example[key], pad_vals[key]) |
|
for example in packed_examples |
|
]) |
|
return batch |
|
|
|
|
|
if __name__ == '__main__': |
|
from argparse import ArgumentParser, Namespace |
|
|
|
from omegaconf import OmegaConf as om |
|
|
|
from llmfoundry import (build_finetuning_dataloader, |
|
build_text_denoising_dataloader) |
|
from llmfoundry.data import build_text_dataloader |
|
from llmfoundry.utils import build_tokenizer |
|
|
|
def parse_args() -> Namespace: |
|
"""Parse commandline arguments.""" |
|
parser = ArgumentParser( |
|
description= |
|
'Profile packing_ratio choices for a particular workload.') |
|
parser.add_argument( |
|
'--yaml-path', |
|
type=str, |
|
required=True, |
|
help='Path to the YAML that defines the workload to profile.') |
|
parser.add_argument('--num-devices', |
|
type=int, |
|
default=None, |
|
help='How many devices your run will use.') |
|
parser.add_argument('--min', |
|
type=float, |
|
required=True, |
|
help='Smallest packing_ratio to test. Must be >=1.') |
|
parser.add_argument( |
|
'--max', |
|
type=float, |
|
required=True, |
|
help='Largest packing_ratio to test. Must be larger than `min`.') |
|
parser.add_argument( |
|
'--num-packing-ratios', |
|
type=int, |
|
default=10, |
|
help= |
|
'Number of packing_ratio values (spaced between `min` and `max) to try.' |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
if not os.path.isfile(args.yaml_path): |
|
raise FileNotFoundError( |
|
'`yaml_path` does not correspond to any existing file.') |
|
if args.num_devices < 1: |
|
raise ValueError('`num_devices` must be a positive integer.') |
|
if args.min < 1.0: |
|
raise ValueError('`min` must be >=1.0.') |
|
if args.max < args.min: |
|
raise ValueError('`max` cannot be less than `min`.') |
|
if args.num_packing_ratios < 1: |
|
raise ValueError('`num_packing_ratios` must be a positive integer.') |
|
return args |
|
|
|
def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, |
|
device_batch_size: int): |
|
if cfg.name == 'text': |
|
return build_text_dataloader(cfg, tokenizer, device_batch_size) |
|
elif cfg.name == 'text_denoising': |
|
return build_text_denoising_dataloader(cfg, tokenizer, |
|
device_batch_size) |
|
elif cfg.name == 'finetuning': |
|
return build_finetuning_dataloader(cfg, tokenizer, |
|
device_batch_size) |
|
else: |
|
raise ValueError( |
|
f'Not sure how to build dataloader with config: {cfg}') |
|
|
|
args = parse_args() |
|
|
|
with open(args.yaml_path) as f: |
|
cfg = om.load(f) |
|
if 'parameters' in cfg: |
|
cfg = om.to_container(cfg.parameters) |
|
cfg = om.create(cfg) |
|
device_batch_size = cfg.global_train_batch_size // args.num_devices |
|
|
|
|
|
packing_ratios, raw_batch_sizes = [], [] |
|
for packing_ratio in np.linspace(args.min, |
|
args.max, |
|
args.num_packing_ratios, |
|
endpoint=True): |
|
packing_ratio = np.round(10 * packing_ratio) / 10 |
|
raw_batch_size = int(packing_ratio * device_batch_size) |
|
if raw_batch_size not in raw_batch_sizes: |
|
packing_ratios.append(packing_ratio) |
|
raw_batch_sizes.append(raw_batch_size) |
|
|
|
|
|
if 'train_loader' not in cfg: |
|
raise ValueError('config must define train_loader') |
|
dataloader_cfg = cfg.train_loader |
|
|
|
max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep', |
|
None) |
|
|
|
|
|
if 'tokenizer' not in cfg: |
|
raise ValueError('config must define tokenizer') |
|
|
|
resolved_tokenizer_cfg = om.to_container(cfg.tokenizer, resolve=True) |
|
if not isinstance(resolved_tokenizer_cfg, Dict): |
|
raise ValueError( |
|
'tokenizer config needs to be resolved by omegaconf into a Dict.') |
|
tokenizer_cfg: Dict[Any, Any] = resolved_tokenizer_cfg |
|
|
|
tokenizer_name = tokenizer_cfg['name'] |
|
tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) |
|
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) |
|
|
|
|
|
dataloader_cfg.dataset.packing_ratio = None |
|
dataloader_cfg.dataset.max_leftovers_to_keep = None |
|
train_dataloader = build_dataloader(dataloader_cfg, tokenizer, |
|
max(raw_batch_sizes) * 100) |
|
|
|
|
|
big_batch = next(iter(train_dataloader)) |
|
|
|
def split_big_batch(raw_batch_size: int) -> List: |
|
input_ids = big_batch['input_ids'].split(raw_batch_size) |
|
batches = [{'input_ids': x} for x in input_ids] |
|
|
|
for key in big_batch.keys(): |
|
if key == 'input_ids': |
|
continue |
|
for idx, split in enumerate(big_batch[key].split(raw_batch_size)): |
|
batches[idx].update({key: split}) |
|
return batches |
|
|
|
def profile_packing(raw_batch_size: int) -> Tuple[float, float]: |
|
packer = BinPackWrapper( |
|
collator=lambda x: x, |
|
target_batch_size=device_batch_size, |
|
max_seq_len=dataloader_cfg.dataset.max_seq_len, |
|
pad_token_id=0, |
|
padding_side='left', |
|
max_leftover_bins_to_keep=max_leftovers_to_keep) |
|
|
|
|
|
for batch in split_big_batch(raw_batch_size): |
|
if batch['input_ids'].shape[0] < device_batch_size: |
|
continue |
|
_ = packer(batch) |
|
|
|
|
|
padding_percent = 100 * (1 - packer.efficiency) |
|
waste_percent = 100 * packer.waste |
|
return padding_percent, waste_percent |
|
|
|
header = '\n\n\n packing_ratio | % PADDING | % WASTE' |
|
fstr = ' {:5.1f} | {:5.2f}% | {:6.2f}%' |
|
|
|
print(header) |
|
print('-' * len(header)) |
|
for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes): |
|
padding, waste = profile_packing(raw_batch_size) |
|
print(fstr.format(packing_ratio, padding, waste)) |
|
|