phi-2-merge / mergekit /merge_methods /generalized_task_arithmetic.py
Shaleen123's picture
Upload folder using huggingface_hub
a164e13 verified
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.
import logging
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
import torch
from pydantic import BaseModel
from typing_extensions import Literal
from mergekit.architecture import WeightInfo
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.sparsify import SparsificationMethod, sparsify
class ConsensusMethod(str, Enum):
count = "count"
sum = "sum"
class GeneralizedTaskArithmeticMerge(MergeMethod, BaseModel, frozen=True):
consensus_method: Optional[ConsensusMethod]
sparsification_method: Optional[SparsificationMethod]
default_normalize: bool
def parameters(self) -> List[ConfigParameterDef]:
return [
ConfigParameterDef(name="int8_mask", required=False, default_value=False),
ConfigParameterDef(
name="normalize", required=False, default_value=self.default_normalize
),
]
def tensor_parameters(self) -> List[ConfigParameterDef]:
return [
ConfigParameterDef(name="weight", required=True),
ConfigParameterDef(name="density", required=False, default_value=1.0),
]
def make_task(
self,
output_weight: WeightInfo,
tensors: GatherTensors,
base_model: Optional[ModelReference],
parameters: ImmutableMap[str, Any],
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
) -> Task:
return GTATask(
method=self,
tensors=tensors,
base_model=base_model,
tensor_parameters=tensor_parameters,
int8_mask=parameters["int8_mask"],
normalize=parameters["normalize"],
out_tensor_name=output_weight.name,
)
class GTATask(Task[torch.Tensor]):
method: GeneralizedTaskArithmeticMerge
tensors: GatherTensors
base_model: ModelReference
out_tensor_name: str
tensor_parameters: ImmutableMap[ModelReference, Any]
int8_mask: bool
normalize: bool
def uses_accelerator(self) -> bool:
return True
def arguments(self) -> Dict[str, Task]:
return {"tensors": self.tensors}
def execute(
self,
tensors: Dict[ModelReference, torch.Tensor],
**_kwargs,
) -> torch.Tensor:
# collect task vectors
tvs, base = get_task_vectors(
self.out_tensor_name,
self.base_model,
tensors,
tensor_parameters=self.tensor_parameters.data,
)
if not tvs:
return base
# sparsify
if self.method.sparsification_method:
for tv_info in tvs:
tv_info["delta"] = sparsify(
tv_info["delta"],
density=tv_info["density"],
method=self.method.sparsification_method,
)
deltas = torch.stack([tv["delta"] for tv in tvs], dim=0)
weights = torch.tensor(
[tv["weight"] for tv in tvs], dtype=deltas.dtype, device=deltas.device
)
while len(deltas.shape) > len(weights.shape):
weights.unsqueeze_(-1)
weighted_deltas = deltas * weights
# get sign consensus and mix deltas
if self.method.consensus_method:
mask_dtype = torch.int8 if self.int8_mask else base.dtype
mask = get_mask(
weighted_deltas,
method=self.method.consensus_method,
mask_dtype=mask_dtype,
)
mixed_delta = (weighted_deltas * mask).sum(dim=0)
divisor = (weights * mask).sum(dim=0)
divisor[divisor == 0] = 1
else:
mixed_delta = weighted_deltas.sum(dim=0)
divisor = weights.sum(dim=0)
divisor[divisor.abs() < 1e-8] = 1
if self.normalize:
mixed_delta /= divisor
return (base + mixed_delta).to(base.dtype)
def get_task_vectors(
parameter_name: str,
base_model: ModelReference,
tensors: ImmutableMap[ModelReference, torch.Tensor],
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
) -> Tuple[List[Dict[str, Any]], torch.Tensor]:
keys = list(tensors.keys())
base = tensors[base_model]
res = []
for model in keys:
if model == base_model:
continue
x = tensors[model].to(base.dtype)
if x.shape != base.shape:
if "lm_head" in parameter_name or "embed_tokens" in parameter_name:
x = x[: base.shape[0], : base.shape[1]]
logging.warning(f"Using submatrix of {model}:{parameter_name}")
else:
logging.warning(
f"skipping {model}:{parameter_name} due to size mismatch"
)
continue
delta = x - base
del x
del tensors[model]
d = {}
d["model"] = model
d["delta"] = delta
for p in tensor_parameters[model]:
d[p] = tensor_parameters[model][p]
res.append(d)
return res, base
def get_mask(
delta: torch.Tensor,
method: Literal["sum", "count"] = "sum",
mask_dtype: Optional[torch.dtype] = None,
):
"""Returns a mask determining which delta vectors should be merged
into the final model.
For the methodology described in the TIES paper use 'sum'. For a
simpler naive count of signs, use 'count'."""
if mask_dtype is None:
mask_dtype = delta.dtype
sign = delta.sign().to(mask_dtype)
if method == "sum":
sign_weight = delta.sum(dim=0)
majority_sign = (sign_weight >= 0).to(mask_dtype) * 2 - 1
del sign_weight
elif method == "count":
majority_sign = (sign.sum(dim=0) >= 0).to(mask_dtype) * 2 - 1
else:
raise RuntimeError(f'Unimplemented mask method "{method}"')
return sign == majority_sign