from typing import Dict, Optional import pytest from common import make_picollama, run_and_check_merge from transformers import AutoConfig from mergekit.config import ( InputModelDefinition, InputSliceDefinition, MergeConfiguration, OutputSliceDefinition, ParameterSetting, ) from mergekit.io import LazyTensorLoader @pytest.fixture(scope="session") def model_a(tmp_path_factory): return make_picollama(tmp_path_factory.mktemp("model_a")) @pytest.fixture(scope="session") def model_b(tmp_path_factory): return make_picollama(tmp_path_factory.mktemp("model_b")) @pytest.fixture(scope="session") def model_c(tmp_path_factory): return make_picollama(tmp_path_factory.mktemp("model_c")) class TestBasicMerges: def test_gpt2_copy(self): config = MergeConfiguration( merge_method="passthrough", models=[InputModelDefinition(model="gpt2")], dtype="bfloat16", ) run_and_check_merge(config) def test_gpt2_stack(self): config = MergeConfiguration( merge_method="passthrough", slices=[ OutputSliceDefinition( sources=[InputSliceDefinition(model="gpt2", layer_range=[0, 12])] ) ] * 2, dtype="bfloat16", ) def _check_config_layers(p: str): config = AutoConfig.from_pretrained(p) assert config.n_layer == 24 run_and_check_merge(config, validate=_check_config_layers) def test_passthrough_scale(self, model_a): config = MergeConfiguration( merge_method="passthrough", models=[ InputModelDefinition( model=model_a, parameters={ "scale": [ {"filter": "o_proj", "value": 0}, {"value": 1}, ] }, ) ], ) def _check_o_proj(p: str): loader = LazyTensorLoader.from_disk(p) saw_any = False for name in loader.index.tensor_paths: if "o_proj" in name: param = loader.get_tensor(name) assert (param == 0).all() saw_any = True elif "lm_head" in name: param = loader.get_tensor(name) assert param.count_nonzero() > 0 assert saw_any, "No o_proj parameters found" run_and_check_merge(config, validate=_check_o_proj) def test_linear_merge(self, model_a, model_b): config = self.two_model_config(model_a, model_b, merge_method="linear") run_and_check_merge(config) def test_slerp_merge(self, model_a, model_b): config = self.two_model_config( model_a, model_b, merge_method="slerp", base_model=model_a ) config.parameters = {"t": 0.35} run_and_check_merge(config) def test_task_arithmetic_merge(self, model_a, model_b, model_c): config = self.two_model_config( model_a, model_b, merge_method="task_arithmetic", base_model=model_c ) run_and_check_merge(config) def test_ties_merge(self, model_a, model_b, model_c): config = self.two_model_config( model_a, model_b, merge_method="ties", base_model=model_c, params={"density": 0.3}, ) run_and_check_merge(config) def test_dare_ties_merge(self, model_a, model_b, model_c): config = self.two_model_config( model_a, model_b, merge_method="dare_ties", base_model=model_c, params={"density": 0.66}, ) run_and_check_merge(config) def test_model_stock_merge(self, model_a, model_b, model_c): config = self.two_model_config( model_b, model_c, merge_method="model_stock", base_model=model_a ) run_and_check_merge(config) def test_model_stock_filterwise_merge(self, model_a, model_b, model_c): config = self.two_model_config( model_b, model_c, merge_method="model_stock", base_model=model_a, params={"filter_wise": True}, ) run_and_check_merge(config) def two_model_config( self, model_a, model_b, merge_method: str, base_model: Optional[str] = None, params: Optional[Dict[str, ParameterSetting]] = None, ): config = MergeConfiguration( merge_method=merge_method, base_model=base_model, models=[ InputModelDefinition( model=model_a, parameters={"weight": 0.6}, ), InputModelDefinition( model=model_b, parameters={"weight": 0.4}, ), ], dtype="bfloat16", parameters=params, ) return config