File size: 2,246 Bytes
41d24d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from llmlib.base_llm import Message
from PIL import Image

from llmlib.phi3.phi3 import GenConf, Phi3Vision, extract_imgs_and_dicts, pad_left
import pytest
import torch

from .helpers import (
    assert_model_can_answer_batch_of_img_prompts,
    assert_model_can_answer_batch_of_text_prompts,
    assert_model_knows_capital_of_france,
    assert_model_rejects_unsupported_batches,
    get_mona_lisa_completion,
    is_ci,
)


def test_extract_imgs_and_dicts():
    img1 = Image.new(mode="RGB", size=(1, 1))
    img2 = Image.new(mode="RGB", size=(1, 1))
    msgs = [
        a_msg(),
        a_msg(img=img1, img_name="img1"),
        a_msg(img=img2, img_name="img2"),
        a_msg(),
        a_msg(img=img1, img_name="img1"),
        a_msg(img=img2, img_name="img2"),
    ]
    images, messages = extract_imgs_and_dicts(msgs)
    assert len(images) == 2
    assert len(messages) == 6
    assert "<|image_1|>" in messages[1]["content"]
    assert "<|image_1|>" in messages[4]["content"]
    assert "<|image_2|>" in messages[5]["content"]
    assert "<|image_2|>" in messages[2]["content"]


def a_msg(img: Image.Image | None = None, img_name: str | None = None) -> Message:
    return Message(role="user", msg="", img=img, img_name=img_name)


@pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
def test_phi3_vision(model: Phi3Vision):
    assert_model_knows_capital_of_france(model)
    answer: str = get_mona_lisa_completion(model)
    assert isinstance(answer, str)


@pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
def test_phi3_batching(model: Phi3Vision):
    assert_model_can_answer_batch_of_text_prompts(model)
    assert_model_can_answer_batch_of_img_prompts(model)


@pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
def test_phi3_invalid_input(model: Phi3Vision):
    assert_model_rejects_unsupported_batches(model)


@pytest.fixture(scope="module")
def model():
    yield Phi3Vision(GenConf(max_new_tokens=30))


def test_padleft():
    pad_token = -1
    seqs = [torch.tensor([1, 2, 3]), torch.tensor([4, 5]), torch.tensor([6])]
    expected = torch.tensor([[1, 2, 3], [pad_token, 4, 5], [pad_token, pad_token, 6]])
    actual = pad_left(seqs, pad_token)
    assert torch.equal(actual, expected)