|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from enum import Enum |
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
import torch |
|
from pydantic import BaseModel |
|
from typing_extensions import Literal |
|
|
|
from mergekit.architecture import WeightInfo |
|
from mergekit.common import ImmutableMap, ModelReference |
|
from mergekit.graph import Task |
|
from mergekit.io.tasks import GatherTensors |
|
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod |
|
from mergekit.sparsify import SparsificationMethod, sparsify |
|
|
|
|
|
class ConsensusMethod(str, Enum): |
|
count = "count" |
|
sum = "sum" |
|
|
|
|
|
class GeneralizedTaskArithmeticMerge(MergeMethod, BaseModel, frozen=True): |
|
consensus_method: Optional[ConsensusMethod] |
|
sparsification_method: Optional[SparsificationMethod] |
|
default_normalize: bool |
|
|
|
def parameters(self) -> List[ConfigParameterDef]: |
|
return [ |
|
ConfigParameterDef(name="int8_mask", required=False, default_value=False), |
|
ConfigParameterDef( |
|
name="normalize", required=False, default_value=self.default_normalize |
|
), |
|
] |
|
|
|
def tensor_parameters(self) -> List[ConfigParameterDef]: |
|
return [ |
|
ConfigParameterDef(name="weight", required=True), |
|
ConfigParameterDef(name="density", required=False, default_value=1.0), |
|
] |
|
|
|
def make_task( |
|
self, |
|
output_weight: WeightInfo, |
|
tensors: GatherTensors, |
|
base_model: Optional[ModelReference], |
|
parameters: ImmutableMap[str, Any], |
|
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], |
|
) -> Task: |
|
return GTATask( |
|
method=self, |
|
tensors=tensors, |
|
base_model=base_model, |
|
tensor_parameters=tensor_parameters, |
|
int8_mask=parameters["int8_mask"], |
|
normalize=parameters["normalize"], |
|
out_tensor_name=output_weight.name, |
|
) |
|
|
|
|
|
class GTATask(Task[torch.Tensor]): |
|
method: GeneralizedTaskArithmeticMerge |
|
tensors: GatherTensors |
|
base_model: ModelReference |
|
out_tensor_name: str |
|
tensor_parameters: ImmutableMap[ModelReference, Any] |
|
int8_mask: bool |
|
normalize: bool |
|
|
|
def uses_accelerator(self) -> bool: |
|
return True |
|
|
|
def arguments(self) -> Dict[str, Task]: |
|
return {"tensors": self.tensors} |
|
|
|
def execute( |
|
self, |
|
tensors: Dict[ModelReference, torch.Tensor], |
|
**_kwargs, |
|
) -> torch.Tensor: |
|
|
|
tvs, base = get_task_vectors( |
|
self.out_tensor_name, |
|
self.base_model, |
|
tensors, |
|
tensor_parameters=self.tensor_parameters.data, |
|
) |
|
if not tvs: |
|
return base |
|
|
|
|
|
if self.method.sparsification_method: |
|
for tv_info in tvs: |
|
tv_info["delta"] = sparsify( |
|
tv_info["delta"], |
|
density=tv_info["density"], |
|
method=self.method.sparsification_method, |
|
) |
|
|
|
deltas = torch.stack([tv["delta"] for tv in tvs], dim=0) |
|
weights = torch.tensor( |
|
[tv["weight"] for tv in tvs], dtype=deltas.dtype, device=deltas.device |
|
) |
|
while len(deltas.shape) > len(weights.shape): |
|
weights.unsqueeze_(-1) |
|
|
|
weighted_deltas = deltas * weights |
|
|
|
|
|
if self.method.consensus_method: |
|
mask_dtype = torch.int8 if self.int8_mask else base.dtype |
|
mask = get_mask( |
|
weighted_deltas, |
|
method=self.method.consensus_method, |
|
mask_dtype=mask_dtype, |
|
) |
|
mixed_delta = (weighted_deltas * mask).sum(dim=0) |
|
divisor = (weights * mask).sum(dim=0) |
|
divisor[divisor == 0] = 1 |
|
else: |
|
mixed_delta = weighted_deltas.sum(dim=0) |
|
divisor = weights.sum(dim=0) |
|
divisor[divisor.abs() < 1e-8] = 1 |
|
|
|
if self.normalize: |
|
mixed_delta /= divisor |
|
|
|
return (base + mixed_delta).to(base.dtype) |
|
|
|
|
|
def get_task_vectors( |
|
parameter_name: str, |
|
base_model: ModelReference, |
|
tensors: ImmutableMap[ModelReference, torch.Tensor], |
|
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], |
|
) -> Tuple[List[Dict[str, Any]], torch.Tensor]: |
|
keys = list(tensors.keys()) |
|
base = tensors[base_model] |
|
|
|
res = [] |
|
for model in keys: |
|
if model == base_model: |
|
continue |
|
|
|
x = tensors[model].to(base.dtype) |
|
if x.shape != base.shape: |
|
if "lm_head" in parameter_name or "embed_tokens" in parameter_name: |
|
x = x[: base.shape[0], : base.shape[1]] |
|
logging.warning(f"Using submatrix of {model}:{parameter_name}") |
|
else: |
|
logging.warning( |
|
f"skipping {model}:{parameter_name} due to size mismatch" |
|
) |
|
continue |
|
|
|
delta = x - base |
|
del x |
|
del tensors[model] |
|
|
|
d = {} |
|
d["model"] = model |
|
d["delta"] = delta |
|
for p in tensor_parameters[model]: |
|
d[p] = tensor_parameters[model][p] |
|
res.append(d) |
|
return res, base |
|
|
|
|
|
def get_mask( |
|
delta: torch.Tensor, |
|
method: Literal["sum", "count"] = "sum", |
|
mask_dtype: Optional[torch.dtype] = None, |
|
): |
|
"""Returns a mask determining which delta vectors should be merged |
|
into the final model. |
|
|
|
For the methodology described in the TIES paper use 'sum'. For a |
|
simpler naive count of signs, use 'count'.""" |
|
if mask_dtype is None: |
|
mask_dtype = delta.dtype |
|
|
|
sign = delta.sign().to(mask_dtype) |
|
|
|
if method == "sum": |
|
sign_weight = delta.sum(dim=0) |
|
majority_sign = (sign_weight >= 0).to(mask_dtype) * 2 - 1 |
|
del sign_weight |
|
elif method == "count": |
|
majority_sign = (sign.sum(dim=0) >= 0).to(mask_dtype) * 2 - 1 |
|
else: |
|
raise RuntimeError(f'Unimplemented mask method "{method}"') |
|
|
|
return sign == majority_sign |
|
|