jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
#credit to huchenlei for this module
#from https://github.com/huchenlei/ComfyUI-layerdiffuse
import torch
import comfy.model_management
import comfy.lora
import copy
from typing import Optional
from enum import Enum
from comfy.utils import load_torch_file
from comfy.conds import CONDRegular
from comfy_extras.nodes_compositing import JoinImageWithAlpha
from .model import ModelPatcher, TransparentVAEDecoder, calculate_weight_adjust_channel
from .attension_sharing import AttentionSharingPatcher
from ..config import LAYER_DIFFUSION, LAYER_DIFFUSION_DIR, LAYER_DIFFUSION_VAE
from ..libs.utils import to_lora_patch_dict, get_local_filepath, get_sd_version
load_layer_model_state_dict = load_torch_file
class LayerMethod(Enum):
FG_ONLY_ATTN = "Attention Injection"
FG_ONLY_CONV = "Conv Injection"
FG_TO_BLEND = "Foreground"
FG_BLEND_TO_BG = "Foreground to Background"
BG_TO_BLEND = "Background"
BG_BLEND_TO_FG = "Background to Foreground"
EVERYTHING = "Everything"
class LayerDiffuse:
def __init__(self) -> None:
self.vae_transparent_decoder = None
self.frames = 1
def get_layer_diffusion_method(self, method, has_blend_latent):
method = LayerMethod(method)
if method == LayerMethod.BG_TO_BLEND and has_blend_latent:
method = LayerMethod.BG_BLEND_TO_FG
elif method == LayerMethod.FG_TO_BLEND and has_blend_latent:
method = LayerMethod.FG_BLEND_TO_BG
return method
def apply_layer_c_concat(self, cond, uncond, c_concat):
def write_c_concat(cond):
new_cond = []
for t in cond:
n = [t[0], t[1].copy()]
if "model_conds" not in n[1]:
n[1]["model_conds"] = {}
n[1]["model_conds"]["c_concat"] = CONDRegular(c_concat)
new_cond.append(n)
return new_cond
return (write_c_concat(cond), write_c_concat(uncond))
def apply_layer_diffusion(self, model: ModelPatcher, method, weight, samples, blend_samples, positive, negative, image=None, additional_cond=(None, None, None)):
control_img: Optional[torch.TensorType] = None
sd_version = get_sd_version(model)
model_url = LAYER_DIFFUSION[method.value][sd_version]["model_url"]
if image is not None:
image = image.movedim(-1, 1)
try:
if hasattr(comfy.lora, "calculate_weight"):
comfy.lora.calculate_weight = calculate_weight_adjust_channel(comfy.lora.calculate_weight)
else:
ModelPatcher.calculate_weight = calculate_weight_adjust_channel(ModelPatcher.calculate_weight)
except:
pass
if method in [LayerMethod.FG_ONLY_CONV, LayerMethod.FG_ONLY_ATTN] and sd_version == 'sd1':
self.frames = 1
elif method in [LayerMethod.BG_TO_BLEND, LayerMethod.FG_TO_BLEND, LayerMethod.BG_BLEND_TO_FG, LayerMethod.FG_BLEND_TO_BG] and sd_version == 'sd1':
self.frames = 2
batch_size, _, height, width = samples['samples'].shape
if batch_size % 2 != 0:
raise Exception(f"The batch size should be a multiple of 2. 批次大小需为2的倍数")
control_img = image
elif method == LayerMethod.EVERYTHING and sd_version == 'sd1':
batch_size, _, height, width = samples['samples'].shape
self.frames = 3
if batch_size % 3 != 0:
raise Exception(f"The batch size should be a multiple of 3. 批次大小需为3的倍数")
if model_url is None:
raise Exception(f"{method.value} is not supported for {sd_version} model")
model_path = get_local_filepath(model_url, LAYER_DIFFUSION_DIR)
layer_lora_state_dict = load_layer_model_state_dict(model_path)
work_model = model.clone()
if sd_version == 'sd1':
patcher = AttentionSharingPatcher(
work_model, self.frames, use_control=control_img is not None
)
patcher.load_state_dict(layer_lora_state_dict, strict=True)
if control_img is not None:
patcher.set_control(control_img)
else:
layer_lora_patch_dict = to_lora_patch_dict(layer_lora_state_dict)
work_model.add_patches(layer_lora_patch_dict, weight)
# cond_contact
if method in [LayerMethod.FG_ONLY_ATTN, LayerMethod.FG_ONLY_CONV]:
samp_model = work_model
elif sd_version == 'sdxl':
if method in [LayerMethod.BG_TO_BLEND, LayerMethod.FG_TO_BLEND]:
c_concat = model.model.latent_format.process_in(samples["samples"])
else:
c_concat = model.model.latent_format.process_in(torch.cat([samples["samples"], blend_samples["samples"]], dim=1))
samp_model, positive, negative = (work_model,) + self.apply_layer_c_concat(positive, negative, c_concat)
elif sd_version == 'sd1':
if method in [LayerMethod.BG_TO_BLEND, LayerMethod.BG_BLEND_TO_FG]:
additional_cond = (additional_cond[0], None)
elif method in [LayerMethod.FG_TO_BLEND, LayerMethod.FG_BLEND_TO_BG]:
additional_cond = (additional_cond[1], None)
work_model.model_options.setdefault("transformer_options", {})
work_model.model_options["transformer_options"]["cond_overwrite"] = [
cond[0][0] if cond is not None else None
for cond in additional_cond
]
samp_model = work_model
return samp_model, positive, negative
def join_image_with_alpha(self, image, alpha):
out = image.movedim(-1, 1)
if out.shape[1] == 3: # RGB
out = torch.cat([out, torch.ones_like(out[:, :1, :, :])], dim=1)
for i in range(out.shape[0]):
out[i, 3, :, :] = alpha
return out.movedim(1, -1)
def image_to_alpha(self, image, latent):
pixel = image.movedim(-1, 1) # [B, H, W, C] => [B, C, H, W]
decoded = []
sub_batch_size = 16
for start_idx in range(0, latent.shape[0], sub_batch_size):
decoded.append(
self.vae_transparent_decoder.decode_pixel(
pixel[start_idx: start_idx + sub_batch_size],
latent[start_idx: start_idx + sub_batch_size],
)
)
pixel_with_alpha = torch.cat(decoded, dim=0)
# [B, C, H, W] => [B, H, W, C]
pixel_with_alpha = pixel_with_alpha.movedim(1, -1)
image = pixel_with_alpha[..., 1:]
alpha = pixel_with_alpha[..., 0]
alpha = 1.0 - alpha
new_images, = JoinImageWithAlpha().join_image_with_alpha(image, alpha)
return new_images, alpha
def make_3d_mask(self, mask):
if len(mask.shape) == 4:
return mask.squeeze(0)
elif len(mask.shape) == 2:
return mask.unsqueeze(0)
return mask
def masks_to_list(self, masks):
if masks is None:
empty_mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
return ([empty_mask],)
res = []
for mask in masks:
res.append(mask)
return [self.make_3d_mask(x) for x in res]
def layer_diffusion_decode(self, layer_diffusion_method, latent, blend_samples, samp_images, model):
alpha = []
if layer_diffusion_method is not None:
sd_version = get_sd_version(model)
if sd_version not in ['sdxl', 'sd1']:
raise Exception(f"Only SDXL and SD1.5 model supported for Layer Diffusion")
method = self.get_layer_diffusion_method(layer_diffusion_method, blend_samples is not None)
sd15_allow = True if sd_version == 'sd1' and method in [LayerMethod.FG_ONLY_ATTN, LayerMethod.EVERYTHING, LayerMethod.BG_TO_BLEND, LayerMethod.BG_BLEND_TO_FG] else False
sdxl_allow = True if sd_version == 'sdxl' and method in [LayerMethod.FG_ONLY_CONV, LayerMethod.FG_ONLY_ATTN, LayerMethod.BG_BLEND_TO_FG] else False
if sdxl_allow or sd15_allow:
if self.vae_transparent_decoder is None:
model_url = LAYER_DIFFUSION_VAE['decode'][sd_version]["model_url"]
if model_url is None:
raise Exception(f"{method.value} is not supported for {sd_version} model")
decoder_file = get_local_filepath(model_url, LAYER_DIFFUSION_DIR)
self.vae_transparent_decoder = TransparentVAEDecoder(
load_torch_file(decoder_file),
device=comfy.model_management.get_torch_device(),
dtype=(torch.float16 if comfy.model_management.should_use_fp16() else torch.float32),
)
if method in [LayerMethod.EVERYTHING, LayerMethod.BG_BLEND_TO_FG, LayerMethod.BG_TO_BLEND]:
new_images = []
sliced_samples = copy.copy({"samples": latent})
for index in range(len(samp_images)):
if index % self.frames == 0:
img = samp_images[index::self.frames]
alpha_images, _alpha = self.image_to_alpha(img, sliced_samples["samples"][index::self.frames])
alpha.append(self.make_3d_mask(_alpha[0]))
new_images.append(alpha_images[0])
else:
new_images.append(samp_images[index])
else:
new_images, alpha = self.image_to_alpha(samp_images, latent)
else:
new_images = samp_images
else:
new_images = samp_images
return (new_images, samp_images, alpha)