File size: 7,938 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
from typing import Union
import torch
from torch import Tensor
import math
class TensorInterp:
LERP = "lerp"
SLERP = "slerp"
_LIST = [LERP, SLERP]
class SelectError(Exception):
pass
def lerp_tensors(tensor_from: Tensor, tensor_to: Tensor, strength_to: Tensor):
# basic weighted average to combine conds
# TODO: see how far we can generalize this, and if some params need to change
return torch.mul(tensor_from, (1.0-strength_to)) + torch.mul(tensor_to, strength_to)
# https://matilabs.ai/2024/03/05/slerp-model-merging-primer/#slerp-code
# https://medium.com/@akp83540/slerp-algorithm-a4ce1bacee4a
def slerp_tensors(tensor_from: Tensor, tensor_to: Tensor, strength_to: Tensor, dot_threshold=0.9995):
# normalize tensors
normal_from = tensor_from / tensor_from.norm()
normal_to = tensor_to / tensor_to.norm()
# get dot product to find the cosine of the angle between the tensors (vectors)
dot = (normal_from * normal_to).sum()
# if tensors (vectors) nearly parallel (dot product ~ 1.0), simplify to lerp
if dot.abs() > dot_threshold:
return lerp_tensors(tensor_from=tensor_from, tensor_to=tensor_to, strength_to=strength_to)
# omega (Ω)
omega = dot.acos()
# apply formula:
# q(t) = (q₀ * sin((1 — t) * Ω)) / sin(Ω) + (q₁ * sin(t * Ω)) / sin(Ω)
# simplified to (extract sin(Ω)):
# q(t) = ((q₀ * sin((1 — t) * Ω)) + (q₁ * sin(t * Ω))) / sin(Ω)
sin_from = ((1.0 - strength_to) * omega).sin()
sin_to = (strength_to * omega).sin()
return (tensor_from * sin_from + tensor_to * sin_to) / omega.sin()
def validate_index(raw_index: Union[str, int, float], length: int=0, is_range: bool=False, allow_negative=False, allow_missing=False, allow_decimal=False) -> int:
is_decimal = False
if isinstance(raw_index, str):
if '.' in raw_index:
is_decimal = True
if is_decimal:
if not allow_decimal:
raise SelectError(f"Index '{raw_index}' contains a decimal, but decimal inputs are not allowed.")
if length == 0:
raise SelectError(f"Decimal indexes are not allowed when no explicit length ({length}) is provided.")
try:
index_float = float(raw_index)
except ValueError as e:
raise SelectError(f"Decimal index '{raw_index}' isn't a valid float. ", e)
if index_float < 0.0 or index_float > 1.0:
raise SelectError(f"Decimal index must be between 0.0 and 1.0, but was '{index_float}'.")
if math.isclose(index_float, 1.0):
index = length-1
else:
index = int(index_float * length)
else:
try:
index = int(raw_index)
except ValueError as e:
raise SelectError(f"Index '{raw_index}' must be an integer.", e)
# if part of range, do nothing
if is_range:
if index < 0:
conv_index = length+index
if conv_index < 0:
conv_index = 0
index = conv_index
return index
# otherwise, validate index
# validate not out of range - only when latent_count is passed in
if length > 0 and index > length-1 and not allow_missing:
raise SelectError(f"Index '{index}' out of range for {length} item(s).")
# if negative, validate not out of range
if index < 0:
if not allow_negative:
raise SelectError(f"Negative indeces not allowed, but was '{index}'.")
conv_index = length+index
if conv_index < 0 and not allow_missing:
raise SelectError(f"Index '{index}', converted to '{conv_index}' out of range for {length} item(s).")
index = conv_index
return index
def convert_to_index_int(raw_index: str, length: int=0, is_range: bool=False, allow_negative=False, allow_missing=False, allow_decimal=False) -> int:
return validate_index(raw_index, length=length, is_range=is_range, allow_negative=allow_negative, allow_missing=allow_missing, allow_decimal=allow_decimal)
def convert_str_to_indexes(indexes_str: str, length: int=0, allow_range=True, allow_missing=False, fix_reverse=False, same_is_one=False, allow_decimal=False) -> list[int]:
if not indexes_str:
return []
int_indexes = list(range(0, length))
allow_negative = length > 0
chosen_indexes = []
# parse string - allow positive ints, negative ints, and ranges separated by ':'
groups = indexes_str.split(",")
groups = [g.strip() for g in groups]
for g in groups:
# parse range of indeces (e.g. 2:16)
if ':' in g:
if not allow_range:
raise SelectError("Ranges (:) not allowed for this input.")
index_range = g.split(":", 2)
index_range = [r.strip() for r in index_range]
start_index = index_range[0]
if len(start_index) > 0:
start_index = convert_to_index_int(start_index, length=length, is_range=True, allow_negative=allow_negative, allow_missing=allow_missing, allow_decimal=allow_decimal)
else:
start_index = 0
end_index = index_range[1]
if len(end_index) > 0:
end_index = convert_to_index_int(end_index, length=length, is_range=True, allow_negative=allow_negative, allow_missing=allow_missing, allow_decimal=allow_decimal)
else:
end_index = length
# support step as well, to allow things like reversing, every-other, etc.
step = 1
if len(index_range) > 2:
step = index_range[2]
if len(step) > 0:
step = convert_to_index_int(step, length=length, is_range=True, allow_negative=True, allow_missing=True)
else:
step = 1
# if supposed to treat same start and end as one entry, do so
if same_is_one and start_index == end_index:
chosen_indexes.append(convert_to_index_int(start_index, length=length, allow_negative=allow_negative, allow_missing=allow_missing, allow_decimal=allow_decimal))
else:
# if should fix_reverse and reverse detected, then swap start and end indexes
do_reverse = False
if fix_reverse and end_index < start_index:
start_index, end_index = end_index, start_index
#do_reverse = True
# if latents were passed in, base indeces on known latent count
if len(int_indexes) > 0 and not allow_missing:
new_indexes = int_indexes[start_index:end_index][::step]
if do_reverse:
new_indexes.reverse()
chosen_indexes.extend(new_indexes)
# otherwise, assume indeces are valid
else:
new_indexes = list(range(start_index, end_index, step))
if do_reverse:
new_indexes.reverse()
chosen_indexes.extend(new_indexes)
# parse individual indeces
else:
chosen_indexes.append(convert_to_index_int(g, length=length, allow_negative=allow_negative, allow_missing=allow_missing, allow_decimal=allow_decimal))
return chosen_indexes
def select_indexes(input_obj: Union[Tensor, list], idxs: list):
if type(input_obj) == Tensor:
return input_obj[idxs]
else:
return [input_obj[i] for i in idxs]
def select_indexes_from_str(input_obj: Union[Tensor, list], indexes: str, allow_range=True, err_if_missing=True, err_if_empty=True):
real_idxs = convert_str_to_indexes(indexes, len(input_obj), allow_range=allow_range, allow_missing=not err_if_missing)
if err_if_empty and len(real_idxs) == 0:
raise Exception(f"Nothing was selected based on indexes found in '{indexes}'.")
return select_indexes(input_obj, real_idxs)
|