Last commit not found
"""Module for testing streaming dataset sequence packing""" | |
import pytest | |
from datasets import concatenate_datasets, load_dataset | |
from torch.utils.data import DataLoader, RandomSampler | |
from transformers import AutoTokenizer | |
from axolotl.datasets import TokenizedPromptDataset | |
from axolotl.prompt_strategies.completion import load | |
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq | |
from axolotl.utils.dict import DictDefault | |
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths | |
def fixture_tokenizer(): | |
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") | |
tokenizer.pad_token = "</s>" | |
return tokenizer | |
def fixture_max_seq_length(): | |
return 4096 | |
class TestBatchedSamplerPacking: | |
""" | |
Test class for packing streaming dataset sequences | |
""" | |
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length): | |
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 | |
dataset = load_dataset( | |
"Trelis/tiny-shakespeare", | |
split="train", | |
) | |
cfg = DictDefault( | |
{ | |
"train_on_inputs": True, | |
"sequence_len": max_seq_length, | |
} | |
) | |
ds_cfg = DictDefault( | |
{ | |
"field": "Text", | |
} | |
) | |
completion_strategy = load(tokenizer, cfg, ds_cfg) | |
dataset_wrapper = TokenizedPromptDataset( | |
completion_strategy, | |
dataset, | |
) | |
train_dataset = concatenate_datasets([dataset_wrapper]) | |
batch_sampler = MultipackBatchSampler( | |
sampler=RandomSampler(train_dataset), | |
batch_size=batch_size, | |
drop_last=True, | |
batch_max_len=max_seq_length, | |
lengths=get_dataset_lengths(train_dataset), | |
) | |
loader = DataLoader( | |
train_dataset, | |
batch_sampler=batch_sampler, | |
collate_fn=V2BatchSamplerDataCollatorForSeq2Seq( # pylint: disable=unexpected-keyword-arg | |
tokenizer=tokenizer, | |
padding=True, | |
pad_to_multiple_of=max_seq_length, | |
return_tensors="pt", | |
), | |
num_workers=num_workers, | |
) | |
inputs = next(iter(loader)) | |
assert inputs["input_ids"].shape == (batch_size, max_seq_length) | |
assert inputs["labels"].shape == (batch_size, max_seq_length) | |
assert inputs["attention_mask"].shape == (batch_size, max_seq_length) | |
assert inputs["input_ids"].tolist()[0][0] == 2 | |
assert inputs["labels"].tolist()[0][0] == -100 | |
assert inputs["attention_mask"].tolist()[0][0] == 0 | |
assert inputs["attention_mask"].tolist()[0][-1] > 1 | |
if batch_size >= 2: | |
assert inputs["input_ids"].tolist()[1][0] == 2 | |
assert inputs["labels"].tolist()[1][0] == -100 | |
assert inputs["attention_mask"].tolist()[1][0] == 0 | |
assert inputs["attention_mask"].tolist()[1][-1] > 1 | |