Spaces:
Sleeping
Sleeping
import os | |
from pathlib import Path | |
import PIL | |
from llmlib.base_llm import LLM, Message | |
import pytest | |
def assert_model_knows_capital_of_france(model: LLM) -> None: | |
response: str = model.complete_msgs2( | |
msgs=[Message(role="user", msg="What is the capital of France?")] | |
) | |
assert "paris" in response.lower() | |
def assert_model_can_answer_batch_of_text_prompts(model: LLM) -> None: | |
prompts = [ | |
"What is the capital of France?", | |
"What continent is south of Europe?", | |
"What are the two tallest mountains in the world?", | |
] | |
batch = [[Message.from_prompt(prompt)] for prompt in prompts] | |
responses = model.complete_batch(batch=batch) | |
assert len(responses) == 3 | |
assert "paris" in responses[0].lower() | |
assert "africa" in responses[1].lower() | |
assert "everest" in responses[2].lower() | |
def assert_model_can_answer_batch_of_img_prompts(model: LLM) -> None: | |
batch = [ | |
[pyramid_message()], | |
[forest_message()], | |
[fish_message()], | |
] | |
responses = model.complete_batch(batch=batch) | |
assert len(responses) == 3 | |
assert "pyramid" in responses[0].lower() | |
assert "forest" in responses[1].lower() | |
assert "fish" in responses[2].lower() | |
def assert_model_rejects_unsupported_batches(model: LLM) -> None: | |
mixed_textonly_and_img_batch = [ | |
[Message.from_prompt("What is the capital of France?")], | |
[pyramid_message()], | |
] | |
err_msg = "Batch must contain an image in every entry or none at all." | |
with pytest.raises(ValueError, match=err_msg): | |
model.complete_batch(mixed_textonly_and_img_batch) | |
def assert_model_recognizes_pyramid_in_image(model: LLM): | |
msg = pyramid_message() | |
answer: str = model.complete_msgs2(msgs=[msg]) | |
assert "pyramid" in answer.lower() | |
def assert_model_recognizes_afd_in_video(model: LLM): | |
video_path = file_for_test("video.mp4") | |
question = "Describe the video in english" | |
answer: str = model.video_prompt(video_path, question) | |
assert "alternative für deutschland" in answer.lower(), answer | |
def get_mona_lisa_completion(model: LLM) -> str: | |
msg: Message = mona_lisa_message() | |
answer: str = model.complete_msgs2(msgs=[msg]) | |
return answer | |
def mona_lisa_message() -> Message: | |
_, img = mona_lisa_filename_and_img() | |
prompt = "What is in the image?" | |
msg = Message(role="user", msg=prompt, img=img, img_name="") | |
return msg | |
def pyramid_message() -> Message: | |
img_name = "pyramid.jpg" | |
img = get_test_img(img_name) | |
msg = Message(role="user", msg="What is in the image?", img=img, img_name="") | |
return msg | |
def forest_message() -> Message: | |
img_name = "forest.jpg" | |
img = get_test_img(img_name) | |
msg = Message( | |
role="user", msg="Describe what you see in the picture.", img=img, img_name="" | |
) | |
return msg | |
def fish_message() -> Message: | |
img_name = "fish.jpg" | |
img = get_test_img(img_name) | |
msg = Message( | |
role="user", | |
msg="What animal is depicted and where does it live?", | |
img=img, | |
img_name="", | |
) | |
return msg | |
def mona_lisa_filename_and_img() -> tuple[str, PIL.Image.Image]: | |
img_name = "mona-lisa.png" | |
img = get_test_img(img_name) | |
return img_name, img | |
def get_test_img(name: str) -> PIL.Image.Image: | |
path = file_for_test(name) | |
return PIL.Image.open(path) | |
def file_for_test(name: str) -> Path: | |
return Path(__file__).parent.parent / "test-files" / name | |
def is_ci() -> bool: | |
is_ci_str: str = os.environ.get("CI", "false").lower() | |
return is_ci_str != "false" | |