File size: 3,590 Bytes
41d24d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
from pathlib import Path
import PIL
from llmlib.base_llm import LLM, Message
import pytest


def assert_model_knows_capital_of_france(model: LLM) -> None:
    response: str = model.complete_msgs2(
        msgs=[Message(role="user", msg="What is the capital of France?")]
    )
    assert "paris" in response.lower()


def assert_model_can_answer_batch_of_text_prompts(model: LLM) -> None:
    prompts = [
        "What is the capital of France?",
        "What continent is south of Europe?",
        "What are the two tallest mountains in the world?",
    ]
    batch = [[Message.from_prompt(prompt)] for prompt in prompts]
    responses = model.complete_batch(batch=batch)
    assert len(responses) == 3
    assert "paris" in responses[0].lower()
    assert "africa" in responses[1].lower()
    assert "everest" in responses[2].lower()


def assert_model_can_answer_batch_of_img_prompts(model: LLM) -> None:
    batch = [
        [pyramid_message()],
        [forest_message()],
        [fish_message()],
    ]
    responses = model.complete_batch(batch=batch)
    assert len(responses) == 3
    assert "pyramid" in responses[0].lower()
    assert "forest" in responses[1].lower()
    assert "fish" in responses[2].lower()


def assert_model_rejects_unsupported_batches(model: LLM) -> None:
    mixed_textonly_and_img_batch = [
        [Message.from_prompt("What is the capital of France?")],
        [pyramid_message()],
    ]
    err_msg = "Batch must contain an image in every entry or none at all."
    with pytest.raises(ValueError, match=err_msg):
        model.complete_batch(mixed_textonly_and_img_batch)


def assert_model_recognizes_pyramid_in_image(model: LLM):
    msg = pyramid_message()
    answer: str = model.complete_msgs2(msgs=[msg])
    assert "pyramid" in answer.lower()


def assert_model_recognizes_afd_in_video(model: LLM):
    video_path = file_for_test("video.mp4")
    question = "Describe the video in english"
    answer: str = model.video_prompt(video_path, question)
    assert "alternative für deutschland" in answer.lower(), answer


def get_mona_lisa_completion(model: LLM) -> str:
    msg: Message = mona_lisa_message()
    answer: str = model.complete_msgs2(msgs=[msg])
    return answer


def mona_lisa_message() -> Message:
    _, img = mona_lisa_filename_and_img()
    prompt = "What is in the image?"
    msg = Message(role="user", msg=prompt, img=img, img_name="")
    return msg


def pyramid_message() -> Message:
    img_name = "pyramid.jpg"
    img = get_test_img(img_name)
    msg = Message(role="user", msg="What is in the image?", img=img, img_name="")
    return msg


def forest_message() -> Message:
    img_name = "forest.jpg"
    img = get_test_img(img_name)
    msg = Message(
        role="user", msg="Describe what you see in the picture.", img=img, img_name=""
    )
    return msg


def fish_message() -> Message:
    img_name = "fish.jpg"
    img = get_test_img(img_name)
    msg = Message(
        role="user",
        msg="What animal is depicted and where does it live?",
        img=img,
        img_name="",
    )
    return msg


def mona_lisa_filename_and_img() -> tuple[str, PIL.Image.Image]:
    img_name = "mona-lisa.png"
    img = get_test_img(img_name)
    return img_name, img


def get_test_img(name: str) -> PIL.Image.Image:
    path = file_for_test(name)
    return PIL.Image.open(path)


def file_for_test(name: str) -> Path:
    return Path(__file__).parent.parent / "test-files" / name


def is_ci() -> bool:
    is_ci_str: str = os.environ.get("CI", "false").lower()
    return is_ci_str != "false"