Spaces:
Sleeping
Sleeping
from dataclasses import dataclass | |
from llmlib.bundler import Bundler | |
from llmlib.bundler_request import BundlerRequest | |
from llmlib.base_llm import LLM, Message | |
import pytest | |
from llmlib.model_registry import ModelEntry, ModelRegistry | |
def test_model_id_on_gpu(): | |
b = Bundler(filled_model_registry()) | |
assert b.id_of_model_on_gpu() is None | |
b.set_model_on_gpu(GpuLLM.model_id) | |
assert b.id_of_model_on_gpu() == GpuLLM.model_id | |
def test_get_response(): | |
b = Bundler(filled_model_registry()) | |
msgs = [Message(role="user", msg="hello")] | |
request = BundlerRequest(model_id=GpuLLM.model_id, msgs=msgs) | |
expected_response = GpuLLM().complete_msgs2(msgs) | |
actual_response: str = b.get_response(request) | |
assert actual_response == expected_response | |
assert b.id_of_model_on_gpu() == GpuLLM.model_id | |
def test_bundler_multiple_responses(): | |
b = Bundler(filled_model_registry()) | |
models = [GpuLLM(), GpuLLM2(), NonGpuLLM()] | |
msgs = [Message(role="user", msg="hello")] | |
expected_responses = [m.complete_msgs2(msgs) for m in models] | |
assert expected_responses[0] != expected_responses[1] | |
actual_responses = [ | |
b.get_response(BundlerRequest(model_id=m.model_id, msgs=msgs)) for m in models | |
] | |
assert actual_responses == expected_responses | |
last_gpu_model = [m for m in models if m.requires_gpu_exclusively][-1] | |
assert b.id_of_model_on_gpu() == last_gpu_model.model_id | |
def test_set_model_on_gpu(): | |
b = Bundler(filled_model_registry()) | |
b.set_model_on_gpu(GpuLLM.model_id) | |
assert b.id_of_model_on_gpu() == GpuLLM.model_id | |
with pytest.raises(AssertionError): | |
b.set_model_on_gpu("invalid") | |
assert b.id_of_model_on_gpu() == GpuLLM.model_id | |
b.set_model_on_gpu(NonGpuLLM.model_id) | |
gpu_model_is_still_loaded: bool = b.id_of_model_on_gpu() == GpuLLM.model_id | |
assert gpu_model_is_still_loaded | |
def filled_model_registry() -> ModelRegistry: | |
model_entries = [ | |
ModelEntry.from_cls_with_id(GpuLLM), | |
ModelEntry.from_cls_with_id(GpuLLM2), | |
ModelEntry.from_cls_with_id(NonGpuLLM), | |
] | |
return ModelRegistry(model_entries) | |
class GpuLLM(LLM): | |
model_id = "gpu-llm-model" | |
requires_gpu_exclusively = True | |
def complete_msgs2(self, msgs: list[Message]) -> str: | |
return "gpu msg" | |
class GpuLLM2(LLM): | |
model_id = "gpu-llm-model-2" | |
requires_gpu_exclusively = True | |
def complete_msgs2(self, msgs: list[Message]) -> str: | |
return "gpu msg 2" | |
class NonGpuLLM(LLM): | |
model_id = "non-gpu-llm-model" | |
requires_gpu_exclusively = False | |
def complete_msgs2(self, msgs: list[Message]) -> str: | |
return "non-gpu message" | |