|
from typing import Union |
|
|
|
import torch |
|
from torch import Tensor |
|
import math |
|
|
|
|
|
class TensorInterp: |
|
LERP = "lerp" |
|
SLERP = "slerp" |
|
_LIST = [LERP, SLERP] |
|
|
|
|
|
class SelectError(Exception): |
|
pass |
|
|
|
|
|
def lerp_tensors(tensor_from: Tensor, tensor_to: Tensor, strength_to: Tensor): |
|
|
|
|
|
return torch.mul(tensor_from, (1.0-strength_to)) + torch.mul(tensor_to, strength_to) |
|
|
|
|
|
|
|
|
|
def slerp_tensors(tensor_from: Tensor, tensor_to: Tensor, strength_to: Tensor, dot_threshold=0.9995): |
|
|
|
normal_from = tensor_from / tensor_from.norm() |
|
normal_to = tensor_to / tensor_to.norm() |
|
|
|
dot = (normal_from * normal_to).sum() |
|
|
|
if dot.abs() > dot_threshold: |
|
return lerp_tensors(tensor_from=tensor_from, tensor_to=tensor_to, strength_to=strength_to) |
|
|
|
omega = dot.acos() |
|
|
|
|
|
|
|
|
|
sin_from = ((1.0 - strength_to) * omega).sin() |
|
sin_to = (strength_to * omega).sin() |
|
return (tensor_from * sin_from + tensor_to * sin_to) / omega.sin() |
|
|
|
|
|
def validate_index(raw_index: Union[str, int, float], length: int=0, is_range: bool=False, allow_negative=False, allow_missing=False, allow_decimal=False) -> int: |
|
is_decimal = False |
|
if isinstance(raw_index, str): |
|
if '.' in raw_index: |
|
is_decimal = True |
|
if is_decimal: |
|
if not allow_decimal: |
|
raise SelectError(f"Index '{raw_index}' contains a decimal, but decimal inputs are not allowed.") |
|
if length == 0: |
|
raise SelectError(f"Decimal indexes are not allowed when no explicit length ({length}) is provided.") |
|
try: |
|
index_float = float(raw_index) |
|
except ValueError as e: |
|
raise SelectError(f"Decimal index '{raw_index}' isn't a valid float. ", e) |
|
if index_float < 0.0 or index_float > 1.0: |
|
raise SelectError(f"Decimal index must be between 0.0 and 1.0, but was '{index_float}'.") |
|
if math.isclose(index_float, 1.0): |
|
index = length-1 |
|
else: |
|
index = int(index_float * length) |
|
else: |
|
try: |
|
index = int(raw_index) |
|
except ValueError as e: |
|
raise SelectError(f"Index '{raw_index}' must be an integer.", e) |
|
|
|
if is_range: |
|
if index < 0: |
|
conv_index = length+index |
|
if conv_index < 0: |
|
conv_index = 0 |
|
index = conv_index |
|
return index |
|
|
|
|
|
if length > 0 and index > length-1 and not allow_missing: |
|
raise SelectError(f"Index '{index}' out of range for {length} item(s).") |
|
|
|
if index < 0: |
|
if not allow_negative: |
|
raise SelectError(f"Negative indeces not allowed, but was '{index}'.") |
|
conv_index = length+index |
|
if conv_index < 0 and not allow_missing: |
|
raise SelectError(f"Index '{index}', converted to '{conv_index}' out of range for {length} item(s).") |
|
index = conv_index |
|
return index |
|
|
|
|
|
def convert_to_index_int(raw_index: str, length: int=0, is_range: bool=False, allow_negative=False, allow_missing=False, allow_decimal=False) -> int: |
|
return validate_index(raw_index, length=length, is_range=is_range, allow_negative=allow_negative, allow_missing=allow_missing, allow_decimal=allow_decimal) |
|
|
|
|
|
def convert_str_to_indexes(indexes_str: str, length: int=0, allow_range=True, allow_missing=False, fix_reverse=False, same_is_one=False, allow_decimal=False) -> list[int]: |
|
if not indexes_str: |
|
return [] |
|
int_indexes = list(range(0, length)) |
|
allow_negative = length > 0 |
|
chosen_indexes = [] |
|
|
|
groups = indexes_str.split(",") |
|
groups = [g.strip() for g in groups] |
|
for g in groups: |
|
|
|
if ':' in g: |
|
if not allow_range: |
|
raise SelectError("Ranges (:) not allowed for this input.") |
|
index_range = g.split(":", 2) |
|
index_range = [r.strip() for r in index_range] |
|
|
|
start_index = index_range[0] |
|
if len(start_index) > 0: |
|
start_index = convert_to_index_int(start_index, length=length, is_range=True, allow_negative=allow_negative, allow_missing=allow_missing, allow_decimal=allow_decimal) |
|
else: |
|
start_index = 0 |
|
end_index = index_range[1] |
|
if len(end_index) > 0: |
|
end_index = convert_to_index_int(end_index, length=length, is_range=True, allow_negative=allow_negative, allow_missing=allow_missing, allow_decimal=allow_decimal) |
|
else: |
|
end_index = length |
|
|
|
step = 1 |
|
if len(index_range) > 2: |
|
step = index_range[2] |
|
if len(step) > 0: |
|
step = convert_to_index_int(step, length=length, is_range=True, allow_negative=True, allow_missing=True) |
|
else: |
|
step = 1 |
|
|
|
if same_is_one and start_index == end_index: |
|
chosen_indexes.append(convert_to_index_int(start_index, length=length, allow_negative=allow_negative, allow_missing=allow_missing, allow_decimal=allow_decimal)) |
|
else: |
|
|
|
do_reverse = False |
|
if fix_reverse and end_index < start_index: |
|
start_index, end_index = end_index, start_index |
|
|
|
|
|
if len(int_indexes) > 0 and not allow_missing: |
|
new_indexes = int_indexes[start_index:end_index][::step] |
|
if do_reverse: |
|
new_indexes.reverse() |
|
chosen_indexes.extend(new_indexes) |
|
|
|
else: |
|
new_indexes = list(range(start_index, end_index, step)) |
|
if do_reverse: |
|
new_indexes.reverse() |
|
chosen_indexes.extend(new_indexes) |
|
|
|
else: |
|
chosen_indexes.append(convert_to_index_int(g, length=length, allow_negative=allow_negative, allow_missing=allow_missing, allow_decimal=allow_decimal)) |
|
return chosen_indexes |
|
|
|
|
|
def select_indexes(input_obj: Union[Tensor, list], idxs: list): |
|
if type(input_obj) == Tensor: |
|
return input_obj[idxs] |
|
else: |
|
return [input_obj[i] for i in idxs] |
|
|
|
|
|
def select_indexes_from_str(input_obj: Union[Tensor, list], indexes: str, allow_range=True, err_if_missing=True, err_if_empty=True): |
|
real_idxs = convert_str_to_indexes(indexes, len(input_obj), allow_range=allow_range, allow_missing=not err_if_missing) |
|
if err_if_empty and len(real_idxs) == 0: |
|
raise Exception(f"Nothing was selected based on indexes found in '{indexes}'.") |
|
return select_indexes(input_obj, real_idxs) |
|
|