# Copyright (C) 2024 Charles O. Goddard # # This software is free software: you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation, either version 3 of the # License, or (at your option) any later version. # # This software is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see http://www.gnu.org/licenses/. import logging from enum import Enum from typing import Any, Dict, List, Optional, Tuple import torch from pydantic import BaseModel from typing_extensions import Literal from mergekit.architecture import WeightInfo from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task from mergekit.io.tasks import GatherTensors from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod from mergekit.sparsify import SparsificationMethod, sparsify class ConsensusMethod(str, Enum): count = "count" sum = "sum" class GeneralizedTaskArithmeticMerge(MergeMethod, BaseModel, frozen=True): consensus_method: Optional[ConsensusMethod] sparsification_method: Optional[SparsificationMethod] default_normalize: bool def parameters(self) -> List[ConfigParameterDef]: return [ ConfigParameterDef(name="int8_mask", required=False, default_value=False), ConfigParameterDef( name="normalize", required=False, default_value=self.default_normalize ), ] def tensor_parameters(self) -> List[ConfigParameterDef]: return [ ConfigParameterDef(name="weight", required=True), ConfigParameterDef(name="density", required=False, default_value=1.0), ] def make_task( self, output_weight: WeightInfo, tensors: GatherTensors, base_model: Optional[ModelReference], parameters: ImmutableMap[str, Any], tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], ) -> Task: return GTATask( method=self, tensors=tensors, base_model=base_model, tensor_parameters=tensor_parameters, int8_mask=parameters["int8_mask"], normalize=parameters["normalize"], out_tensor_name=output_weight.name, ) class GTATask(Task[torch.Tensor]): method: GeneralizedTaskArithmeticMerge tensors: GatherTensors base_model: ModelReference out_tensor_name: str tensor_parameters: ImmutableMap[ModelReference, Any] int8_mask: bool normalize: bool def uses_accelerator(self) -> bool: return True def arguments(self) -> Dict[str, Task]: return {"tensors": self.tensors} def execute( self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs, ) -> torch.Tensor: # collect task vectors tvs, base = get_task_vectors( self.out_tensor_name, self.base_model, tensors, tensor_parameters=self.tensor_parameters.data, ) if not tvs: return base # sparsify if self.method.sparsification_method: for tv_info in tvs: tv_info["delta"] = sparsify( tv_info["delta"], density=tv_info["density"], method=self.method.sparsification_method, ) deltas = torch.stack([tv["delta"] for tv in tvs], dim=0) weights = torch.tensor( [tv["weight"] for tv in tvs], dtype=deltas.dtype, device=deltas.device ) while len(deltas.shape) > len(weights.shape): weights.unsqueeze_(-1) weighted_deltas = deltas * weights # get sign consensus and mix deltas if self.method.consensus_method: mask_dtype = torch.int8 if self.int8_mask else base.dtype mask = get_mask( weighted_deltas, method=self.method.consensus_method, mask_dtype=mask_dtype, ) mixed_delta = (weighted_deltas * mask).sum(dim=0) divisor = (weights * mask).sum(dim=0) divisor[divisor == 0] = 1 else: mixed_delta = weighted_deltas.sum(dim=0) divisor = weights.sum(dim=0) divisor[divisor.abs() < 1e-8] = 1 if self.normalize: mixed_delta /= divisor return (base + mixed_delta).to(base.dtype) def get_task_vectors( parameter_name: str, base_model: ModelReference, tensors: ImmutableMap[ModelReference, torch.Tensor], tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], ) -> Tuple[List[Dict[str, Any]], torch.Tensor]: keys = list(tensors.keys()) base = tensors[base_model] res = [] for model in keys: if model == base_model: continue x = tensors[model].to(base.dtype) if x.shape != base.shape: if "lm_head" in parameter_name or "embed_tokens" in parameter_name: x = x[: base.shape[0], : base.shape[1]] logging.warning(f"Using submatrix of {model}:{parameter_name}") else: logging.warning( f"skipping {model}:{parameter_name} due to size mismatch" ) continue delta = x - base del x del tensors[model] d = {} d["model"] = model d["delta"] = delta for p in tensor_parameters[model]: d[p] = tensor_parameters[model][p] res.append(d) return res, base def get_mask( delta: torch.Tensor, method: Literal["sum", "count"] = "sum", mask_dtype: Optional[torch.dtype] = None, ): """Returns a mask determining which delta vectors should be merged into the final model. For the methodology described in the TIES paper use 'sum'. For a simpler naive count of signs, use 'count'.""" if mask_dtype is None: mask_dtype = delta.dtype sign = delta.sign().to(mask_dtype) if method == "sum": sign_weight = delta.sum(dim=0) majority_sign = (sign_weight >= 0).to(mask_dtype) * 2 - 1 del sign_weight elif method == "count": majority_sign = (sign.sum(dim=0) >= 0).to(mask_dtype) * 2 - 1 else: raise RuntimeError(f'Unimplemented mask method "{method}"') return sign == majority_sign