multimodal-vibe-check / tests /test_bundler.py
tomasruiz's picture
Include code of llmapp from Github
41d24d2
raw
history blame
2.72 kB
from dataclasses import dataclass
from llmlib.bundler import Bundler
from llmlib.bundler_request import BundlerRequest
from llmlib.base_llm import LLM, Message
import pytest
from llmlib.model_registry import ModelEntry, ModelRegistry
def test_model_id_on_gpu():
b = Bundler(filled_model_registry())
assert b.id_of_model_on_gpu() is None
b.set_model_on_gpu(GpuLLM.model_id)
assert b.id_of_model_on_gpu() == GpuLLM.model_id
def test_get_response():
b = Bundler(filled_model_registry())
msgs = [Message(role="user", msg="hello")]
request = BundlerRequest(model_id=GpuLLM.model_id, msgs=msgs)
expected_response = GpuLLM().complete_msgs2(msgs)
actual_response: str = b.get_response(request)
assert actual_response == expected_response
assert b.id_of_model_on_gpu() == GpuLLM.model_id
def test_bundler_multiple_responses():
b = Bundler(filled_model_registry())
models = [GpuLLM(), GpuLLM2(), NonGpuLLM()]
msgs = [Message(role="user", msg="hello")]
expected_responses = [m.complete_msgs2(msgs) for m in models]
assert expected_responses[0] != expected_responses[1]
actual_responses = [
b.get_response(BundlerRequest(model_id=m.model_id, msgs=msgs)) for m in models
]
assert actual_responses == expected_responses
last_gpu_model = [m for m in models if m.requires_gpu_exclusively][-1]
assert b.id_of_model_on_gpu() == last_gpu_model.model_id
def test_set_model_on_gpu():
b = Bundler(filled_model_registry())
b.set_model_on_gpu(GpuLLM.model_id)
assert b.id_of_model_on_gpu() == GpuLLM.model_id
with pytest.raises(AssertionError):
b.set_model_on_gpu("invalid")
assert b.id_of_model_on_gpu() == GpuLLM.model_id
b.set_model_on_gpu(NonGpuLLM.model_id)
gpu_model_is_still_loaded: bool = b.id_of_model_on_gpu() == GpuLLM.model_id
assert gpu_model_is_still_loaded
def filled_model_registry() -> ModelRegistry:
model_entries = [
ModelEntry.from_cls_with_id(GpuLLM),
ModelEntry.from_cls_with_id(GpuLLM2),
ModelEntry.from_cls_with_id(NonGpuLLM),
]
return ModelRegistry(model_entries)
@dataclass
class GpuLLM(LLM):
model_id = "gpu-llm-model"
requires_gpu_exclusively = True
def complete_msgs2(self, msgs: list[Message]) -> str:
return "gpu msg"
@dataclass
class GpuLLM2(LLM):
model_id = "gpu-llm-model-2"
requires_gpu_exclusively = True
def complete_msgs2(self, msgs: list[Message]) -> str:
return "gpu msg 2"
@dataclass
class NonGpuLLM(LLM):
model_id = "non-gpu-llm-model"
requires_gpu_exclusively = False
def complete_msgs2(self, msgs: list[Message]) -> str:
return "non-gpu message"