File size: 2,649 Bytes
acc4ffe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Test few shot prompt template."""
import pytest

from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate

EXAMPLE_PROMPT = PromptTemplate(
    input_variables=["question", "answer"], template="{question}: {answer}"
)


def test_suffix_only() -> None:
    """Test prompt works with just a suffix."""
    suffix = "This is a {foo} test."
    input_variables = ["foo"]
    prompt = FewShotPromptTemplate(
        input_variables=input_variables,
        suffix=suffix,
        examples=[],
        example_prompt=EXAMPLE_PROMPT,
    )
    output = prompt.format(foo="bar")
    expected_output = "This is a bar test."
    assert output == expected_output


def test_prompt_missing_input_variables() -> None:
    """Test error is raised when input variables are not provided."""
    # Test when missing in suffix
    template = "This is a {foo} test."
    with pytest.raises(ValueError):
        FewShotPromptTemplate(
            input_variables=[],
            suffix=template,
            examples=[],
            example_prompt=EXAMPLE_PROMPT,
        )

    # Test when missing in prefix
    template = "This is a {foo} test."
    with pytest.raises(ValueError):
        FewShotPromptTemplate(
            input_variables=[],
            suffix="foo",
            examples=[],
            prefix=template,
            example_prompt=EXAMPLE_PROMPT,
        )


def test_prompt_extra_input_variables() -> None:
    """Test error is raised when there are too many input variables."""
    template = "This is a {foo} test."
    input_variables = ["foo", "bar"]
    with pytest.raises(ValueError):
        FewShotPromptTemplate(
            input_variables=input_variables,
            suffix=template,
            examples=[],
            example_prompt=EXAMPLE_PROMPT,
        )


def test_few_shot_functionality() -> None:
    """Test that few shot works with examples."""
    prefix = "This is a test about {content}."
    suffix = "Now you try to talk about {new_content}."
    examples = [
        {"question": "foo", "answer": "bar"},
        {"question": "baz", "answer": "foo"},
    ]
    prompt = FewShotPromptTemplate(
        suffix=suffix,
        prefix=prefix,
        input_variables=["content", "new_content"],
        examples=examples,
        example_prompt=EXAMPLE_PROMPT,
        example_separator="\n",
    )
    output = prompt.format(content="animals", new_content="party")
    expected_output = (
        "This is a test about animals.\n"
        "foo: bar\n"
        "baz: foo\n"
        "Now you try to talk about party."
    )
    assert output == expected_output