|
|
|
|
|
from typing import Union |
|
|
|
import os |
|
import torch |
|
import torch as th |
|
import torch.nn as nn |
|
from torch import Tensor |
|
from collections import OrderedDict |
|
|
|
|
|
from comfy.ldm.modules.diffusionmodules.util import (zero_module, timestep_embedding) |
|
|
|
from comfy.cldm.cldm import ControlNet as ControlNetCLDM |
|
import comfy.cldm.cldm |
|
from comfy.controlnet import ControlNet |
|
|
|
from comfy.ldm.modules.attention import optimized_attention |
|
import comfy.ops |
|
import comfy.model_management |
|
import comfy.model_detection |
|
import comfy.utils |
|
|
|
from .utils import (AdvancedControlBase, ControlWeights, ControlWeightType, TimestepKeyframeGroup, AbstractPreprocWrapper, Extras, |
|
extend_to_batch_size, broadcast_image_to_extend) |
|
from .logger import logger |
|
|
|
|
|
class PlusPlusType: |
|
OPENPOSE = "openpose" |
|
DEPTH = "depth" |
|
THICKLINE = "hed/pidi/scribble/ted" |
|
THINLINE = "canny/lineart/mlsd" |
|
NORMAL = "normal" |
|
SEGMENT = "segment" |
|
TILE = "tile" |
|
REPAINT = "inpaint/outpaint" |
|
NONE = "none" |
|
_LIST_WITH_NONE = [OPENPOSE, DEPTH, THICKLINE, THINLINE, NORMAL, SEGMENT, TILE, REPAINT, NONE] |
|
_LIST = [OPENPOSE, DEPTH, THICKLINE, THINLINE, NORMAL, SEGMENT, TILE, REPAINT] |
|
_DICT = {OPENPOSE: 0, DEPTH: 1, THICKLINE: 2, THINLINE: 3, NORMAL: 4, SEGMENT: 5, TILE: 6, REPAINT: 7, NONE: -1} |
|
|
|
@classmethod |
|
def to_idx(cls, control_type: str): |
|
try: |
|
return cls._DICT[control_type] |
|
except KeyError: |
|
raise Exception(f"Unknown control type '{control_type}'.") |
|
|
|
|
|
class PlusPlusInput: |
|
def __init__(self, image: Tensor, control_type: str, strength: float): |
|
self.image = image |
|
self.control_type = control_type |
|
self.strength = strength |
|
|
|
def clone(self): |
|
return PlusPlusInput(self.image, self.control_type, self.strength) |
|
|
|
|
|
class PlusPlusInputGroup: |
|
def __init__(self): |
|
self.controls: dict[str, PlusPlusInput] = {} |
|
|
|
def add(self, pp_input: PlusPlusInput): |
|
if pp_input.control_type in self.controls: |
|
raise Exception(f"Control type '{pp_input.control_type}' is already present; ControlNet++ does not allow more than 1 of each type.") |
|
self.controls[pp_input.control_type] = pp_input |
|
|
|
def clone(self) -> 'PlusPlusInputGroup': |
|
cloned = PlusPlusInputGroup() |
|
for key, value in self.controls.items(): |
|
cloned.controls[key] = value.clone() |
|
return cloned |
|
|
|
|
|
class PlusPlusImageWrapper(AbstractPreprocWrapper): |
|
error_msg = error_msg = "Invalid use of ControlNet++ Image Wrapper. The output of ControlNet++ Image Wrapper is NOT a usual image, but an object holding the images and extra info - you must connect the output directly to an Apply Advanced ControlNet node. It cannot be used for anything else that accepts IMAGE input." |
|
def __init__(self, condhint: PlusPlusInputGroup): |
|
super().__init__(condhint) |
|
|
|
self.condhint: PlusPlusInputGroup |
|
|
|
def movedim(self, source: int, destination: int): |
|
condhint = self.condhint.clone() |
|
for pp_input in condhint.controls.values(): |
|
pp_input.image = pp_input.image.movedim(source, destination) |
|
return PlusPlusImageWrapper(condhint) |
|
|
|
|
|
class OptimizedAttention(nn.Module): |
|
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): |
|
super().__init__() |
|
self.heads = nhead |
|
self.c = c |
|
|
|
self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device) |
|
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device) |
|
|
|
def forward(self, x): |
|
x = self.in_proj(x) |
|
q, k, v = x.split(self.c, dim=2) |
|
out = optimized_attention(q, k, v, self.heads) |
|
return self.out_proj(out) |
|
|
|
class QuickGELU(nn.Module): |
|
def forward(self, x: torch.Tensor): |
|
return x * torch.sigmoid(1.702 * x) |
|
|
|
class ResBlockUnionControlnet(nn.Module): |
|
def __init__(self, dim, nhead, dtype=None, device=None, operations=None): |
|
super().__init__() |
|
self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations) |
|
self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device) |
|
self.mlp = nn.Sequential( |
|
OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()), |
|
("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))])) |
|
self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device) |
|
|
|
def attention(self, x: torch.Tensor): |
|
return self.attn(x) |
|
|
|
def forward(self, x: torch.Tensor): |
|
x = x + self.attention(self.ln_1(x)) |
|
x = x + self.mlp(self.ln_2(x)) |
|
return x |
|
|
|
|
|
class ControlAddEmbeddingAdv(nn.Module): |
|
def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations: comfy.ops.disable_weight_init=None): |
|
super().__init__() |
|
self.num_control_type = num_control_type |
|
self.in_dim = in_dim |
|
self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device) |
|
self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device) |
|
|
|
def forward(self, control_type, dtype, device): |
|
if control_type is None: |
|
control_type = torch.zeros((self.num_control_type,), device=device) |
|
c_type = timestep_embedding(control_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim)) |
|
return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type))) |
|
|
|
|
|
class ControlNetPlusPlus(ControlNetCLDM): |
|
def __init__(self, *args,**kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
operations: comfy.ops.disable_weight_init = kwargs.get("operations", comfy.ops.disable_weight_init) |
|
device = kwargs.get("device", None) |
|
|
|
time_embed_dim = self.model_channels * 4 |
|
control_add_embed_dim = 256 |
|
|
|
self.control_add_embedding = ControlAddEmbeddingAdv(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations) |
|
|
|
def union_controlnet_merge(self, hint: list[Tensor], control_type, emb, context): |
|
|
|
indexes = torch.nonzero(control_type[0]) |
|
inputs = [] |
|
condition_list = [] |
|
|
|
for idx in range(indexes.shape[0]): |
|
controlnet_cond = self.input_hint_block(hint[indexes[idx][0]], emb, context) |
|
feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) |
|
if idx < indexes.shape[0]: |
|
feat_seq += self.task_embedding[indexes[idx][0]].to(dtype=feat_seq.dtype, device=feat_seq.device) |
|
|
|
inputs.append(feat_seq.unsqueeze(1)) |
|
condition_list.append(controlnet_cond) |
|
|
|
x = torch.cat(inputs, dim=1) |
|
x = self.transformer_layes(x) |
|
|
|
controlnet_cond_fuser = None |
|
for idx in range(indexes.shape[0]): |
|
alpha = self.spatial_ch_projs(x[:, idx]) |
|
alpha = alpha.unsqueeze(-1).unsqueeze(-1) |
|
o = condition_list[idx] + alpha |
|
if controlnet_cond_fuser is None: |
|
controlnet_cond_fuser = o |
|
else: |
|
controlnet_cond_fuser += o |
|
return controlnet_cond_fuser |
|
|
|
def forward(self, x: Tensor, hint: list[Tensor], timesteps, context, y: Tensor=None, **kwargs): |
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) |
|
emb = self.time_embed(t_emb) |
|
|
|
guided_hint = None |
|
if self.control_add_embedding is not None: |
|
control_type = kwargs.get("control_type", None) |
|
|
|
emb += self.control_add_embedding(control_type, emb.dtype, emb.device) |
|
if control_type is not None: |
|
guided_hint = self.union_controlnet_merge(hint, control_type, emb, context) |
|
|
|
if guided_hint is None: |
|
guided_hint = self.input_hint_block(hint[0], emb, context) |
|
|
|
out_output = [] |
|
out_middle = [] |
|
|
|
hs = [] |
|
if self.num_classes is not None: |
|
assert y.shape[0] == x.shape[0] |
|
emb = emb + self.label_emb(y) |
|
|
|
h = x |
|
for module, zero_conv in zip(self.input_blocks, self.zero_convs): |
|
if guided_hint is not None: |
|
h = module(h, emb, context) |
|
h += guided_hint |
|
guided_hint = None |
|
else: |
|
h = module(h, emb, context) |
|
out_output.append(zero_conv(h, emb, context)) |
|
|
|
h = self.middle_block(h, emb, context) |
|
out_middle.append(self.middle_block_out(h, emb, context)) |
|
|
|
return {"middle": out_middle, "output": out_output} |
|
|
|
|
|
class ControlNetPlusPlusAdvanced(ControlNet, AdvancedControlBase): |
|
def __init__(self, control_model: ControlNetPlusPlus, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, load_device=None, manual_cast_dtype=None): |
|
super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) |
|
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controlnet()) |
|
self.add_compatible_weight(ControlWeightType.CONTROLNETPLUSPLUS) |
|
|
|
self.control_model: ControlNetPlusPlus |
|
self.cond_hint_original: Union[PlusPlusImageWrapper, PlusPlusInputGroup] |
|
self.cond_hint: list[Union[Tensor, None]] |
|
self.cond_hint_shape: Tensor = None |
|
self.cond_hint_types: Tensor = None |
|
|
|
self.single_control_type: str = None |
|
|
|
def get_universal_weights(self) -> ControlWeights: |
|
def cn_weights_func(idx: int, control: dict[str, list[Tensor]], key: str): |
|
if key == "middle": |
|
return 1.0 * self.weights.extras.get(Extras.MIDDLE_MULT, 1.0) |
|
c_len = len(control[key]) |
|
raw_weights = [(self.weights.base_multiplier ** float((c_len) - i)) for i in range(c_len+1)] |
|
raw_weights = raw_weights[:-1] |
|
if key == "input": |
|
raw_weights.reverse() |
|
return raw_weights[idx] |
|
return self.weights.copy_with_new_weights(new_weight_func=cn_weights_func) |
|
|
|
def verify_control_type(self, model_name: str, pp_group: PlusPlusInputGroup=None): |
|
if pp_group is not None: |
|
for pp_input in pp_group.controls.values(): |
|
if PlusPlusType.to_idx(pp_input.control_type) >= self.control_model.num_control_type: |
|
raise Exception(f"ControlNet++ model '{model_name}' does not support control_type '{pp_input.control_type}'.") |
|
if self.single_control_type is not None: |
|
if PlusPlusType.to_idx(self.single_control_type) >= self.control_model.num_control_type: |
|
raise Exception(f"ControlNet++ model '{model_name}' does not support control_type '{self.single_control_type}'.") |
|
|
|
def set_cond_hint_inject(self, *args, **kwargs): |
|
to_return = super().set_cond_hint_inject(*args, **kwargs) |
|
|
|
if self.single_control_type is None: |
|
|
|
if type(self.cond_hint_original) != PlusPlusImageWrapper: |
|
raise Exception("ControlNet++ (Multi) expects image input from the Load ControlNet++ Model node, NOT from anything else. Images are provided to that node via ControlNet++ Input nodes.") |
|
self.cond_hint_original = self.cond_hint_original.condhint.clone() |
|
|
|
else: |
|
|
|
if type(self.cond_hint_original) == PlusPlusImageWrapper: |
|
raise Exception("ControlNet++ (Single) expects usual image input, NOT the image input from a Load ControlNet++ Model (Multi) node.") |
|
pp_group = PlusPlusInputGroup() |
|
pp_input = PlusPlusInput(self.cond_hint_original, self.single_control_type, 1.0) |
|
pp_group.add(pp_input) |
|
self.cond_hint_original = pp_group |
|
return to_return |
|
|
|
def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number, transformer_options): |
|
control_prev = None |
|
if self.previous_controlnet is not None: |
|
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options) |
|
|
|
if self.timestep_range is not None: |
|
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: |
|
if control_prev is not None: |
|
return control_prev |
|
else: |
|
return None |
|
|
|
dtype = self.control_model.dtype |
|
if self.manual_cast_dtype is not None: |
|
dtype = self.manual_cast_dtype |
|
|
|
output_dtype = x_noisy.dtype |
|
|
|
|
|
|
|
if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint_shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint_shape[3]: |
|
if self.cond_hint is not None: |
|
del self.cond_hint |
|
self.cond_hint = [None] * self.control_model.num_control_type |
|
self.cond_hint_types = torch.tensor([0.0] * self.control_model.num_control_type) |
|
self.cond_hint_shape = None |
|
compression_ratio = self.compression_ratio |
|
|
|
for pp_type, pp_input in self.cond_hint_original.controls.items(): |
|
pp_idx = PlusPlusType.to_idx(pp_type) |
|
|
|
if pp_idx < 0: |
|
pp_idx = 0 |
|
else: |
|
self.cond_hint_types[pp_idx] = pp_input.strength |
|
|
|
if self.sub_idxs is not None: |
|
actual_cond_hint_orig = pp_input.image |
|
if pp_input.image.size(0) < self.full_latent_length: |
|
actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length) |
|
self.cond_hint[pp_idx] = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, 'nearest-exact', "center") |
|
else: |
|
self.cond_hint[pp_idx] = comfy.utils.common_upscale(pp_input.image, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, 'nearest-exact', "center") |
|
self.cond_hint[pp_idx] = self.cond_hint[pp_idx].to(device=x_noisy.device, dtype=dtype) |
|
self.cond_hint_shape = self.cond_hint[pp_idx].shape |
|
|
|
if self.cond_hint_types.count_nonzero() == 0: |
|
self.cond_hint_types = None |
|
else: |
|
self.cond_hint_types = self.cond_hint_types.unsqueeze(0).to(device=x_noisy.device, dtype=dtype).repeat(x_noisy.shape[0], 1) |
|
for i in range(len(self.cond_hint)): |
|
if self.cond_hint[i] is not None: |
|
if x_noisy.shape[0] != self.cond_hint[i].shape[0]: |
|
self.cond_hint[i] = broadcast_image_to_extend(self.cond_hint[i], x_noisy.shape[0], batched_number) |
|
if self.cond_hint_types is not None and x_noisy.shape[0] != self.cond_hint_types.shape[0]: |
|
self.cond_hint_types = broadcast_image_to_extend(self.cond_hint_types, x_noisy.shape[0], batched_number, False) |
|
|
|
|
|
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype) |
|
|
|
context = cond.get('crossattn_controlnet', cond['c_crossattn']) |
|
y = cond.get('y', None) |
|
if y is not None: |
|
y = y.to(dtype) |
|
timestep = self.model_sampling_current.timestep(t) |
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) |
|
|
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, control_type=self.cond_hint_types) |
|
return self.control_merge(control, control_prev, output_dtype) |
|
|
|
def copy(self): |
|
c = ControlNetPlusPlusAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) |
|
self.copy_to(c) |
|
self.copy_to_advanced(c) |
|
c.single_control_type = self.single_control_type |
|
return c |
|
|
|
|
|
def load_controlnetplusplus(ckpt_path: str, timestep_keyframe: TimestepKeyframeGroup=None, model=None): |
|
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) |
|
|
|
if "task_embedding" not in controlnet_data: |
|
raise Exception(f"'{ckpt_path}' is not a valid ControlNet++ model.") |
|
|
|
controlnet_config = None |
|
supported_inference_dtypes = None |
|
|
|
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: |
|
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data) |
|
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config) |
|
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" |
|
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" |
|
|
|
count = 0 |
|
loop = True |
|
while loop: |
|
suffix = [".weight", ".bias"] |
|
for s in suffix: |
|
k_in = "controlnet_down_blocks.{}{}".format(count, s) |
|
k_out = "zero_convs.{}.0{}".format(count, s) |
|
if k_in not in controlnet_data: |
|
loop = False |
|
break |
|
diffusers_keys[k_in] = k_out |
|
count += 1 |
|
|
|
count = 0 |
|
loop = True |
|
while loop: |
|
suffix = [".weight", ".bias"] |
|
for s in suffix: |
|
if count == 0: |
|
k_in = "controlnet_cond_embedding.conv_in{}".format(s) |
|
else: |
|
k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s) |
|
k_out = "input_hint_block.{}{}".format(count * 2, s) |
|
if k_in not in controlnet_data: |
|
k_in = "controlnet_cond_embedding.conv_out{}".format(s) |
|
loop = False |
|
diffusers_keys[k_in] = k_out |
|
count += 1 |
|
|
|
new_sd = {} |
|
for k in diffusers_keys: |
|
if k in controlnet_data: |
|
new_sd[diffusers_keys[k]] = controlnet_data.pop(k) |
|
|
|
if "control_add_embedding.linear_1.bias" in controlnet_data: |
|
controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0] |
|
for k in list(controlnet_data.keys()): |
|
new_k = k.replace('.attn.in_proj_', '.attn.in_proj.') |
|
new_sd[new_k] = controlnet_data.pop(k) |
|
|
|
leftover_keys = controlnet_data.keys() |
|
if len(leftover_keys) > 0: |
|
logger.warning("leftover ControlNet++ keys: {}".format(leftover_keys)) |
|
controlnet_data = new_sd |
|
elif "controlnet_blocks.0.weight" in controlnet_data: |
|
raise Exception("Unexpected SD3 diffusers format for ControlNet++ model. Something is very wrong.") |
|
|
|
pth_key = 'control_model.zero_convs.0.0.weight' |
|
pth = False |
|
key = 'zero_convs.0.0.weight' |
|
if pth_key in controlnet_data: |
|
pth = True |
|
key = pth_key |
|
prefix = "control_model." |
|
elif key in controlnet_data: |
|
prefix = "" |
|
else: |
|
raise Exception("Unexpected T2IAdapter format for ControlNet++ model. Something is very wrong.") |
|
|
|
if controlnet_config is None: |
|
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True) |
|
supported_inference_dtypes = model_config.supported_inference_dtypes |
|
controlnet_config = model_config.unet_config |
|
|
|
load_device = comfy.model_management.get_torch_device() |
|
if supported_inference_dtypes is None: |
|
unet_dtype = comfy.model_management.unet_dtype() |
|
else: |
|
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes) |
|
|
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) |
|
if manual_cast_dtype is not None: |
|
controlnet_config["operations"] = comfy.ops.manual_cast |
|
controlnet_config["dtype"] = unet_dtype |
|
controlnet_config.pop("out_channels") |
|
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] |
|
control_model = ControlNetPlusPlus(**controlnet_config) |
|
|
|
if pth: |
|
if 'difference' in controlnet_data: |
|
if model is not None: |
|
comfy.model_management.load_models_gpu([model]) |
|
model_sd = model.model_state_dict() |
|
for x in controlnet_data: |
|
c_m = "control_model." |
|
if x.startswith(c_m): |
|
sd_key = "diffusion_model.{}".format(x[len(c_m):]) |
|
if sd_key in model_sd: |
|
cd = controlnet_data[x] |
|
cd += model_sd[sd_key].type(cd.dtype).to(cd.device) |
|
else: |
|
logger.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") |
|
|
|
class WeightsLoader(torch.nn.Module): |
|
pass |
|
w = WeightsLoader() |
|
w.control_model = control_model |
|
missing, unexpected = w.load_state_dict(controlnet_data, strict=False) |
|
else: |
|
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) |
|
|
|
if len(missing) > 0: |
|
logger.warning("missing ControlNet++ keys: {}".format(missing)) |
|
|
|
if len(unexpected) > 0: |
|
logger.debug("unexpected ControlNet++ keys: {}".format(unexpected)) |
|
|
|
global_average_pooling = False |
|
filename = os.path.splitext(ckpt_path)[0] |
|
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): |
|
global_average_pooling = True |
|
|
|
control = ControlNetPlusPlusAdvanced(control_model, timestep_keyframes=timestep_keyframe, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) |
|
return control |
|
|