File size: 3,945 Bytes
ca1ecab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from dataclasses import dataclass
import os
import pathlib

import pytest

from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.protocol import ChatCompletionRequest

chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
    __file__))).parent.parent / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()

# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT = [
    ("facebook/opt-125m", None, True,
     "Hello</s>Hi there!</s>What is the capital of</s>"),
    ("facebook/opt-125m", None, False,
     "Hello</s>Hi there!</s>What is the capital of</s>"),
    ("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
"""),
    ("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of""")
]

TEST_MESSAGES = [
    {
        'role': 'user',
        'content': 'Hello'
    },
    {
        'role': 'assistant',
        'content': 'Hi there!'
    },
    {
        'role': 'user',
        'content': 'What is the capital of'
    },
]


@dataclass
class MockTokenizer:
    chat_template = None


@dataclass
class MockServingChat:
    tokenizer: MockTokenizer


def test_load_chat_template():
    # Testing chatml template
    tokenizer = MockTokenizer()
    mock_serving_chat = MockServingChat(tokenizer)
    OpenAIServingChat._load_chat_template(mock_serving_chat,
                                          chat_template=chatml_jinja_path)

    template_content = tokenizer.chat_template

    # Test assertions
    assert template_content is not None
    # Hard coded value for template_chatml.jinja
    assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""


def test_no_load_chat_template():
    # Testing chatml template
    template = "../../examples/does_not_exist"
    tokenizer = MockTokenizer()

    mock_serving_chat = MockServingChat(tokenizer)
    OpenAIServingChat._load_chat_template(mock_serving_chat,
                                          chat_template=template)
    template_content = tokenizer.chat_template

    # Test assertions
    assert template_content is not None
    # Hard coded value for template_chatml.jinja
    assert template_content == """../../examples/does_not_exist"""


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "model,template,add_generation_prompt,expected_output",
    MODEL_TEMPLATE_GENERATON_OUTPUT)
async def test_get_gen_prompt(model, template, add_generation_prompt,
                              expected_output):
    # Initialize the tokenizer
    tokenizer = get_tokenizer(tokenizer_name=model)
    mock_serving_chat = MockServingChat(tokenizer)
    OpenAIServingChat._load_chat_template(mock_serving_chat,
                                          chat_template=template)

    # Create a mock request object using keyword arguments
    mock_request = ChatCompletionRequest(
        model=model,
        messages=TEST_MESSAGES,
        add_generation_prompt=add_generation_prompt)

    # Call the function and get the result
    result = tokenizer.apply_chat_template(
        conversation=mock_request.messages,
        tokenize=False,
        add_generation_prompt=mock_request.add_generation_prompt)

    # Test assertion
    assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}"