Spaces:
Running
Running
File size: 5,470 Bytes
613c9ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from collections.abc import Iterable
from typing import Union
import torch
from torch import Tensor
from .utils_motion import linear_conversion, normalize_min_max, extend_to_batch_size
class ScaleType:
ABSOLUTE = "absolute"
RELATIVE = "relative"
LIST = [ABSOLUTE, RELATIVE]
class MultivalDynamicNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"float_val": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001},),
},
"optional": {
"mask_optional": ("MASK",)
}
}
RETURN_TYPES = ("MULTIVAL",)
CATEGORY = "Animate Diff ππ
π
/multival"
FUNCTION = "create_multival"
def create_multival(self, float_val: Union[float, list[float]]=1.0, mask_optional: Tensor=None):
# first, normalize inputs
# if float_val is iterable, treat as a list and assume inputs are floats
float_is_iterable = False
if isinstance(float_val, Iterable):
float_is_iterable = True
float_val = list(float_val)
# if mask present, make sure float_val list can be applied to list - match lengths
if mask_optional is not None:
if len(float_val) < mask_optional.shape[0]:
# copies last entry enough times to match mask shape
float_val = float_val + float_val[-1]*(mask_optional.shape[0]-len(float_val))
if mask_optional.shape[0] < len(float_val):
mask_optional = extend_to_batch_size(mask_optional, len(float_val))
float_val = float_val[:mask_optional.shape[0]]
float_val: Tensor = torch.tensor(float_val).unsqueeze(-1).unsqueeze(-1)
# now that inputs are normalized, figure out what value to actually return
if mask_optional is not None:
mask_optional = mask_optional.clone()
if float_is_iterable:
mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device)
else:
mask_optional = mask_optional * float_val
return (mask_optional,)
else:
if not float_is_iterable:
return (float_val,)
# create a dummy mask of b,h,w=float_len,1,1 (sigle pixel)
# purpose is for float input to work with mask code, without special cases
float_len = float_val.shape[0] if float_is_iterable else 1
shape = (float_len,1,1)
mask_optional = torch.ones(shape)
mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device)
return (mask_optional,)
class MultivalScaledMaskNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"min_float_val": ("FLOAT", {"default": 0.0, "min": 0.0, "step": 0.001}),
"max_float_val": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001}),
"mask": ("MASK",),
},
"optional": {
"scaling": (ScaleType.LIST,),
}
}
RETURN_TYPES = ("MULTIVAL",)
CATEGORY = "Animate Diff ππ
π
/multival"
FUNCTION = "create_multival"
def create_multival(self, min_float_val: float, max_float_val: float, mask: Tensor, scaling: str=ScaleType.ABSOLUTE):
# TODO: allow min_float_val and max_float_val to be list[float]
if isinstance(min_float_val, Iterable):
raise ValueError(f"min_float_val must be type float (no lists allowed here), not {type(min_float_val).__name__}.")
if isinstance(max_float_val, Iterable):
raise ValueError(f"max_float_val must be type float (no lists allowed here), not {type(max_float_val).__name__}.")
if scaling == ScaleType.ABSOLUTE:
mask = linear_conversion(mask.clone(), new_min=min_float_val, new_max=max_float_val)
elif scaling == ScaleType.RELATIVE:
mask = normalize_min_max(mask.clone(), new_min=min_float_val, new_max=max_float_val)
else:
raise ValueError(f"scaling '{scaling}' not recognized.")
return MultivalDynamicNode.create_multival(self, mask_optional=mask)
class MultivalDynamicFloatInputNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"float_val": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "forceInput": True},),
},
"optional": {
"mask_optional": ("MASK",)
}
}
RETURN_TYPES = ("MULTIVAL",)
CATEGORY = "Animate Diff ππ
π
/multival"
FUNCTION = "create_multival"
def create_multival(self, float_val: Union[float, list[float]]=None, mask_optional: Tensor=None):
return MultivalDynamicNode.create_multival(self, float_val=float_val, mask_optional=mask_optional)
class MultivalFloatNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"float_val": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001},),
},
}
RETURN_TYPES = ("MULTIVAL",)
CATEGORY = "Animate Diff ππ
π
/multival"
FUNCTION = "create_multival"
def create_multival(self, float_val: Union[float, list[float]]=None):
return MultivalDynamicNode.create_multival(self, float_val=float_val)
|