import pytest from mergekit.common import ModelPath, ModelReference class TestModelReference: def test_parse_simple(self): text = "hf_user/model" mr = ModelReference.parse(text) assert mr.model == ModelPath(path="hf_user/model", revision=None) assert mr.lora is None assert str(mr) == text def test_parse_lora(self): text = "hf_user/model+hf_user/lora" mr = ModelReference.parse(text) assert mr.model == ModelPath(path="hf_user/model", revision=None) assert mr.lora == ModelPath(path="hf_user/lora", revision=None) assert str(mr) == text def test_parse_revision(self): text = "hf_user/model@v0.0.1" mr = ModelReference.parse(text) assert mr.model == ModelPath(path="hf_user/model", revision="v0.0.1") assert mr.lora is None assert str(mr) == text def test_parse_lora_plus_revision(self): text = "hf_user/model@v0.0.1+hf_user/lora@main" mr = ModelReference.parse(text) assert mr.model == ModelPath(path="hf_user/model", revision="v0.0.1") assert mr.lora == ModelPath(path="hf_user/lora", revision="main") assert str(mr) == text def test_parse_bad(self): with pytest.raises(RuntimeError): ModelReference.parse("@@@@@") with pytest.raises(RuntimeError): ModelReference.parse("a+b+c") with pytest.raises(RuntimeError): ModelReference.parse("a+b+c@d+e@f@g")