File size: 9,280 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import torch
from torch import Tensor

from comfy.cldm.cldm import ControlNet as ControlNetCLDM
import comfy.model_detection
import comfy.model_management
import comfy.ops
import comfy.utils

from comfy.ldm.modules.diffusionmodules.util import (
    zero_module,
    timestep_embedding,
)

from .control import ControlNetAdvanced
from .utils import TimestepKeyframeGroup
from .logger import logger


class ControlNetCtrLoRA(ControlNetCLDM):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # delete input hint block
        del self.input_hint_block
    
    def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs):
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
        emb = self.time_embed(t_emb)

        out_output = []
        out_middle = []

        if self.num_classes is not None:
            assert y.shape[0] == x.shape[0]
            emb = emb + self.label_emb(y)
        
        h = hint.to(dtype=x.dtype)
        for module, zero_conv in zip(self.input_blocks, self.zero_convs):
            h = module(h, emb, context)
            out_output.append(zero_conv(h, emb, context))
        
        h = self.middle_block(h, emb, context)
        out_middle.append(self.middle_block_out(h, emb, context))

        return {"middle": out_middle, "output": out_output}


class CtrLoRAAdvanced(ControlNetAdvanced):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.require_vae = True
        self.mult_by_ratio_when_vae = False

    def pre_run_advanced(self, model, percent_to_timestep_function):
        super().pre_run_advanced(model, percent_to_timestep_function)
        self.latent_format = model.latent_format  # LatentFormat object, used to process_in latent cond hint

    def cleanup_advanced(self):
        super().cleanup_advanced()
        if self.latent_format is not None:
            del self.latent_format
            self.latent_format = None

    def copy(self):
        c = CtrLoRAAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
        c.control_model = self.control_model
        c.control_model_wrapped = self.control_model_wrapped
        self.copy_to(c)
        self.copy_to_advanced(c)
        return c


def load_ctrlora(base_path: str, lora_path: str,
                 base_data: dict[str, Tensor]=None, lora_data: dict[str, Tensor]=None,
                 timestep_keyframe: TimestepKeyframeGroup=None, model=None, model_options={}):
    if base_data is None:
        base_data = comfy.utils.load_torch_file(base_path, safe_load=True)
    controlnet_data = base_data

    # first, check that base_data contains keys with lora_layer
    contains_lora_layers = False
    for key in base_data:
        if "lora_layer" in key:
            contains_lora_layers = True
    if not contains_lora_layers:
        raise Exception(f"File '{base_path}' is not a valid CtrLoRA base model; does not contain any lora_layer keys.")
    
    controlnet_config = None
    supported_inference_dtypes = None

    pth_key = 'control_model.zero_convs.0.0.weight'
    pth = False
    key = 'zero_convs.0.0.weight'
    if pth_key in controlnet_data:
        pth = True
        key = pth_key
        prefix = "control_model."
    elif key in controlnet_data:
        prefix = ""
    else:
        raise Exception("")
        net = load_t2i_adapter(controlnet_data, model_options=model_options)
        if net is None:
            logging.error("error could not detect control model type.")
        return net

    if controlnet_config is None:
        model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
        supported_inference_dtypes = list(model_config.supported_inference_dtypes)
        controlnet_config = model_config.unet_config

    unet_dtype = model_options.get("dtype", None)
    if unet_dtype is None:
        weight_dtype = comfy.utils.weight_dtype(controlnet_data)

        if supported_inference_dtypes is None:
            supported_inference_dtypes = [comfy.model_management.unet_dtype()]

        if weight_dtype is not None:
            supported_inference_dtypes.append(weight_dtype)

        unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)

    load_device = comfy.model_management.get_torch_device()

    manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
    operations = model_options.get("custom_operations", None)
    if operations is None:
        operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)

    controlnet_config["operations"] = operations
    controlnet_config["dtype"] = unet_dtype
    controlnet_config["device"] = comfy.model_management.unet_offload_device()
    controlnet_config.pop("out_channels")
    controlnet_config["hint_channels"] = 3
    #controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
    control_model = ControlNetCtrLoRA(**controlnet_config)

    if pth:
        if 'difference' in controlnet_data:
            if model is not None:
                comfy.model_management.load_models_gpu([model])
                model_sd = model.model_state_dict()
                for x in controlnet_data:
                    c_m = "control_model."
                    if x.startswith(c_m):
                        sd_key = "diffusion_model.{}".format(x[len(c_m):])
                        if sd_key in model_sd:
                            cd = controlnet_data[x]
                            cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
            else:
                logger.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")

        class WeightsLoader(torch.nn.Module):
            pass
        w = WeightsLoader()
        w.control_model = control_model
        missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
    else:
        missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)

    if len(missing) > 0:
        logger.warning("missing controlnet keys: {}".format(missing))

    if len(unexpected) > 0:
        logger.debug("unexpected controlnet keys: {}".format(unexpected))

    global_average_pooling = model_options.get("global_average_pooling", False)
    control = CtrLoRAAdvanced(control_model, timestep_keyframe, global_average_pooling=global_average_pooling,
                                 load_device=load_device, manual_cast_dtype=manual_cast_dtype)
    # load lora data onto the controlnet
    if lora_path is not None:
        load_lora_data(control, lora_path)

    return control


def load_lora_data(control: CtrLoRAAdvanced, lora_path: str, loaded_data: dict[str, Tensor]=None, lora_strength=1.0):
    if loaded_data is None:
        loaded_data = comfy.utils.load_torch_file(lora_path, safe_load=True)
    # check that lora_data contains keys with lora_layer
    contains_lora_layers = False
    for key in loaded_data:
        if "lora_layer" in key:
            contains_lora_layers = True
    if not contains_lora_layers:
        raise Exception(f"File '{lora_path}' is not a valid CtrLoRA lora model; does not contain any lora_layer keys.")

    # now that we know we have a ctrlora file, separate keys into 'set' and 'lora' keys
    data_set: dict[str, Tensor] = {}
    data_lora: dict[str, Tensor] = {}

    for key in list(loaded_data.keys()):
        if 'lora_layer' in key:
            data_lora[key] = loaded_data.pop(key)
        else:
            data_set[key] = loaded_data.pop(key)
    # no keys should be left over
    if len(loaded_data) > 0:
        logger.warning("Not all keys from CtrlLoRA lora model's loaded data were parsed!")
    
    # turn set/lora data into corresponding patches;
    patches = {}
    # set will replace the values
    for key, value in data_set.items():
        # prase model key from key;
        # remove "control_model."
        model_key = key.replace("control_model.", "")
        patches[model_key] = ("set", (value,))
    # lora will do mm of up and down tensors
    for down_key in data_lora:
        # only process lora down keys; we will process both up+down at the same time
        if ".up." in key:
            continue
        # get up version of down key
        up_key = down_key.replace(".down.", ".up.")
        # get key that will match up with model key;
        # remove "lora_layer.down." and "control_model."
        model_key = down_key.replace("lora_layer.down.", "").replace("control_model.", "")
        
        weight_down = data_lora[down_key]
        weight_up = data_lora[up_key]
        # currently, ComfyUI expects 6 elements in 'lora' type, but for future-proofing add a bunch more with None
        patches[model_key] = ("lora", (weight_up, weight_down, None, None, None, None,
                                       None, None, None, None, None, None, None, None))
    
    # now that patches are made, add them to model
    control.control_model_wrapped.add_patches(patches, strength_patch=lora_strength)