llm-studio / tests /src /test_nesting.py
qinfeng722's picture
Upload 322 files
5caedb4 verified
import unittest
import pytest
from llm_studio.src.nesting import Dependency, Nesting
class TestDependency:
@pytest.mark.parametrize(
"key, value, is_set",
[
("personalize", True, True),
("validation_strategy", "automatic", True),
("deepspeed_method", "ZeRO2", True),
("lora", False, False),
],
)
def test_dependency_init(self, key, value, is_set):
dep = Dependency(key=key, value=value, is_set=is_set)
assert dep.key == key
assert dep.value == value
assert dep.is_set == is_set
@pytest.mark.parametrize(
"dep, dependency_values, expected",
[
(Dependency("tkey", value=True, is_set=True), [True], True),
(Dependency("tkey", value=True, is_set=True), [False], False),
(Dependency("tkey", value=True, is_set=False), [True], False),
(Dependency("tkey", value=True, is_set=False), [False], True),
(Dependency("tkey", value=False, is_set=True), [False], True),
(Dependency("tkey", value=False, is_set=True), [True], False),
(Dependency("tkey", value=False, is_set=False), [False], False),
(Dependency("tkey", value=False, is_set=False), [True], True),
(Dependency("tkey", value="value", is_set=True), ["value"], True),
(Dependency("tkey", value="value", is_set=True), ["other_value"], False),
(Dependency("tkey", value="value", is_set=False), ["value"], False),
(Dependency("tkey", value="value", is_set=False), ["other_value"], True),
(Dependency("tkey", value=None, is_set=True), [], False),
(Dependency("tkey", value=None, is_set=True), ["value"], False),
(Dependency("tkey", value=None, is_set=False), [], False),
(Dependency("tkey", value=None, is_set=False), ["value"], True),
],
)
def test_dependency_check(self, dep, dependency_values, expected):
assert dep.check(dependency_values) == expected
class TestNesting(unittest.TestCase):
def setUp(self):
self.nesting = Nesting()
def test_nesting_init(self):
self.assertEqual(len(self.nesting.dependencies), 0)
self.assertEqual(len(self.nesting.triggers), 0)
def test_nesting_add(self):
keys = ["key1", "key2"]
dependencies = [
Dependency("dep1", value=True, is_set=True),
Dependency("dep2", value=True, is_set=True),
]
self.nesting.add(keys, dependencies)
self.assertEqual(len(self.nesting.dependencies), 2)
self.assertEqual(len(self.nesting.triggers), 2)
self.assertIn("dep1", self.nesting.triggers)
self.assertIn("dep2", self.nesting.triggers)
def test_nesting_add_duplicate_keys(self):
keys = ["key1", "key1"]
dependencies = [Dependency("dep1", value=True, is_set=True)]
with self.assertRaises(ValueError):
self.nesting.add(keys, dependencies)
def test_nesting_multiple_adds(self):
self.nesting.add(["key1"], [Dependency("dep1", value=True, is_set=True)])
self.nesting.add(["key2"], [Dependency("dep2", value=True, is_set=True)])
self.nesting.add(
["key1", "key2"], [Dependency("dep3", value=True, is_set=True)]
)
self.assertEqual(len(self.nesting.dependencies), 2)
self.assertEqual(len(self.nesting.triggers), 3)
self.assertEqual(len(self.nesting.dependencies["key1"]), 2)
self.assertEqual(len(self.nesting.dependencies["key2"]), 2)