import os import pytest from bpetokenizer import BPETokenizer, Tokenizer @pytest.fixture def tokenizer(): return Tokenizer() @pytest.fixture def bpe_tokenizer(): return BPETokenizer() def test_train(): """Test the training of the tokenizer.""" text = "aaabdaaabac" tokenizer = Tokenizer() tokenizer.train(text, 259, verbose=False) assert len(tokenizer.vocab) == 259 assert len(tokenizer.merges) == 3 assert tokenizer.decode(tokenizer.encode(text)) == "aaabdaaabac" def test_encode(): """Test the encoding of the tokenizer.""" text = "aaabdaaabac" tokenizer = Tokenizer() tokenizer.train(text, 259, verbose=False) assert tokenizer.encode("aaabdaaabac") == [258, 100, 258, 97, 99] def test_decode(): """Test the decoding of the tokenizer.""" text = "aaabdaaabac" tokenizer = Tokenizer() tokenizer.train(text, 259, verbose=False) assert tokenizer.decode([258, 100, 258, 97, 99]) == "aaabdaaabac" def test_train_bpe(): """Test the training of the BPE tokenizer.""" text = "aaabdaaabac" tokenizer = BPETokenizer() tokenizer.train(text, 256 + 3, verbose=False) assert len(tokenizer.vocab) == 259 assert len(tokenizer.merges) == 3 assert tokenizer.decode(tokenizer.encode(text)) == "aaabdaaabac" def test_train_bpe_w_special_tokens(): """Test the bpetokenizer with special tokens""" special_tokens = { "<|endoftext|>": 1001, "<|startoftext|>": 1002, "[SPECIAL1]": 1003, "[SPECIAL2]": 1004, } PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" tokenizer = BPETokenizer(special_tokens=special_tokens, pattern=PATTERN) texts = "<|startoftext|> Hello, World! This is a sample text with the special tokens [SPECIAL1] and [SPECIAL2] to test the tokenizer.<|endoftext|>" tokenizer.train(texts, vocab_size=310, verbose=False) assert len(tokenizer.vocab) == 310 assert len(tokenizer.merges) == 310 - 256 assert tokenizer.decode(tokenizer.encode(texts)) == texts assert tokenizer.inverse_special_tokens == {v: k for k,v in special_tokens.items()} assert tokenizer.special_tokens == special_tokens assert tokenizer.pattern == PATTERN