Spaces:
Sleeping
Sleeping
File size: 2,718 Bytes
41d24d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
from dataclasses import dataclass
from llmlib.bundler import Bundler
from llmlib.bundler_request import BundlerRequest
from llmlib.base_llm import LLM, Message
import pytest
from llmlib.model_registry import ModelEntry, ModelRegistry
def test_model_id_on_gpu():
b = Bundler(filled_model_registry())
assert b.id_of_model_on_gpu() is None
b.set_model_on_gpu(GpuLLM.model_id)
assert b.id_of_model_on_gpu() == GpuLLM.model_id
def test_get_response():
b = Bundler(filled_model_registry())
msgs = [Message(role="user", msg="hello")]
request = BundlerRequest(model_id=GpuLLM.model_id, msgs=msgs)
expected_response = GpuLLM().complete_msgs2(msgs)
actual_response: str = b.get_response(request)
assert actual_response == expected_response
assert b.id_of_model_on_gpu() == GpuLLM.model_id
def test_bundler_multiple_responses():
b = Bundler(filled_model_registry())
models = [GpuLLM(), GpuLLM2(), NonGpuLLM()]
msgs = [Message(role="user", msg="hello")]
expected_responses = [m.complete_msgs2(msgs) for m in models]
assert expected_responses[0] != expected_responses[1]
actual_responses = [
b.get_response(BundlerRequest(model_id=m.model_id, msgs=msgs)) for m in models
]
assert actual_responses == expected_responses
last_gpu_model = [m for m in models if m.requires_gpu_exclusively][-1]
assert b.id_of_model_on_gpu() == last_gpu_model.model_id
def test_set_model_on_gpu():
b = Bundler(filled_model_registry())
b.set_model_on_gpu(GpuLLM.model_id)
assert b.id_of_model_on_gpu() == GpuLLM.model_id
with pytest.raises(AssertionError):
b.set_model_on_gpu("invalid")
assert b.id_of_model_on_gpu() == GpuLLM.model_id
b.set_model_on_gpu(NonGpuLLM.model_id)
gpu_model_is_still_loaded: bool = b.id_of_model_on_gpu() == GpuLLM.model_id
assert gpu_model_is_still_loaded
def filled_model_registry() -> ModelRegistry:
model_entries = [
ModelEntry.from_cls_with_id(GpuLLM),
ModelEntry.from_cls_with_id(GpuLLM2),
ModelEntry.from_cls_with_id(NonGpuLLM),
]
return ModelRegistry(model_entries)
@dataclass
class GpuLLM(LLM):
model_id = "gpu-llm-model"
requires_gpu_exclusively = True
def complete_msgs2(self, msgs: list[Message]) -> str:
return "gpu msg"
@dataclass
class GpuLLM2(LLM):
model_id = "gpu-llm-model-2"
requires_gpu_exclusively = True
def complete_msgs2(self, msgs: list[Message]) -> str:
return "gpu msg 2"
@dataclass
class NonGpuLLM(LLM):
model_id = "non-gpu-llm-model"
requires_gpu_exclusively = False
def complete_msgs2(self, msgs: list[Message]) -> str:
return "non-gpu message"
|