tomasruiz's picture
Include code of llmapp from Github
41d24d2
raw
history blame
3.59 kB
import os
from pathlib import Path
import PIL
from llmlib.base_llm import LLM, Message
import pytest
def assert_model_knows_capital_of_france(model: LLM) -> None:
response: str = model.complete_msgs2(
msgs=[Message(role="user", msg="What is the capital of France?")]
)
assert "paris" in response.lower()
def assert_model_can_answer_batch_of_text_prompts(model: LLM) -> None:
prompts = [
"What is the capital of France?",
"What continent is south of Europe?",
"What are the two tallest mountains in the world?",
]
batch = [[Message.from_prompt(prompt)] for prompt in prompts]
responses = model.complete_batch(batch=batch)
assert len(responses) == 3
assert "paris" in responses[0].lower()
assert "africa" in responses[1].lower()
assert "everest" in responses[2].lower()
def assert_model_can_answer_batch_of_img_prompts(model: LLM) -> None:
batch = [
[pyramid_message()],
[forest_message()],
[fish_message()],
]
responses = model.complete_batch(batch=batch)
assert len(responses) == 3
assert "pyramid" in responses[0].lower()
assert "forest" in responses[1].lower()
assert "fish" in responses[2].lower()
def assert_model_rejects_unsupported_batches(model: LLM) -> None:
mixed_textonly_and_img_batch = [
[Message.from_prompt("What is the capital of France?")],
[pyramid_message()],
]
err_msg = "Batch must contain an image in every entry or none at all."
with pytest.raises(ValueError, match=err_msg):
model.complete_batch(mixed_textonly_and_img_batch)
def assert_model_recognizes_pyramid_in_image(model: LLM):
msg = pyramid_message()
answer: str = model.complete_msgs2(msgs=[msg])
assert "pyramid" in answer.lower()
def assert_model_recognizes_afd_in_video(model: LLM):
video_path = file_for_test("video.mp4")
question = "Describe the video in english"
answer: str = model.video_prompt(video_path, question)
assert "alternative für deutschland" in answer.lower(), answer
def get_mona_lisa_completion(model: LLM) -> str:
msg: Message = mona_lisa_message()
answer: str = model.complete_msgs2(msgs=[msg])
return answer
def mona_lisa_message() -> Message:
_, img = mona_lisa_filename_and_img()
prompt = "What is in the image?"
msg = Message(role="user", msg=prompt, img=img, img_name="")
return msg
def pyramid_message() -> Message:
img_name = "pyramid.jpg"
img = get_test_img(img_name)
msg = Message(role="user", msg="What is in the image?", img=img, img_name="")
return msg
def forest_message() -> Message:
img_name = "forest.jpg"
img = get_test_img(img_name)
msg = Message(
role="user", msg="Describe what you see in the picture.", img=img, img_name=""
)
return msg
def fish_message() -> Message:
img_name = "fish.jpg"
img = get_test_img(img_name)
msg = Message(
role="user",
msg="What animal is depicted and where does it live?",
img=img,
img_name="",
)
return msg
def mona_lisa_filename_and_img() -> tuple[str, PIL.Image.Image]:
img_name = "mona-lisa.png"
img = get_test_img(img_name)
return img_name, img
def get_test_img(name: str) -> PIL.Image.Image:
path = file_for_test(name)
return PIL.Image.open(path)
def file_for_test(name: str) -> Path:
return Path(__file__).parent.parent / "test-files" / name
def is_ci() -> bool:
is_ci_str: str = os.environ.get("CI", "false").lower()
return is_ci_str != "false"