from typing import Union import torch from torch import Tensor import math class TensorInterp: LERP = "lerp" SLERP = "slerp" _LIST = [LERP, SLERP] class SelectError(Exception): pass def lerp_tensors(tensor_from: Tensor, tensor_to: Tensor, strength_to: Tensor): # basic weighted average to combine conds # TODO: see how far we can generalize this, and if some params need to change return torch.mul(tensor_from, (1.0-strength_to)) + torch.mul(tensor_to, strength_to) # https://matilabs.ai/2024/03/05/slerp-model-merging-primer/#slerp-code # https://medium.com/@akp83540/slerp-algorithm-a4ce1bacee4a def slerp_tensors(tensor_from: Tensor, tensor_to: Tensor, strength_to: Tensor, dot_threshold=0.9995): # normalize tensors normal_from = tensor_from / tensor_from.norm() normal_to = tensor_to / tensor_to.norm() # get dot product to find the cosine of the angle between the tensors (vectors) dot = (normal_from * normal_to).sum() # if tensors (vectors) nearly parallel (dot product ~ 1.0), simplify to lerp if dot.abs() > dot_threshold: return lerp_tensors(tensor_from=tensor_from, tensor_to=tensor_to, strength_to=strength_to) # omega (Ω) omega = dot.acos() # apply formula: # q(t) = (q₀ * sin((1 — t) * Ω)) / sin(Ω) + (q₁ * sin(t * Ω)) / sin(Ω) # simplified to (extract sin(Ω)): # q(t) = ((q₀ * sin((1 — t) * Ω)) + (q₁ * sin(t * Ω))) / sin(Ω) sin_from = ((1.0 - strength_to) * omega).sin() sin_to = (strength_to * omega).sin() return (tensor_from * sin_from + tensor_to * sin_to) / omega.sin() def validate_index(raw_index: Union[str, int, float], length: int=0, is_range: bool=False, allow_negative=False, allow_missing=False, allow_decimal=False) -> int: is_decimal = False if isinstance(raw_index, str): if '.' in raw_index: is_decimal = True if is_decimal: if not allow_decimal: raise SelectError(f"Index '{raw_index}' contains a decimal, but decimal inputs are not allowed.") if length == 0: raise SelectError(f"Decimal indexes are not allowed when no explicit length ({length}) is provided.") try: index_float = float(raw_index) except ValueError as e: raise SelectError(f"Decimal index '{raw_index}' isn't a valid float. ", e) if index_float < 0.0 or index_float > 1.0: raise SelectError(f"Decimal index must be between 0.0 and 1.0, but was '{index_float}'.") if math.isclose(index_float, 1.0): index = length-1 else: index = int(index_float * length) else: try: index = int(raw_index) except ValueError as e: raise SelectError(f"Index '{raw_index}' must be an integer.", e) # if part of range, do nothing if is_range: if index < 0: conv_index = length+index if conv_index < 0: conv_index = 0 index = conv_index return index # otherwise, validate index # validate not out of range - only when latent_count is passed in if length > 0 and index > length-1 and not allow_missing: raise SelectError(f"Index '{index}' out of range for {length} item(s).") # if negative, validate not out of range if index < 0: if not allow_negative: raise SelectError(f"Negative indeces not allowed, but was '{index}'.") conv_index = length+index if conv_index < 0 and not allow_missing: raise SelectError(f"Index '{index}', converted to '{conv_index}' out of range for {length} item(s).") index = conv_index return index def convert_to_index_int(raw_index: str, length: int=0, is_range: bool=False, allow_negative=False, allow_missing=False, allow_decimal=False) -> int: return validate_index(raw_index, length=length, is_range=is_range, allow_negative=allow_negative, allow_missing=allow_missing, allow_decimal=allow_decimal) def convert_str_to_indexes(indexes_str: str, length: int=0, allow_range=True, allow_missing=False, fix_reverse=False, same_is_one=False, allow_decimal=False) -> list[int]: if not indexes_str: return [] int_indexes = list(range(0, length)) allow_negative = length > 0 chosen_indexes = [] # parse string - allow positive ints, negative ints, and ranges separated by ':' groups = indexes_str.split(",") groups = [g.strip() for g in groups] for g in groups: # parse range of indeces (e.g. 2:16) if ':' in g: if not allow_range: raise SelectError("Ranges (:) not allowed for this input.") index_range = g.split(":", 2) index_range = [r.strip() for r in index_range] start_index = index_range[0] if len(start_index) > 0: start_index = convert_to_index_int(start_index, length=length, is_range=True, allow_negative=allow_negative, allow_missing=allow_missing, allow_decimal=allow_decimal) else: start_index = 0 end_index = index_range[1] if len(end_index) > 0: end_index = convert_to_index_int(end_index, length=length, is_range=True, allow_negative=allow_negative, allow_missing=allow_missing, allow_decimal=allow_decimal) else: end_index = length # support step as well, to allow things like reversing, every-other, etc. step = 1 if len(index_range) > 2: step = index_range[2] if len(step) > 0: step = convert_to_index_int(step, length=length, is_range=True, allow_negative=True, allow_missing=True) else: step = 1 # if supposed to treat same start and end as one entry, do so if same_is_one and start_index == end_index: chosen_indexes.append(convert_to_index_int(start_index, length=length, allow_negative=allow_negative, allow_missing=allow_missing, allow_decimal=allow_decimal)) else: # if should fix_reverse and reverse detected, then swap start and end indexes do_reverse = False if fix_reverse and end_index < start_index: start_index, end_index = end_index, start_index #do_reverse = True # if latents were passed in, base indeces on known latent count if len(int_indexes) > 0 and not allow_missing: new_indexes = int_indexes[start_index:end_index][::step] if do_reverse: new_indexes.reverse() chosen_indexes.extend(new_indexes) # otherwise, assume indeces are valid else: new_indexes = list(range(start_index, end_index, step)) if do_reverse: new_indexes.reverse() chosen_indexes.extend(new_indexes) # parse individual indeces else: chosen_indexes.append(convert_to_index_int(g, length=length, allow_negative=allow_negative, allow_missing=allow_missing, allow_decimal=allow_decimal)) return chosen_indexes def select_indexes(input_obj: Union[Tensor, list], idxs: list): if type(input_obj) == Tensor: return input_obj[idxs] else: return [input_obj[i] for i in idxs] def select_indexes_from_str(input_obj: Union[Tensor, list], indexes: str, allow_range=True, err_if_missing=True, err_if_empty=True): real_idxs = convert_str_to_indexes(indexes, len(input_obj), allow_range=allow_range, allow_missing=not err_if_missing) if err_if_empty and len(real_idxs) == 0: raise Exception(f"Nothing was selected based on indexes found in '{indexes}'.") return select_indexes(input_obj, real_idxs)