daquanzhou
merge github repos and lfs track ckpt/path/safetensors/pt
613c9ab
raw
history blame
5.47 kB
from collections.abc import Iterable
from typing import Union
import torch
from torch import Tensor
from .utils_motion import linear_conversion, normalize_min_max, extend_to_batch_size
class ScaleType:
ABSOLUTE = "absolute"
RELATIVE = "relative"
LIST = [ABSOLUTE, RELATIVE]
class MultivalDynamicNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"float_val": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001},),
},
"optional": {
"mask_optional": ("MASK",)
}
}
RETURN_TYPES = ("MULTIVAL",)
CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/multival"
FUNCTION = "create_multival"
def create_multival(self, float_val: Union[float, list[float]]=1.0, mask_optional: Tensor=None):
# first, normalize inputs
# if float_val is iterable, treat as a list and assume inputs are floats
float_is_iterable = False
if isinstance(float_val, Iterable):
float_is_iterable = True
float_val = list(float_val)
# if mask present, make sure float_val list can be applied to list - match lengths
if mask_optional is not None:
if len(float_val) < mask_optional.shape[0]:
# copies last entry enough times to match mask shape
float_val = float_val + float_val[-1]*(mask_optional.shape[0]-len(float_val))
if mask_optional.shape[0] < len(float_val):
mask_optional = extend_to_batch_size(mask_optional, len(float_val))
float_val = float_val[:mask_optional.shape[0]]
float_val: Tensor = torch.tensor(float_val).unsqueeze(-1).unsqueeze(-1)
# now that inputs are normalized, figure out what value to actually return
if mask_optional is not None:
mask_optional = mask_optional.clone()
if float_is_iterable:
mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device)
else:
mask_optional = mask_optional * float_val
return (mask_optional,)
else:
if not float_is_iterable:
return (float_val,)
# create a dummy mask of b,h,w=float_len,1,1 (sigle pixel)
# purpose is for float input to work with mask code, without special cases
float_len = float_val.shape[0] if float_is_iterable else 1
shape = (float_len,1,1)
mask_optional = torch.ones(shape)
mask_optional = mask_optional[:] * float_val.to(mask_optional.dtype).to(mask_optional.device)
return (mask_optional,)
class MultivalScaledMaskNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"min_float_val": ("FLOAT", {"default": 0.0, "min": 0.0, "step": 0.001}),
"max_float_val": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.001}),
"mask": ("MASK",),
},
"optional": {
"scaling": (ScaleType.LIST,),
}
}
RETURN_TYPES = ("MULTIVAL",)
CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/multival"
FUNCTION = "create_multival"
def create_multival(self, min_float_val: float, max_float_val: float, mask: Tensor, scaling: str=ScaleType.ABSOLUTE):
# TODO: allow min_float_val and max_float_val to be list[float]
if isinstance(min_float_val, Iterable):
raise ValueError(f"min_float_val must be type float (no lists allowed here), not {type(min_float_val).__name__}.")
if isinstance(max_float_val, Iterable):
raise ValueError(f"max_float_val must be type float (no lists allowed here), not {type(max_float_val).__name__}.")
if scaling == ScaleType.ABSOLUTE:
mask = linear_conversion(mask.clone(), new_min=min_float_val, new_max=max_float_val)
elif scaling == ScaleType.RELATIVE:
mask = normalize_min_max(mask.clone(), new_min=min_float_val, new_max=max_float_val)
else:
raise ValueError(f"scaling '{scaling}' not recognized.")
return MultivalDynamicNode.create_multival(self, mask_optional=mask)
class MultivalDynamicFloatInputNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"float_val": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "forceInput": True},),
},
"optional": {
"mask_optional": ("MASK",)
}
}
RETURN_TYPES = ("MULTIVAL",)
CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/multival"
FUNCTION = "create_multival"
def create_multival(self, float_val: Union[float, list[float]]=None, mask_optional: Tensor=None):
return MultivalDynamicNode.create_multival(self, float_val=float_val, mask_optional=mask_optional)
class MultivalFloatNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"float_val": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001},),
},
}
RETURN_TYPES = ("MULTIVAL",)
CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/multival"
FUNCTION = "create_multival"
def create_multival(self, float_val: Union[float, list[float]]=None):
return MultivalDynamicNode.create_multival(self, float_val=float_val)