File size: 2,364 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

#credit to huchenlei for this
#from https://github.com/huchenlei/ComfyUI-layerdiffuse/blob/151f7460bbc9d7437d4f0010f21f80178f7a84a6/layered_diffusion.py#L34-L96

import torch
import functools
from comfy.model_patcher import ModelPatcher
import comfy.model_management

def calculate_weight_adjust_channel(func):
    """Patches ComfyUI's LoRA weight application to accept multi-channel inputs."""

    @functools.wraps(func)
    def calculate_weight(patches, weight: torch.Tensor, key: str, intermediate_dtype=torch.float32) -> torch.Tensor:
        weight = func(patches, weight, key, intermediate_dtype)

        for p in patches:
            alpha = p[0]
            v = p[1]

            # The recursion call should be handled in the main func call.
            if isinstance(v, list):
                continue

            if len(v) == 1:
                patch_type = "diff"
            elif len(v) == 2:
                patch_type = v[0]
                v = v[1]

            if patch_type == "diff":
                w1 = v[0]
                if all(
                    (
                        alpha != 0.0,
                        w1.shape != weight.shape,
                        w1.ndim == weight.ndim == 4,
                    )
                ):
                    new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
                    print(
                        f"IC-Light: Merged with {key} channel changed from {weight.shape} to {new_shape}"
                    )
                    new_diff = alpha * comfy.model_management.cast_to_device(
                        w1, weight.device, weight.dtype
                    )
                    new_weight = torch.zeros(size=new_shape).to(weight)
                    new_weight[
                        : weight.shape[0],
                        : weight.shape[1],
                        : weight.shape[2],
                        : weight.shape[3],
                    ] = weight
                    new_weight[
                        : new_diff.shape[0],
                        : new_diff.shape[1],
                        : new_diff.shape[2],
                        : new_diff.shape[3],
                    ] += new_diff
                    new_weight = new_weight.contiguous().clone()
                    weight = new_weight
        return weight

    return calculate_weight