|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod |
|
import numpy as np |
|
import torch |
|
from torch import Tensor |
|
|
|
from comfy.ldm.modules.diffusionmodules.util import ( |
|
zero_module, |
|
timestep_embedding, |
|
) |
|
|
|
from comfy.cldm.cldm import ControlNet as ControlNetCLDM |
|
from comfy.ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential |
|
from comfy.model_patcher import ModelPatcher |
|
from comfy.patcher_extension import PatcherInjection |
|
|
|
from .dinklink import (InterfaceAnimateDiffInfo, InterfaceAnimateDiffModel, |
|
get_CreateMotionModelPatcher, get_AnimateDiffModel, get_AnimateDiffInfo) |
|
from .logger import logger |
|
from .utils import (BIGMAX, AbstractPreprocWrapper, disable_weight_init_clean_groupnorm, WrapperConsts) |
|
|
|
|
|
class SparseMotionModelPatcher(ModelPatcher): |
|
'''Class only used for IDE type hints.''' |
|
def __init__(self, *args, **kwargs): |
|
self.model = InterfaceAnimateDiffModel |
|
|
|
|
|
class SparseConst: |
|
HINT_MULT = "sparse_hint_mult" |
|
NONHINT_MULT = "sparse_nonhint_mult" |
|
MASK_MULT = "sparse_mask_mult" |
|
|
|
|
|
class SparseControlNet(ControlNetCLDM): |
|
def __init__(self, *args,**kwargs): |
|
super().__init__(*args, **kwargs) |
|
hint_channels = kwargs.get("hint_channels") |
|
operations: disable_weight_init_clean_groupnorm = kwargs.get("operations", disable_weight_init_clean_groupnorm) |
|
device = kwargs.get("device", None) |
|
self.use_simplified_conditioning_embedding = kwargs.get("use_simplified_conditioning_embedding", False) |
|
if self.use_simplified_conditioning_embedding: |
|
self.input_hint_block = TimestepEmbedSequential( |
|
zero_module(operations.conv_nd(self.dims, hint_channels, self.model_channels, 3, padding=1, dtype=self.dtype, device=device)), |
|
) |
|
|
|
def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs): |
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) |
|
emb = self.time_embed(t_emb) |
|
|
|
|
|
x = torch.zeros_like(x) |
|
guided_hint = self.input_hint_block(hint, emb, context) |
|
|
|
out_output = [] |
|
out_middle = [] |
|
|
|
hs = [] |
|
if self.num_classes is not None: |
|
assert y.shape[0] == x.shape[0] |
|
emb = emb + self.label_emb(y) |
|
|
|
h = x |
|
for module, zero_conv in zip(self.input_blocks, self.zero_convs): |
|
if guided_hint is not None: |
|
h = module(h, emb, context) |
|
h += guided_hint |
|
guided_hint = None |
|
else: |
|
h = module(h, emb, context) |
|
out_output.append(zero_conv(h, emb, context)) |
|
|
|
h = self.middle_block(h, emb, context) |
|
out_middle.append(self.middle_block_out(h, emb, context)) |
|
|
|
return {"middle": out_middle, "output": out_output} |
|
|
|
|
|
def load_sparsectrl_motionmodel(ckpt_path: str, motion_data: dict[str, Tensor], ops=None) -> InterfaceAnimateDiffModel: |
|
mm_info: InterfaceAnimateDiffInfo = get_AnimateDiffInfo()("SD1.5", "AnimateDiff", "v3", ckpt_path) |
|
init_kwargs = { |
|
"ops": ops, |
|
"get_unet_func": _get_unet_func, |
|
} |
|
motion_model: InterfaceAnimateDiffModel = get_AnimateDiffModel()(mm_state_dict=motion_data, mm_info=mm_info, init_kwargs=init_kwargs) |
|
missing, unexpected = motion_model.load_state_dict(motion_data) |
|
if len(missing) > 0 or len(unexpected) > 0: |
|
logger.info(f"SparseCtrl MotionModel: {missing}, {unexpected}") |
|
return motion_model |
|
|
|
|
|
def create_sparse_modelpatcher(model, motion_model, load_device, offload_device): |
|
patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device) |
|
if motion_model is not None: |
|
_motionpatcher = _create_sparse_motionmodelpatcher(motion_model, load_device, offload_device) |
|
patcher.set_additional_models(WrapperConsts.ACN, [_motionpatcher]) |
|
patcher.set_injections(WrapperConsts.ACN, |
|
[PatcherInjection(inject=_inject_motion_models, eject=_eject_motion_models)]) |
|
return patcher |
|
|
|
def _create_sparse_motionmodelpatcher(motion_model, load_device, offload_device) -> SparseMotionModelPatcher: |
|
return get_CreateMotionModelPatcher()(motion_model, load_device, offload_device) |
|
|
|
|
|
def _inject_motion_models(patcher: ModelPatcher): |
|
motion_models: list[SparseMotionModelPatcher] = patcher.get_additional_models_with_key(WrapperConsts.ACN) |
|
for mm in motion_models: |
|
mm.model.inject(patcher) |
|
|
|
def _eject_motion_models(patcher: ModelPatcher): |
|
motion_models: list[SparseMotionModelPatcher] = patcher.get_additional_models_with_key(WrapperConsts.ACN) |
|
for mm in motion_models: |
|
mm.model.eject(patcher) |
|
|
|
def _get_unet_func(wrapper, model: ModelPatcher): |
|
return model.model |
|
|
|
|
|
class PreprocSparseRGBWrapper(AbstractPreprocWrapper): |
|
error_msg = error_msg = "Invalid use of RGB SparseCtrl output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input." |
|
def __init__(self, condhint: Tensor): |
|
super().__init__(condhint) |
|
|
|
|
|
class SparseContextAware: |
|
NEAREST_HINT = "nearest_hint" |
|
OFF = "off" |
|
|
|
LIST = [NEAREST_HINT, OFF] |
|
|
|
|
|
class SparseSettings: |
|
def __init__(self, sparse_method: 'SparseMethod', use_motion: bool=True, motion_strength=1.0, motion_scale=1.0, merged=False, |
|
sparse_mask_mult=1.0, sparse_hint_mult=1.0, sparse_nonhint_mult=1.0, context_aware=SparseContextAware.NEAREST_HINT): |
|
|
|
|
|
if type(sparse_method) == str: |
|
logger.warn("Outdated Steerable-Motion workflow detected; attempting to auto-convert indexes input. If you experience an error here, consult Steerable-Motion github, NOT Advanced-ControlNet.") |
|
sparse_method = SparseIndexMethod(get_idx_list_from_str(sparse_method)) |
|
self.sparse_method = sparse_method |
|
self.use_motion = use_motion |
|
self.motion_strength = motion_strength |
|
self.motion_scale = motion_scale |
|
self.merged = merged |
|
self.sparse_mask_mult = float(sparse_mask_mult) |
|
self.sparse_hint_mult = float(sparse_hint_mult) |
|
self.sparse_nonhint_mult = float(sparse_nonhint_mult) |
|
self.context_aware = context_aware |
|
|
|
def is_context_aware(self): |
|
return self.context_aware != SparseContextAware.OFF |
|
|
|
@classmethod |
|
def default(cls): |
|
return SparseSettings(sparse_method=SparseSpreadMethod(), use_motion=True) |
|
|
|
|
|
class SparseMethod(ABC): |
|
SPREAD = "spread" |
|
INDEX = "index" |
|
def __init__(self, method: str): |
|
self.method = method |
|
|
|
@abstractmethod |
|
def _get_indexes(self, hint_length: int, full_length: int) -> list[int]: |
|
pass |
|
|
|
def get_indexes(self, hint_length: int, full_length: int, sub_idxs: list[int]=None) -> tuple[list[int], list[int]]: |
|
returned_idxs = self._get_indexes(hint_length, full_length) |
|
if sub_idxs is None: |
|
return returned_idxs, None |
|
|
|
index_mapping = {} |
|
for i, value in enumerate(returned_idxs): |
|
index_mapping[value] = i |
|
def get_mapped_idxs(idxs: list[int]): |
|
return [index_mapping[idx] for idx in idxs] |
|
|
|
fitting_idxs = [] |
|
for sub_idx in sub_idxs: |
|
if sub_idx in returned_idxs: |
|
fitting_idxs.append(sub_idx) |
|
|
|
if len(fitting_idxs) > 0: |
|
return fitting_idxs, get_mapped_idxs(fitting_idxs) |
|
|
|
|
|
def get_closest_idx(target_idx: int, idxs: list[int]): |
|
min_idx = -1 |
|
min_dist = BIGMAX |
|
for idx in idxs: |
|
new_dist = abs(idx-target_idx) |
|
if new_dist < min_dist: |
|
min_idx = idx |
|
min_dist = new_dist |
|
if min_dist == 1: |
|
return min_idx, min_dist |
|
return min_idx, min_dist |
|
start_closest_idx, start_dist = get_closest_idx(sub_idxs[0], returned_idxs) |
|
end_closest_idx, end_dist = get_closest_idx(sub_idxs[-1], returned_idxs) |
|
|
|
if hint_length == 1: |
|
|
|
if start_dist == end_dist: |
|
|
|
center_idx = sub_idxs[np.linspace(0, len(sub_idxs)-1, 3, endpoint=True, dtype=int)[1]] |
|
return [center_idx], get_mapped_idxs([start_closest_idx]) |
|
|
|
if start_dist < end_dist: |
|
return [sub_idxs[0]], get_mapped_idxs([start_closest_idx]) |
|
return [sub_idxs[-1]], get_mapped_idxs([end_closest_idx]) |
|
|
|
|
|
if start_dist == end_dist: |
|
return [sub_idxs[0], sub_idxs[-1]], get_mapped_idxs([start_closest_idx, end_closest_idx]) |
|
|
|
if start_dist < end_dist: |
|
return [sub_idxs[0]], get_mapped_idxs([start_closest_idx]) |
|
return [sub_idxs[-1]], get_mapped_idxs([end_closest_idx]) |
|
|
|
|
|
class SparseSpreadMethod(SparseMethod): |
|
UNIFORM = "uniform" |
|
STARTING = "starting" |
|
ENDING = "ending" |
|
CENTER = "center" |
|
|
|
LIST = [UNIFORM, STARTING, ENDING, CENTER] |
|
|
|
def __init__(self, spread=UNIFORM): |
|
super().__init__(self.SPREAD) |
|
self.spread = spread |
|
|
|
def _get_indexes(self, hint_length: int, full_length: int) -> list[int]: |
|
|
|
if hint_length >= full_length: |
|
return list(range(full_length)) |
|
|
|
if hint_length == 1: |
|
if self.spread in [self.UNIFORM, self.STARTING]: |
|
return [0] |
|
elif self.spread == self.ENDING: |
|
return [full_length-1] |
|
elif self.spread == self.CENTER: |
|
|
|
return [np.linspace(0, full_length-1, 3, endpoint=True, dtype=int)[1]] |
|
else: |
|
raise ValueError(f"Unrecognized spread: {self.spread}") |
|
|
|
if self.spread == self.UNIFORM: |
|
return list(np.linspace(0, full_length-1, hint_length, endpoint=True, dtype=int)) |
|
elif self.spread == self.STARTING: |
|
|
|
return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[:-1] |
|
elif self.spread == self.ENDING: |
|
|
|
return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[1:] |
|
elif self.spread == self.CENTER: |
|
|
|
if full_length-hint_length < 3: |
|
return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[:-1] |
|
|
|
return list(np.linspace(0, full_length-1, hint_length+2, endpoint=True, dtype=int))[1:-1] |
|
return ValueError(f"Unrecognized spread: {self.spread}") |
|
|
|
|
|
class SparseIndexMethod(SparseMethod): |
|
def __init__(self, idxs: list[int]): |
|
super().__init__(self.INDEX) |
|
self.idxs = idxs |
|
|
|
def _get_indexes(self, hint_length: int, full_length: int) -> list[int]: |
|
orig_hint_length = hint_length |
|
if hint_length > full_length: |
|
hint_length = full_length |
|
|
|
if len(self.idxs) < hint_length: |
|
err_msg = f"There are not enough indexes ({len(self.idxs)}) provided to fit the usable {hint_length} input images." |
|
if orig_hint_length != hint_length: |
|
err_msg = f"{err_msg} (original input images: {orig_hint_length})" |
|
raise ValueError(err_msg) |
|
|
|
idxs = self.idxs[:hint_length] |
|
new_idxs = [] |
|
real_idxs = set() |
|
for idx in idxs: |
|
if idx < 0: |
|
real_idx = full_length+idx |
|
if real_idx in real_idxs: |
|
raise ValueError(f"Index '{idx}' maps to '{real_idx}' and is duplicate - indexes in Sparse Index Method must be unique.") |
|
else: |
|
real_idx = idx |
|
if real_idx in real_idxs: |
|
raise ValueError(f"Index '{idx}' is duplicate (or a negative index is equivalent) - indexes in Sparse Index Method must be unique.") |
|
real_idxs.add(real_idx) |
|
new_idxs.append(real_idx) |
|
return new_idxs |
|
|
|
|
|
def get_idx_list_from_str(indexes: str) -> list[int]: |
|
idxs = [] |
|
unique_idxs = set() |
|
|
|
str_idxs = [x.strip() for x in indexes.strip().split(",")] |
|
for str_idx in str_idxs: |
|
try: |
|
idx = int(str_idx) |
|
if idx in unique_idxs: |
|
raise ValueError(f"'{idx}' is duplicated; indexes must be unique.") |
|
idxs.append(idx) |
|
unique_idxs.add(idx) |
|
except ValueError: |
|
raise ValueError(f"'{str_idx}' is not a valid integer index.") |
|
if len(idxs) == 0: |
|
raise ValueError(f"No indexes were listed in Sparse Index Method.") |
|
return idxs |
|
|