import torch from torch.nn import Linear from types import MethodType import comfy.model_management import comfy.samplers from comfy.cldm.cldm import ControlNet from comfy.controlnet import ControlLora def patch_controlnet(model, control_net): import comfy.controlnet if isinstance(control_net, ControlLora): del_keys = [] for k in control_net.control_weights: if k.startswith("label_emb.0.0."): del_keys.append(k) for k in del_keys: control_net.control_weights.pop(k) super_pre_run = ControlLora.pre_run super_copy = ControlLora.copy super_forward = ControlNet.forward def KolorsControlNet_forward(self, x, hint, timesteps, context, **kwargs): with torch.cuda.amp.autocast(enabled=True): context = model.model.diffusion_model.encoder_hid_proj(context) return super_forward(self, x, hint, timesteps, context, **kwargs) def KolorsControlLora_pre_run(self, *args, **kwargs): result = super_pre_run(self, *args, **kwargs) if hasattr(self, "control_model"): self.control_model.forward = MethodType( KolorsControlNet_forward, self.control_model) return result control_net.pre_run = MethodType( KolorsControlLora_pre_run, control_net) def KolorsControlLora_copy(self, *args, **kwargs): c = super_copy(self, *args, **kwargs) c.pre_run = MethodType( KolorsControlLora_pre_run, c) return c control_net.copy = MethodType(KolorsControlLora_copy, control_net) elif isinstance(control_net, comfy.controlnet.ControlNet): model_label_emb = model.model.diffusion_model.label_emb control_net.control_model.label_emb = model_label_emb control_net.control_model_wrapped.model.label_emb = model_label_emb super_forward = ControlNet.forward def KolorsControlNet_forward(self, x, hint, timesteps, context, **kwargs): with torch.cuda.amp.autocast(enabled=True): context = model.model.diffusion_model.encoder_hid_proj(context) return super_forward(self, x, hint, timesteps, context, **kwargs) control_net.control_model.forward = MethodType( KolorsControlNet_forward, control_net.control_model) else: raise NotImplementedError(f"Type {control_net} not supported for KolorsControlNetPatch") return control_net