File size: 5,008 Bytes
a164e13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from mergekit.architecture import WeightInfo
from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
class SlerpTask(Task[torch.Tensor]):
gather_tensors: GatherTensors
base_model: ModelReference
t: float
parameter_name: str
def uses_accelerator(self) -> bool:
return True
def arguments(self) -> Dict[str, Task]:
return {"tensors": self.gather_tensors}
def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor:
if len(tensors) == 1:
return list(tensors.values())[0]
elif len(tensors) != 2:
raise RuntimeError("Slerp merge expects exactly two models")
elif self.base_model not in tensors:
raise RuntimeError("Base model not in input tensors")
[a, b] = list(tensors.items())
if a[0] != self.base_model:
[a, b] = [b, a]
prepped_tensors = [a[1], b[1]]
rectify_embed_sizes(self.parameter_name, prepped_tensors)
return (
slerp(
self.t,
prepped_tensors[0],
prepped_tensors[1],
)
.to(prepped_tensors[0].dtype)
.to(prepped_tensors[0].device)
)
class SlerpMerge(MergeMethod):
def parameters(self) -> List[ConfigParameterDef]:
return [ConfigParameterDef(name="t", required=True)]
def make_task(
self,
*,
output_weight: WeightInfo,
tensors: GatherTensors,
parameters: ImmutableMap[str, Any],
base_model: Optional[ModelReference],
**_kwargs,
) -> Task:
return SlerpTask(
gather_tensors=tensors,
base_model=base_model,
parameter_name=output_weight.name,
t=parameters["t"],
)
def lerp(
t: float, v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor]
) -> Union[np.ndarray, torch.Tensor]:
return (1 - t) * v0 + t * v1
def slerp(
t: Union[float, np.ndarray],
v0: Union[np.ndarray, torch.Tensor],
v1: Union[np.ndarray, torch.Tensor],
DOT_THRESHOLD: float = 0.9995,
eps: float = 1e-8,
):
"""
Spherical linear interpolation
From: https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colinear. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
"""
is_torch = False
if not isinstance(v0, np.ndarray):
is_torch = True
v0 = v0.detach().cpu().float().numpy()
if not isinstance(v1, np.ndarray):
is_torch = True
v1 = v1.detach().cpu().float().numpy()
# Copy the vectors to reuse them later
v0_copy = np.copy(v0)
v1_copy = np.copy(v1)
# Normalize the vectors to get the directions and angles
v0 = normalize(v0, eps)
v1 = normalize(v1, eps)
# Dot product with the normalized vectors (can't use np.dot in W)
dot = np.sum(v0 * v1)
# If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp
if np.abs(dot) > DOT_THRESHOLD:
res = lerp(t, v0_copy, v1_copy)
return maybe_torch(res, is_torch)
# Calculate initial angle between v0 and v1
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
# Angle at timestep t
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
# Finish the slerp algorithm
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
res = s0 * v0_copy + s1 * v1_copy
return maybe_torch(res, is_torch)
def maybe_torch(v: np.ndarray, is_torch: bool):
if is_torch:
return torch.from_numpy(v)
return v
def normalize(v: np.ndarray, eps: float):
norm_v = np.linalg.norm(v)
if norm_v > eps:
v = v / norm_v
return v
|