# Copyright (C) 2024 Charles O. Goddard # # This software is free software: you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation, either version 3 of the # License, or (at your option) any later version. # # This software is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see http://www.gnu.org/licenses/. from typing import Any, Dict, List import torch from mergekit.architecture import WeightInfo from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes from mergekit.graph import Task from mergekit.io.tasks import GatherTensors from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod class LinearMergeTask(Task[torch.Tensor]): gather_tensors: GatherTensors tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]] normalize: bool parameter_name: str def uses_accelerator(self) -> bool: return True def arguments(self) -> Dict[str, Task]: return {"tensors": self.gather_tensors} def execute( self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs ) -> torch.Tensor: keys = list(tensors.keys()) tensors = [tensors[key] for key in keys] weights = [self.tensor_parameters[key]["weight"] for key in keys] rectify_embed_sizes(self.parameter_name, tensors) unique_shapes = set(t.shape for t in tensors) if len(unique_shapes) != 1: raise RuntimeError( f"Tensor size mismatch for {self.parameter_name}, sizes: {list(unique_shapes)}" ) tensors = torch.stack(tensors, dim=0) weights = torch.tensor(weights, dtype=tensors.dtype, device=tensors.device) while len(weights.shape) < len(tensors.shape): weights.unsqueeze_(-1) res = (weights * tensors).sum(dim=0) if self.normalize: res /= weights.sum(dim=0) return res class LinearMerge(MergeMethod): def parameters(self) -> List[ConfigParameterDef]: return [ ConfigParameterDef(name="normalize", required=False, default_value=True), ] def tensor_parameters(self) -> List[ConfigParameterDef]: return [ConfigParameterDef(name="weight", required=True)] def make_task( self, *, output_weight: WeightInfo, tensors: GatherTensors, parameters: Dict[str, Any], tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], **_kwargs, ) -> Task: return LinearMergeTask( gather_tensors=tensors, tensor_parameters=tensor_parameters, normalize=parameters["normalize"], parameter_name=output_weight.name, )