# Copyright (C) 2024 Charles O. Goddard # # This software is free software: you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation, either version 3 of the # License, or (at your option) any later version. # # This software is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see http://www.gnu.org/licenses/. from typing import Any, Dict, List, Optional import torch from pydantic import BaseModel from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task from mergekit.io.tasks import GatherTensors from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod from mergekit.merge_methods.slerp import slerp from mergekit.tokenizer import BuildTokenizer, TokenizerInfo class TokenizerPermutationMergeTask(Task[torch.Tensor]): tokenizer_task: BuildTokenizer gather_tensors: GatherTensors base_model: Optional[ModelReference] use_slerp: bool slerp_t: Optional[float] tensor_parameters: ImmutableMap[ModelReference, Any] def uses_accelerator(self) -> bool: return True def arguments(self) -> Dict[str, Task]: return {"tokenizer_info": self.tokenizer_task, "tensors": self.gather_tensors} def execute( self, tokenizer_info: TokenizerInfo, tensors: Dict[ModelReference, torch.Tensor] ) -> torch.Tensor: if not tensors: return None if len(tensors) == 1: return list(tensors.values())[0] if self.use_slerp and self.slerp_t is None: raise RuntimeError("Must set t to use embed_slerp") models = [] expanded = [] masks = [] weights = [] for model in tensors: models.append(model) x = tensors[model] p = tokenizer_info.permutations[model] xp = torch.zeros((len(p), x.shape[-1]), dtype=x.dtype, device=x.device) mask = torch.zeros((len(p),), dtype=torch.bool, device=x.device) for out_idx in p: in_idx = p[out_idx] if in_idx < 0: continue xp[out_idx, :] = x[in_idx, :] mask[out_idx] = 1 expanded.append(xp) masks.append(mask) is_base = model == self.base_model if self.use_slerp: weight = (1.0 - self.slerp_t) if is_base else self.slerp_t else: weight = self.tensor_parameters[model]["weight"] weights.append(weight) expanded = torch.stack(expanded, dim=0) masks = torch.stack(masks, dim=0).unsqueeze(-1) weights = ( torch.tensor(weights, dtype=expanded.dtype, device=expanded.device) .unsqueeze(-1) .unsqueeze(-1) ) total_weight = (masks * weights).sum(dim=0) scale = 1 / total_weight scale[total_weight.abs() < 1e-8] = 0 linear_merged = (expanded * weights * masks).sum(dim=0) * scale if self.use_slerp: if expanded.shape[0] != 2: raise RuntimeError("SLERP takes exactly two models") if models[0] == self.base_model: v0 = expanded[0, ...] v1 = expanded[1, ...] else: v0 = expanded[1, ...] v1 = expanded[0, ...] res = slerp(self.slerp_t, v0, v1) need_linear = (masks.sum(dim=0) != 2).squeeze(dim=-1) res[need_linear, :] = linear_merged[need_linear, :].to( device=res.device, dtype=res.dtype ) return res return linear_merged class TokenizerPermutationMerge(MergeMethod, BaseModel): tokenizer_task: BuildTokenizer def parameters(self) -> List[ConfigParameterDef]: return [ ConfigParameterDef(name="t", required=False), ConfigParameterDef(name="embed_slerp", required=False, default_value=False), ] def tensor_parameters(self) -> List[ConfigParameterDef]: return [ ConfigParameterDef(name="weight", required=False), ] def make_task( self, *, tensors: GatherTensors, parameters: Dict[str, Any], tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], base_model: Optional[ModelReference], **_kwargs, ) -> Task: return TokenizerPermutationMergeTask( base_model=base_model, tokenizer_task=self.tokenizer_task, gather_tensors=tensors, use_slerp=parameters["embed_slerp"], slerp_t=parameters["t"], tensor_parameters=tensor_parameters, )