File size: 5,056 Bytes
a164e13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.
from typing import Any, Dict, List, Optional
import torch
from pydantic import BaseModel
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.merge_methods.slerp import slerp
from mergekit.tokenizer import BuildTokenizer, TokenizerInfo
class TokenizerPermutationMergeTask(Task[torch.Tensor]):
tokenizer_task: BuildTokenizer
gather_tensors: GatherTensors
base_model: Optional[ModelReference]
use_slerp: bool
slerp_t: Optional[float]
tensor_parameters: ImmutableMap[ModelReference, Any]
def uses_accelerator(self) -> bool:
return True
def arguments(self) -> Dict[str, Task]:
return {"tokenizer_info": self.tokenizer_task, "tensors": self.gather_tensors}
def execute(
self, tokenizer_info: TokenizerInfo, tensors: Dict[ModelReference, torch.Tensor]
) -> torch.Tensor:
if not tensors:
return None
if len(tensors) == 1:
return list(tensors.values())[0]
if self.use_slerp and self.slerp_t is None:
raise RuntimeError("Must set t to use embed_slerp")
models = []
expanded = []
masks = []
weights = []
for model in tensors:
models.append(model)
x = tensors[model]
p = tokenizer_info.permutations[model]
xp = torch.zeros((len(p), x.shape[-1]), dtype=x.dtype, device=x.device)
mask = torch.zeros((len(p),), dtype=torch.bool, device=x.device)
for out_idx in p:
in_idx = p[out_idx]
if in_idx < 0:
continue
xp[out_idx, :] = x[in_idx, :]
mask[out_idx] = 1
expanded.append(xp)
masks.append(mask)
is_base = model == self.base_model
if self.use_slerp:
weight = (1.0 - self.slerp_t) if is_base else self.slerp_t
else:
weight = self.tensor_parameters[model]["weight"]
weights.append(weight)
expanded = torch.stack(expanded, dim=0)
masks = torch.stack(masks, dim=0).unsqueeze(-1)
weights = (
torch.tensor(weights, dtype=expanded.dtype, device=expanded.device)
.unsqueeze(-1)
.unsqueeze(-1)
)
total_weight = (masks * weights).sum(dim=0)
scale = 1 / total_weight
scale[total_weight.abs() < 1e-8] = 0
linear_merged = (expanded * weights * masks).sum(dim=0) * scale
if self.use_slerp:
if expanded.shape[0] != 2:
raise RuntimeError("SLERP takes exactly two models")
if models[0] == self.base_model:
v0 = expanded[0, ...]
v1 = expanded[1, ...]
else:
v0 = expanded[1, ...]
v1 = expanded[0, ...]
res = slerp(self.slerp_t, v0, v1)
need_linear = (masks.sum(dim=0) != 2).squeeze(dim=-1)
res[need_linear, :] = linear_merged[need_linear, :].to(
device=res.device, dtype=res.dtype
)
return res
return linear_merged
class TokenizerPermutationMerge(MergeMethod, BaseModel):
tokenizer_task: BuildTokenizer
def parameters(self) -> List[ConfigParameterDef]:
return [
ConfigParameterDef(name="t", required=False),
ConfigParameterDef(name="embed_slerp", required=False, default_value=False),
]
def tensor_parameters(self) -> List[ConfigParameterDef]:
return [
ConfigParameterDef(name="weight", required=False),
]
def make_task(
self,
*,
tensors: GatherTensors,
parameters: Dict[str, Any],
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
base_model: Optional[ModelReference],
**_kwargs,
) -> Task:
return TokenizerPermutationMergeTask(
base_model=base_model,
tokenizer_task=self.tokenizer_task,
gather_tensors=tensors,
use_slerp=parameters["embed_slerp"],
slerp_t=parameters["t"],
tensor_parameters=tensor_parameters,
)
|