|
|
|
|
|
|
|
|
|
import torch |
|
import functools |
|
from comfy.model_patcher import ModelPatcher |
|
import comfy.model_management |
|
|
|
def calculate_weight_adjust_channel(func): |
|
"""Patches ComfyUI's LoRA weight application to accept multi-channel inputs.""" |
|
|
|
@functools.wraps(func) |
|
def calculate_weight(patches, weight: torch.Tensor, key: str, intermediate_dtype=torch.float32) -> torch.Tensor: |
|
weight = func(patches, weight, key, intermediate_dtype) |
|
|
|
for p in patches: |
|
alpha = p[0] |
|
v = p[1] |
|
|
|
|
|
if isinstance(v, list): |
|
continue |
|
|
|
if len(v) == 1: |
|
patch_type = "diff" |
|
elif len(v) == 2: |
|
patch_type = v[0] |
|
v = v[1] |
|
|
|
if patch_type == "diff": |
|
w1 = v[0] |
|
if all( |
|
( |
|
alpha != 0.0, |
|
w1.shape != weight.shape, |
|
w1.ndim == weight.ndim == 4, |
|
) |
|
): |
|
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)] |
|
print( |
|
f"IC-Light: Merged with {key} channel changed from {weight.shape} to {new_shape}" |
|
) |
|
new_diff = alpha * comfy.model_management.cast_to_device( |
|
w1, weight.device, weight.dtype |
|
) |
|
new_weight = torch.zeros(size=new_shape).to(weight) |
|
new_weight[ |
|
: weight.shape[0], |
|
: weight.shape[1], |
|
: weight.shape[2], |
|
: weight.shape[3], |
|
] = weight |
|
new_weight[ |
|
: new_diff.shape[0], |
|
: new_diff.shape[1], |
|
: new_diff.shape[2], |
|
: new_diff.shape[3], |
|
] += new_diff |
|
new_weight = new_weight.contiguous().clone() |
|
weight = new_weight |
|
return weight |
|
|
|
return calculate_weight |
|
|