Last commit not found
"""Module for testing prompt tokenizers.""" | |
import json | |
import logging | |
import unittest | |
from pathlib import Path | |
from transformers import AutoTokenizer | |
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy | |
from axolotl.prompters import ShareGPTPrompter | |
logging.basicConfig(level="INFO") | |
class TestPromptTokenizationStrategies(unittest.TestCase): | |
""" | |
Test class for prompt tokenization strategies. | |
""" | |
def setUp(self) -> None: | |
# pylint: disable=duplicate-code | |
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") | |
self.tokenizer.add_special_tokens( | |
{ | |
"bos_token": "<s>", | |
"eos_token": "</s>", | |
"unk_token": "<unk>", | |
} | |
) | |
def test_sharegpt_integration(self): | |
print(Path(__file__).parent) | |
with open( | |
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8" | |
) as fin: | |
data = fin.read() | |
conversation = json.loads(data) | |
with open( | |
Path(__file__).parent / "fixtures/conversation.tokenized.json", | |
encoding="utf-8", | |
) as fin: | |
data = fin.read() | |
tokenized_conversation = json.loads(data) | |
prompter = ShareGPTPrompter("chat") | |
strat = ShareGPTPromptTokenizingStrategy( | |
prompter, | |
self.tokenizer, | |
False, | |
2048, | |
) | |
example = strat.tokenize_prompt(conversation) | |
for fields in ["input_ids", "attention_mask", "labels"]: | |
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields])) | |
self.assertEqual(example[fields], tokenized_conversation[fields]) | |
if __name__ == "__main__": | |
unittest.main() | |