import os import pytest from jinja2.exceptions import TemplateError from llm_studio.app_utils.hugging_face_utils import get_chat_template from llm_studio.src.datasets.text_utils import get_tokenizer from llm_studio.src.utils.config_utils import load_config_yaml def test_chat_template_no_system_prompt(): test_directory = os.path.abspath(os.path.dirname(__file__)) cfg_path = os.path.join(test_directory, "../test_data/cfg.yaml") cfg = load_config_yaml(cfg_path) cfg.dataset.system_column = "None" cfg.dataset.text_prompt_start = "<|prompt|>" cfg.dataset.text_answer_separator = "<|answer|>" cfg.dataset.add_eos_token_to_prompt = True cfg.dataset.add_eos_token_to_answer = True tokenizer = get_tokenizer(cfg) tokenizer.eos_token = "" tokenizer.chat_template = get_chat_template(cfg) chat = [ {"role": "user", "content": "[user prompt]"}, {"role": "assistant", "content": "[assistant response]"}, {"role": "user", "content": "[user prompt2]"}, ] input = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True, ) expected = "<|prompt|>[user prompt]<|answer|>[assistant response]<|prompt|>[user prompt2]<|answer|>" # noqa assert input == expected # raise error test for chat in [ [ {"role": "system", "content": "[system prompt]"}, {"role": "user", "content": "[user prompt]"}, {"role": "assistant", "content": "[assistant response]"}, {"role": "user", "content": "[user prompt2]"}, ], [ {"role": "user", "content": "[user prompt]"}, {"role": "assistant", "content": "[assistant response]"}, {"role": "user", "content": "[user prompt2]"}, {"role": "system", "content": "[system prompt]"}, ], ]: with pytest.raises(TemplateError) as e: input = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True, ) expected = "System role not supported" assert expected in str(e.value) # raise error test for chat in [ [ {"role": "user", "content": "[user prompt]"}, {"role": "user", "content": "[user prompt2]"}, {"role": "assistant", "content": "[assistant response]"}, ], [ {"role": "assistant", "content": "[assistant response]"}, {"role": "user", "content": "[user prompt]"}, {"role": "assistant", "content": "[assistant response]"}, ], [ {"role": "assistant", "content": "[assistant response]"}, {"role": "assistant", "content": "[user prompt]"}, {"role": "user", "content": "[assistant response]"}, ], ]: with pytest.raises(TemplateError) as e: input = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True, ) expected = "Conversation roles must alternate user/assistant/user/assistant/..." assert expected in str(e.value) def test_chat_template_with_system_prompt(): test_directory = os.path.abspath(os.path.dirname(__file__)) cfg_path = os.path.join(test_directory, "../test_data/cfg.yaml") cfg = load_config_yaml(cfg_path) cfg.dataset.text_prompt_start = "<|prompt|>" cfg.dataset.text_answer_separator = "<|answer|>" cfg.dataset.text_system_start = "<|system|>" cfg.dataset.add_eos_token_to_prompt = True cfg.dataset.add_eos_token_to_answer = True cfg.dataset.add_eos_token_to_system = True cfg.dataset.system_column = "system" tokenizer = get_tokenizer(cfg) tokenizer.eos_token = "" tokenizer.chat_template = get_chat_template(cfg) chat = [ {"role": "system", "content": "[system prompt]"}, {"role": "user", "content": "[user prompt]"}, {"role": "assistant", "content": "[assistant response]"}, {"role": "user", "content": "[user prompt2]"}, ] input = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True, ) expected = "<|system|>[system prompt]<|prompt|>[user prompt]<|answer|>[assistant response]<|prompt|>[user prompt2]<|answer|>" # noqa assert input == expected # works w/o system prompt as well chat = [ {"role": "user", "content": "[user prompt]"}, {"role": "assistant", "content": "[assistant response]"}, {"role": "user", "content": "[user prompt2]"}, ] input = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True, ) expected = "<|prompt|>[user prompt]<|answer|>[assistant response]<|prompt|>[user prompt2]<|answer|>" # noqa assert input == expected # raise error test for chat in [ [ {"role": "user", "content": "[user prompt]"}, {"role": "system", "content": "[system prompt]"}, {"role": "user", "content": "[user prompt2]"}, {"role": "assistant", "content": "[assistant response]"}, ], [ {"role": "system", "content": "[system prompt]"}, {"role": "user", "content": "[user prompt]"}, {"role": "user", "content": "[user prompt2]"}, {"role": "assistant", "content": "[assistant response]"}, ], [ {"role": "user", "content": "[user prompt]"}, {"role": "user", "content": "[user prompt2]"}, {"role": "assistant", "content": "[assistant response]"}, ], [ {"role": "user", "content": "[user prompt]"}, {"role": "assistant", "content": "[assistant response]"}, {"role": "assistant", "content": "[assistant response2]"}, ], ]: with pytest.raises(TemplateError) as e: input = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True, ) expected = "Conversation roles must alternate system(optional)/user/assistant/user/assistant/..." # noqa assert expected in str(e.value) def test_chat_template_no_eos_token(): test_directory = os.path.abspath(os.path.dirname(__file__)) cfg_path = os.path.join(test_directory, "../test_data/cfg.yaml") cfg = load_config_yaml(cfg_path) cfg.dataset.text_prompt_start = "<|prompt|>" cfg.dataset.text_answer_separator = "<|answer|>" cfg.dataset.text_system_start = "<|system|>" cfg.dataset.add_eos_token_to_system = False cfg.dataset.add_eos_token_to_prompt = False cfg.dataset.add_eos_token_to_answer = False cfg.dataset.system_column = "system" tokenizer = get_tokenizer(cfg) tokenizer.chat_template = get_chat_template(cfg) chat = [ {"role": "system", "content": "[system prompt]"}, {"role": "user", "content": "[user prompt]"}, {"role": "assistant", "content": "[assistant response]"}, {"role": "user", "content": "[user prompt2]"}, ] input = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True, ) expected = "<|system|>[system prompt]<|prompt|>[user prompt]<|answer|>[assistant response]<|prompt|>[user prompt2]<|answer|>" # noqa assert input == expected def test_chat_template_no_special_token(): test_directory = os.path.abspath(os.path.dirname(__file__)) cfg_path = os.path.join(test_directory, "../test_data/cfg.yaml") cfg = load_config_yaml(cfg_path) cfg.dataset.system_column = "system" cfg.dataset.text_system_start = "" cfg.dataset.text_prompt_start = "" cfg.dataset.text_answer_separator = "" cfg.dataset.add_eos_token_to_system = False cfg.dataset.add_eos_token_to_prompt = False cfg.dataset.add_eos_token_to_answer = False tokenizer = get_tokenizer(cfg) tokenizer.chat_template = get_chat_template(cfg) chat = [ {"role": "system", "content": "[system prompt]"}, {"role": "user", "content": "[user prompt]"}, {"role": "assistant", "content": "[assistant response]"}, {"role": "user", "content": "[user prompt2]"}, ] input = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True, ) expected = "[system prompt][user prompt][assistant response][user prompt2]" assert input == expected