# Copyright (C) 2024 Charles O. Goddard # # This software is free software: you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation, either version 3 of the # License, or (at your option) any later version. # # This software is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see http://www.gnu.org/licenses/. from typing import Any, Dict, List, Optional import torch from mergekit.architecture import WeightInfo from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes from mergekit.graph import Task from mergekit.io.tasks import GatherTensors from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod class ModelStockMergeTask(Task[torch.Tensor]): gather_tensors: GatherTensors base_model: ModelReference parameter_name: str filter_wise: bool = False def uses_accelerator(self) -> bool: return True def arguments(self) -> Dict[str, Task]: return {"tensors": self.gather_tensors} def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor: if len(tensors) == 1 and self.base_model in tensors: return tensors[self.base_model] if len(tensors) < 3: raise ValueError( "ModelStockMerge requires at least 3 models (base plus two+ others)" ) w_0, ws = self.get_rectified_weights(tensors) out_shape = w_0.shape if self.filter_wise: if w_0.dim() == 1: # bias (or other single-vector) parameters should be treated as row vectors w_0 = w_0.unsqueeze(0) ws = [w.unsqueeze(0) for w in ws] else: w_0 = w_0.view(-1) ws = [w.view(-1) for w in ws] offsets = [w - w_0 for w in ws] # now there is a question of how to come up with a value for theta. # in the two-vector case, we can get an exact angle between the two vectors # but the paper doesn't explicitly say what to do in the multi-vector case - # they keep using a singular theta value and don't elaborate on how to # calculate it. i'm going to assume an average of pairwise angles for now? i guess? cos_thetas = [] for i, w_0_offset in enumerate(offsets): for j in range(i + 1, len(offsets)): w_1_offset = offsets[j] norm_product = torch.norm(w_0_offset, dim=-1) * torch.norm( w_1_offset, dim=-1 ) cos_theta = ( (w_0_offset * w_1_offset).sum(dim=-1) / norm_product.clamp(min=1e-6) ).clamp(-1, 1) cos_thetas.append(cos_theta) cos_theta = torch.stack(cos_thetas).mean(dim=0).unsqueeze(-1) N = len(ws) t = (N * cos_theta) / (1 + (N - 1) * cos_theta) w_avg = sum(ws) / len(ws) w_h = t * w_avg + (1 - t) * w_0 return w_h.reshape(out_shape) def get_rectified_weights(self, tensors: Dict[ModelReference, torch.Tensor]): if self.base_model not in tensors: raise ValueError("Base model tensor not found") all_weights = [tensors[self.base_model]] + [ tensors[k] for k in tensors if k != self.base_model ] rectify_embed_sizes(self.parameter_name, all_weights) w_0 = all_weights[0] ws = all_weights[1:] return w_0, ws class ModelStockMerge(MergeMethod): def parameters(self) -> List[ConfigParameterDef]: return [ ConfigParameterDef(name="filter_wise", required=False, default_value=False) ] def make_task( self, *, output_weight: WeightInfo, tensors: GatherTensors, base_model: Optional[ModelReference], parameters: ImmutableMap[str, Any], **_kwargs, ) -> Task: return ModelStockMergeTask( gather_tensors=tensors, base_model=base_model, parameter_name=output_weight.name, filter_wise=parameters["filter_wise"], )