|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from mergekit.architecture import WeightInfo |
|
from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes |
|
from mergekit.graph import Task |
|
from mergekit.io.tasks import GatherTensors |
|
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod |
|
|
|
|
|
class SlerpTask(Task[torch.Tensor]): |
|
gather_tensors: GatherTensors |
|
base_model: ModelReference |
|
t: float |
|
parameter_name: str |
|
|
|
def uses_accelerator(self) -> bool: |
|
return True |
|
|
|
def arguments(self) -> Dict[str, Task]: |
|
return {"tensors": self.gather_tensors} |
|
|
|
def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor: |
|
if len(tensors) == 1: |
|
return list(tensors.values())[0] |
|
elif len(tensors) != 2: |
|
raise RuntimeError("Slerp merge expects exactly two models") |
|
elif self.base_model not in tensors: |
|
raise RuntimeError("Base model not in input tensors") |
|
|
|
[a, b] = list(tensors.items()) |
|
if a[0] != self.base_model: |
|
[a, b] = [b, a] |
|
prepped_tensors = [a[1], b[1]] |
|
|
|
rectify_embed_sizes(self.parameter_name, prepped_tensors) |
|
|
|
return ( |
|
slerp( |
|
self.t, |
|
prepped_tensors[0], |
|
prepped_tensors[1], |
|
) |
|
.to(prepped_tensors[0].dtype) |
|
.to(prepped_tensors[0].device) |
|
) |
|
|
|
|
|
class SlerpMerge(MergeMethod): |
|
def parameters(self) -> List[ConfigParameterDef]: |
|
return [ConfigParameterDef(name="t", required=True)] |
|
|
|
def make_task( |
|
self, |
|
*, |
|
output_weight: WeightInfo, |
|
tensors: GatherTensors, |
|
parameters: ImmutableMap[str, Any], |
|
base_model: Optional[ModelReference], |
|
**_kwargs, |
|
) -> Task: |
|
return SlerpTask( |
|
gather_tensors=tensors, |
|
base_model=base_model, |
|
parameter_name=output_weight.name, |
|
t=parameters["t"], |
|
) |
|
|
|
|
|
def lerp( |
|
t: float, v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor] |
|
) -> Union[np.ndarray, torch.Tensor]: |
|
return (1 - t) * v0 + t * v1 |
|
|
|
|
|
def slerp( |
|
t: Union[float, np.ndarray], |
|
v0: Union[np.ndarray, torch.Tensor], |
|
v1: Union[np.ndarray, torch.Tensor], |
|
DOT_THRESHOLD: float = 0.9995, |
|
eps: float = 1e-8, |
|
): |
|
""" |
|
Spherical linear interpolation |
|
|
|
From: https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c |
|
Args: |
|
t (float/np.ndarray): Float value between 0.0 and 1.0 |
|
v0 (np.ndarray): Starting vector |
|
v1 (np.ndarray): Final vector |
|
DOT_THRESHOLD (float): Threshold for considering the two vectors as |
|
colinear. Not recommended to alter this. |
|
Returns: |
|
v2 (np.ndarray): Interpolation vector between v0 and v1 |
|
""" |
|
is_torch = False |
|
if not isinstance(v0, np.ndarray): |
|
is_torch = True |
|
v0 = v0.detach().cpu().float().numpy() |
|
if not isinstance(v1, np.ndarray): |
|
is_torch = True |
|
v1 = v1.detach().cpu().float().numpy() |
|
|
|
|
|
v0_copy = np.copy(v0) |
|
v1_copy = np.copy(v1) |
|
|
|
|
|
v0 = normalize(v0, eps) |
|
v1 = normalize(v1, eps) |
|
|
|
|
|
dot = np.sum(v0 * v1) |
|
|
|
|
|
if np.abs(dot) > DOT_THRESHOLD: |
|
res = lerp(t, v0_copy, v1_copy) |
|
return maybe_torch(res, is_torch) |
|
|
|
|
|
theta_0 = np.arccos(dot) |
|
sin_theta_0 = np.sin(theta_0) |
|
|
|
|
|
theta_t = theta_0 * t |
|
sin_theta_t = np.sin(theta_t) |
|
|
|
|
|
s0 = np.sin(theta_0 - theta_t) / sin_theta_0 |
|
s1 = sin_theta_t / sin_theta_0 |
|
res = s0 * v0_copy + s1 * v1_copy |
|
|
|
return maybe_torch(res, is_torch) |
|
|
|
|
|
def maybe_torch(v: np.ndarray, is_torch: bool): |
|
if is_torch: |
|
return torch.from_numpy(v) |
|
return v |
|
|
|
|
|
def normalize(v: np.ndarray, eps: float): |
|
norm_v = np.linalg.norm(v) |
|
if norm_v > eps: |
|
v = v / norm_v |
|
return v |
|
|