|
import torch |
|
from torch import Tensor |
|
|
|
from comfy.cldm.cldm import ControlNet as ControlNetCLDM |
|
import comfy.model_detection |
|
import comfy.model_management |
|
import comfy.ops |
|
import comfy.utils |
|
|
|
from comfy.ldm.modules.diffusionmodules.util import ( |
|
zero_module, |
|
timestep_embedding, |
|
) |
|
|
|
from .control import ControlNetAdvanced |
|
from .utils import TimestepKeyframeGroup |
|
from .logger import logger |
|
|
|
|
|
class ControlNetCtrLoRA(ControlNetCLDM): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
del self.input_hint_block |
|
|
|
def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs): |
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) |
|
emb = self.time_embed(t_emb) |
|
|
|
out_output = [] |
|
out_middle = [] |
|
|
|
if self.num_classes is not None: |
|
assert y.shape[0] == x.shape[0] |
|
emb = emb + self.label_emb(y) |
|
|
|
h = hint.to(dtype=x.dtype) |
|
for module, zero_conv in zip(self.input_blocks, self.zero_convs): |
|
h = module(h, emb, context) |
|
out_output.append(zero_conv(h, emb, context)) |
|
|
|
h = self.middle_block(h, emb, context) |
|
out_middle.append(self.middle_block_out(h, emb, context)) |
|
|
|
return {"middle": out_middle, "output": out_output} |
|
|
|
|
|
class CtrLoRAAdvanced(ControlNetAdvanced): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.require_vae = True |
|
self.mult_by_ratio_when_vae = False |
|
|
|
def pre_run_advanced(self, model, percent_to_timestep_function): |
|
super().pre_run_advanced(model, percent_to_timestep_function) |
|
self.latent_format = model.latent_format |
|
|
|
def cleanup_advanced(self): |
|
super().cleanup_advanced() |
|
if self.latent_format is not None: |
|
del self.latent_format |
|
self.latent_format = None |
|
|
|
def copy(self): |
|
c = CtrLoRAAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) |
|
c.control_model = self.control_model |
|
c.control_model_wrapped = self.control_model_wrapped |
|
self.copy_to(c) |
|
self.copy_to_advanced(c) |
|
return c |
|
|
|
|
|
def load_ctrlora(base_path: str, lora_path: str, |
|
base_data: dict[str, Tensor]=None, lora_data: dict[str, Tensor]=None, |
|
timestep_keyframe: TimestepKeyframeGroup=None, model=None, model_options={}): |
|
if base_data is None: |
|
base_data = comfy.utils.load_torch_file(base_path, safe_load=True) |
|
controlnet_data = base_data |
|
|
|
|
|
contains_lora_layers = False |
|
for key in base_data: |
|
if "lora_layer" in key: |
|
contains_lora_layers = True |
|
if not contains_lora_layers: |
|
raise Exception(f"File '{base_path}' is not a valid CtrLoRA base model; does not contain any lora_layer keys.") |
|
|
|
controlnet_config = None |
|
supported_inference_dtypes = None |
|
|
|
pth_key = 'control_model.zero_convs.0.0.weight' |
|
pth = False |
|
key = 'zero_convs.0.0.weight' |
|
if pth_key in controlnet_data: |
|
pth = True |
|
key = pth_key |
|
prefix = "control_model." |
|
elif key in controlnet_data: |
|
prefix = "" |
|
else: |
|
raise Exception("") |
|
net = load_t2i_adapter(controlnet_data, model_options=model_options) |
|
if net is None: |
|
logging.error("error could not detect control model type.") |
|
return net |
|
|
|
if controlnet_config is None: |
|
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True) |
|
supported_inference_dtypes = list(model_config.supported_inference_dtypes) |
|
controlnet_config = model_config.unet_config |
|
|
|
unet_dtype = model_options.get("dtype", None) |
|
if unet_dtype is None: |
|
weight_dtype = comfy.utils.weight_dtype(controlnet_data) |
|
|
|
if supported_inference_dtypes is None: |
|
supported_inference_dtypes = [comfy.model_management.unet_dtype()] |
|
|
|
if weight_dtype is not None: |
|
supported_inference_dtypes.append(weight_dtype) |
|
|
|
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes) |
|
|
|
load_device = comfy.model_management.get_torch_device() |
|
|
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) |
|
operations = model_options.get("custom_operations", None) |
|
if operations is None: |
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype) |
|
|
|
controlnet_config["operations"] = operations |
|
controlnet_config["dtype"] = unet_dtype |
|
controlnet_config["device"] = comfy.model_management.unet_offload_device() |
|
controlnet_config.pop("out_channels") |
|
controlnet_config["hint_channels"] = 3 |
|
|
|
control_model = ControlNetCtrLoRA(**controlnet_config) |
|
|
|
if pth: |
|
if 'difference' in controlnet_data: |
|
if model is not None: |
|
comfy.model_management.load_models_gpu([model]) |
|
model_sd = model.model_state_dict() |
|
for x in controlnet_data: |
|
c_m = "control_model." |
|
if x.startswith(c_m): |
|
sd_key = "diffusion_model.{}".format(x[len(c_m):]) |
|
if sd_key in model_sd: |
|
cd = controlnet_data[x] |
|
cd += model_sd[sd_key].type(cd.dtype).to(cd.device) |
|
else: |
|
logger.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") |
|
|
|
class WeightsLoader(torch.nn.Module): |
|
pass |
|
w = WeightsLoader() |
|
w.control_model = control_model |
|
missing, unexpected = w.load_state_dict(controlnet_data, strict=False) |
|
else: |
|
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) |
|
|
|
if len(missing) > 0: |
|
logger.warning("missing controlnet keys: {}".format(missing)) |
|
|
|
if len(unexpected) > 0: |
|
logger.debug("unexpected controlnet keys: {}".format(unexpected)) |
|
|
|
global_average_pooling = model_options.get("global_average_pooling", False) |
|
control = CtrLoRAAdvanced(control_model, timestep_keyframe, global_average_pooling=global_average_pooling, |
|
load_device=load_device, manual_cast_dtype=manual_cast_dtype) |
|
|
|
if lora_path is not None: |
|
load_lora_data(control, lora_path) |
|
|
|
return control |
|
|
|
|
|
def load_lora_data(control: CtrLoRAAdvanced, lora_path: str, loaded_data: dict[str, Tensor]=None, lora_strength=1.0): |
|
if loaded_data is None: |
|
loaded_data = comfy.utils.load_torch_file(lora_path, safe_load=True) |
|
|
|
contains_lora_layers = False |
|
for key in loaded_data: |
|
if "lora_layer" in key: |
|
contains_lora_layers = True |
|
if not contains_lora_layers: |
|
raise Exception(f"File '{lora_path}' is not a valid CtrLoRA lora model; does not contain any lora_layer keys.") |
|
|
|
|
|
data_set: dict[str, Tensor] = {} |
|
data_lora: dict[str, Tensor] = {} |
|
|
|
for key in list(loaded_data.keys()): |
|
if 'lora_layer' in key: |
|
data_lora[key] = loaded_data.pop(key) |
|
else: |
|
data_set[key] = loaded_data.pop(key) |
|
|
|
if len(loaded_data) > 0: |
|
logger.warning("Not all keys from CtrlLoRA lora model's loaded data were parsed!") |
|
|
|
|
|
patches = {} |
|
|
|
for key, value in data_set.items(): |
|
|
|
|
|
model_key = key.replace("control_model.", "") |
|
patches[model_key] = ("set", (value,)) |
|
|
|
for down_key in data_lora: |
|
|
|
if ".up." in key: |
|
continue |
|
|
|
up_key = down_key.replace(".down.", ".up.") |
|
|
|
|
|
model_key = down_key.replace("lora_layer.down.", "").replace("control_model.", "") |
|
|
|
weight_down = data_lora[down_key] |
|
weight_up = data_lora[up_key] |
|
|
|
patches[model_key] = ("lora", (weight_up, weight_down, None, None, None, None, |
|
None, None, None, None, None, None, None, None)) |
|
|
|
|
|
control.control_model_wrapped.add_patches(patches, strength_patch=lora_strength) |
|
|