import torch import torchvision.transforms as transforms import folder_paths import os import types import numpy as np import torch.nn.functional as F from comfy.utils import load_torch_file from .utils.convert_unet import convert_iclight_unet from .utils.patches import calculate_weight_adjust_channel from .utils.image import generate_gradient_image, LightPosition from nodes import MAX_RESOLUTION from comfy.model_patcher import ModelPatcher from comfy import lora import model_management import logging class LoadAndApplyICLightUnet: @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), "model_path": (folder_paths.get_filename_list("unet"), ) } } RETURN_TYPES = ("MODEL",) FUNCTION = "load" CATEGORY = "IC-Light" DESCRIPTION = """ Loads and applies the diffusers SD1.5 IC-Light models available here: https://huggingface.co/lllyasviel/ic-light/tree/main Used with ICLightConditioning -node """ def load(self, model, model_path): type_str = str(type(model.model.model_config).__name__) if "SD15" not in type_str: raise Exception(f"Attempted to load {type_str} model, IC-Light is only compatible with SD 1.5 models.") print("LoadAndApplyICLightUnet: Checking IC-Light Unet path") model_full_path = folder_paths.get_full_path("unet", model_path) if not os.path.exists(model_full_path): raise Exception("Invalid model path") else: print("LoadAndApplyICLightUnet: Loading IC-Light Unet weights") model_clone = model.clone() iclight_state_dict = load_torch_file(model_full_path) print("LoadAndApplyICLightUnet: Attempting to add patches with IC-Light Unet weights") try: if 'conv_in.weight' in iclight_state_dict: iclight_state_dict = convert_iclight_unet(iclight_state_dict) in_channels = iclight_state_dict["diffusion_model.input_blocks.0.0.weight"].shape[1] for key in iclight_state_dict: model_clone.add_patches({key: (iclight_state_dict[key],)}, 1.0, 1.0) else: for key in iclight_state_dict: model_clone.add_patches({"diffusion_model." + key: (iclight_state_dict[key],)}, 1.0, 1.0) in_channels = iclight_state_dict["input_blocks.0.0.weight"].shape[1] except: raise Exception("Could not patch model") print("LoadAndApplyICLightUnet: Added LoadICLightUnet patches") #Patch ComfyUI's LoRA weight application to accept multi-channel inputs. Thanks @huchenlei try: if hasattr(lora, 'calculate_weight'): lora.calculate_weight = calculate_weight_adjust_channel(lora.calculate_weight) else: raise Exception("IC-Light: The 'calculate_weight' function does not exist in 'lora'") except Exception as e: raise Exception(f"IC-Light: Could not patch calculate_weight - {str(e)}") # Mimic the existing IP2P class to enable extra_conds def bound_extra_conds(self, **kwargs): return ICLight.extra_conds(self, **kwargs) new_extra_conds = types.MethodType(bound_extra_conds, model_clone.model) model_clone.add_object_patch("extra_conds", new_extra_conds) model_clone.model.model_config.unet_config["in_channels"] = in_channels return (model_clone, ) import comfy class ICLight: def extra_conds(self, **kwargs): out = {} image = kwargs.get("concat_latent_image", None) noise = kwargs.get("noise", None) device = kwargs["device"] model_in_channels = self.model_config.unet_config['in_channels'] input_channels = image.shape[1] + 4 if model_in_channels != input_channels: raise Exception(f"Input channels {input_channels} does not match model in_channels {model_in_channels}, 'opt_background' latent input should be used with the IC-Light 'fbc' model, and only with it") if image is None: image = torch.zeros_like(noise) if image.shape[1:] != noise.shape[1:]: image = comfy.utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") image = comfy.utils.resize_to_batch_size(image, noise.shape[0]) process_image_in = lambda image: image out['c_concat'] = comfy.conds.CONDNoiseShape(process_image_in(image)) adm = self.encode_adm(**kwargs) if adm is not None: out['y'] = comfy.conds.CONDRegular(adm) return out class ICLightConditioning: @classmethod def INPUT_TYPES(s): return {"required": {"positive": ("CONDITIONING", ), "negative": ("CONDITIONING", ), "vae": ("VAE", ), "foreground": ("LATENT", ), "multiplier": ("FLOAT", {"default": 0.18215, "min": 0.0, "max": 1.0, "step": 0.001}), }, "optional": { "opt_background": ("LATENT", ), }, } RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") RETURN_NAMES = ("positive", "negative", "empty_latent") FUNCTION = "encode" CATEGORY = "IC-Light" DESCRIPTION = """ Conditioning for the IC-Light model. To use the "opt_background" input, you also need to use the "fbc" version of the IC-Light models. """ def encode(self, positive, negative, vae, foreground, multiplier, opt_background=None): samples_1 = foreground["samples"] if opt_background is not None: samples_2 = opt_background["samples"] repeats_1 = samples_2.size(0) // samples_1.size(0) repeats_2 = samples_1.size(0) // samples_2.size(0) if samples_1.shape[1:] != samples_2.shape[1:]: samples_2 = comfy.utils.common_upscale(samples_2, samples_1.shape[-1], samples_1.shape[-2], "bilinear", "disabled") # Repeat the tensors to match the larger batch size if repeats_1 > 1: samples_1 = samples_1.repeat(repeats_1, 1, 1, 1) if repeats_2 > 1: samples_2 = samples_2.repeat(repeats_2, 1, 1, 1) concat_latent = torch.cat((samples_1, samples_2), dim=1) else: concat_latent = samples_1 out_latent = torch.zeros_like(samples_1) out = [] for conditioning in [positive, negative]: c = [] for t in conditioning: d = t[1].copy() d["concat_latent_image"] = concat_latent * multiplier n = [t[0], d] c.append(n) out.append(c) return (out[0], out[1], {"samples": out_latent}) class LightSource: @classmethod def INPUT_TYPES(s): return { "required": { "light_position": ([member.value for member in LightPosition],), "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.001}), "start_color": ("STRING", {"default": "#FFFFFF"}), "end_color": ("STRING", {"default": "#000000"}), "width": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }), "height": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, }), }, "optional": { "batch_size": ("INT", { "default": 1, "min": 1, "max": 4096, "step": 1, }), "prev_image": ("IMAGE",), } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("IMAGE",) FUNCTION = "execute" CATEGORY = "IC-Light" DESCRIPTION = """ Generates a gradient image that can be used as a simple light source. The color can be specified in RGB or hex format. """ def execute(self, light_position, multiplier, start_color, end_color, width, height, batch_size=1, prev_image=None): def toRgb(color): if color.startswith('#') and len(color) == 7: # e.g. "#RRGGBB" color_rgb =tuple(int(color[i:i+2], 16) for i in (1, 3, 5)) else: # e.g. "255,255,255" color_rgb = tuple(int(i) for i in color.split(',')) return color_rgb lightPosition = LightPosition(light_position) start_color_rgb = toRgb(start_color) end_color_rgb = toRgb(end_color) image = generate_gradient_image(width, height, start_color_rgb, end_color_rgb, multiplier, lightPosition) image = image.astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] image = image.repeat(batch_size, 1, 1, 1) if prev_image is not None: image = torch.cat((prev_image, image), dim=0) return (image,) class CalculateNormalsFromImages: @classmethod def INPUT_TYPES(s): return { "required": { "images": ("IMAGE",), "sigma": ("FLOAT", { "default": 10.0, "min": 0.01, "max": 100.0, "step": 0.01, }), "center_input_range": ("BOOLEAN", { "default": False, }), }, "optional": { "mask": ("MASK",), } } RETURN_TYPES = ("IMAGE", "IMAGE",) RETURN_NAMES = ("normal", "divided",) FUNCTION = "execute" CATEGORY = "IC-Light" DESCRIPTION = """ Calculates normal map from different directional exposures. Takes in 4 images as a batch: left, right, bottom, top """ def execute(self, images, sigma, center_input_range, mask=None): B, H, W, C = images.shape repetitions = B // 4 if center_input_range: images = images * 0.5 + 0.5 if mask is not None: if mask.shape[-2:] != images[0].shape[:-1]: mask = mask.unsqueeze(0) mask = F.interpolate(mask, size=(images.shape[1], images.shape[2]), mode="bilinear") mask = mask.squeeze(0) normal_list = [] divided_list = [] iteration_counter = 0 for i in range(0, B, 4): # Loop over every 4 images index = torch.arange(iteration_counter, B, repetitions) rearranged_images = images[index] images_np = rearranged_images.numpy().astype(np.float32) left = images_np[0] right = images_np[1] bottom = images_np[2] top = images_np[3] ambient = (left + right + bottom + top) / 4.0 def safe_divide(a, b): e = 1e-5 return ((a + e) / (b + e)) - 1.0 left = safe_divide(left, ambient) right = safe_divide(right, ambient) bottom = safe_divide(bottom, ambient) top = safe_divide(top, ambient) u = (right - left) * 0.5 v = (top - bottom) * 0.5 u = np.mean(u, axis=2) v = np.mean(v, axis=2) h = (1.0 - u ** 2.0 - v ** 2.0).clip(0, 1e5) ** (0.5 * sigma) z = np.zeros_like(h) normal = np.stack([u, v, h], axis=2) normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 if mask is not None: matting = mask[iteration_counter].unsqueeze(0).numpy().astype(np.float32) matting = matting[..., np.newaxis] normal = normal * matting + np.stack([z, z, 1 - z], axis=2) normal = torch.from_numpy(normal) #normal = normal.unsqueeze(0) else: normal = normal + np.stack([z, z, 1 - z], axis=2) normal = torch.from_numpy(normal).unsqueeze(0) iteration_counter += 1 normal = (normal - normal.min()) / ((normal.max() - normal.min())) normal_list.append(normal) divided = np.stack([left, right, bottom, top]) divided = torch.from_numpy(divided) divided = (divided - divided.min()) / ((divided.max() - divided.min())) divided = torch.max(divided, dim=3, keepdim=True)[0].repeat(1, 1, 1, 3) divided_list.append(divided) normal_out = torch.cat(normal_list, dim=0) divided_out = torch.cat(divided_list, dim=0) return (normal_out, divided_out, ) class LoadHDRImage: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": {"image": (sorted(files), {"image_upload": False}), "exposures": ("STRING", {"default": "-2,-1,0,1,2"}), }, } CATEGORY = "IC-Light" RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "loadhdrimage" DESCRIPTION = """ Loads a .hdr image from the input directory. Output is a batch of LDR images with the selected exposures. """ def loadhdrimage(self, image, exposures): import cv2 image_path = folder_paths.get_annotated_filepath(image) # Load the HDR image hdr_image = cv2.imread(image_path, cv2.IMREAD_ANYDEPTH) exposures = list(map(int, exposures.split(","))) if not isinstance(exposures, list): exposures = [exposures] # Example exposure values ldr_images_tensors = [] for exposure in exposures: # Scale pixel values to simulate different exposures ldr_image = np.clip(hdr_image * (2**exposure), 0, 1) # Convert to 8-bit image (LDR) by scaling to 255 ldr_image_8bit = np.uint8(ldr_image * 255) # Convert BGR to RGB ldr_image_8bit = cv2.cvtColor(ldr_image_8bit, cv2.COLOR_BGR2RGB) # Convert the LDR image to a torch tensor tensor_image = torch.from_numpy(ldr_image_8bit).float() # Normalize the tensor to the range [0, 1] tensor_image = tensor_image / 255.0 # Change the tensor shape to (C, H, W) tensor_image = tensor_image.permute(2, 0, 1) # Add the tensor to the list ldr_images_tensors.append(tensor_image) batch_tensors = torch.stack(ldr_images_tensors) batch_tensors = batch_tensors.permute(0, 2, 3, 1) return batch_tensors, class BackgroundScaler: @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",), "mask": ("MASK",), "scale": ("FLOAT", {"default": 0.5, "min": -10.0, "max": 10.0, "step": 0.001}), "invert": ("BOOLEAN", { "default": False, }), } } CATEGORY = "IC-Light" RETURN_TYPES = ("IMAGE",) FUNCTION = "apply" DESCRIPTION = """ Sets the masked area color in grayscale range. """ def apply(self, image: torch.Tensor, mask: torch.Tensor, scale: float, invert: bool): # Validate inputs if not isinstance(image, torch.Tensor) or not isinstance(mask, torch.Tensor): raise ValueError("image and mask must be torch.Tensor types.") if image.ndim != 4 or mask.ndim not in [3, 4]: raise ValueError("image must be a 4D tensor, and mask must be a 3D or 4D tensor.") # Adjust mask dimensions if necessary if mask.ndim == 3: # [B, H, W] => [B, H, W, C=1] mask = mask.unsqueeze(-1) if invert: mask = 1 - mask image_out = image * mask + (1 - mask) * scale image_out = torch.clamp(image_out, 0, 1).cpu().float() return (image_out,) class DetailTransfer: @classmethod def INPUT_TYPES(s): return { "required": { "target": ("IMAGE", ), "source": ("IMAGE", ), "mode": ([ "add", "multiply", "screen", "overlay", "soft_light", "hard_light", "color_dodge", "color_burn", "difference", "exclusion", "divide", ], {"default": "add"} ), "blur_sigma": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 100.0, "step": 0.01}), "blend_factor": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.001, "round": 0.001}), }, "optional": { "mask": ("MASK", ), } } RETURN_TYPES = ("IMAGE",) FUNCTION = "process" CATEGORY = "IC-Light" def adjust_mask(self, mask, target_tensor): # Add a channel dimension and repeat to match the channel number of the target tensor if len(mask.shape) == 3: mask = mask.unsqueeze(1) # Add a channel dimension target_channels = target_tensor.shape[1] mask = mask.expand(-1, target_channels, -1, -1) # Expand the channel dimension to match the target tensor's channels return mask def process(self, target, source, mode, blur_sigma, blend_factor, mask=None): B, H, W, C = target.shape device = model_management.get_torch_device() target_tensor = target.permute(0, 3, 1, 2).clone().to(device) source_tensor = source.permute(0, 3, 1, 2).clone().to(device) if target.shape[1:] != source.shape[1:]: source_tensor = comfy.utils.common_upscale(source_tensor, W, H, "bilinear", "disabled") if source.shape[0] < B: source = source[0].unsqueeze(0).repeat(B, 1, 1, 1) kernel_size = int(6 * int(blur_sigma) + 1) gaussian_blur = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(blur_sigma, blur_sigma)) blurred_target = gaussian_blur(target_tensor) blurred_source = gaussian_blur(source_tensor) if mode == "add": tensor_out = (source_tensor - blurred_source) + blurred_target elif mode == "multiply": tensor_out = source_tensor * blurred_target elif mode == "screen": tensor_out = 1 - (1 - source_tensor) * (1 - blurred_target) elif mode == "overlay": tensor_out = torch.where(blurred_target < 0.5, 2 * source_tensor * blurred_target, 1 - 2 * (1 - source_tensor) * (1 - blurred_target)) elif mode == "soft_light": tensor_out = (1 - 2 * blurred_target) * source_tensor**2 + 2 * blurred_target * source_tensor elif mode == "hard_light": tensor_out = torch.where(source_tensor < 0.5, 2 * source_tensor * blurred_target, 1 - 2 * (1 - source_tensor) * (1 - blurred_target)) elif mode == "difference": tensor_out = torch.abs(blurred_target - source_tensor) elif mode == "exclusion": tensor_out = 0.5 - 2 * (blurred_target - 0.5) * (source_tensor - 0.5) elif mode == "color_dodge": tensor_out = blurred_target / (1 - source_tensor) elif mode == "color_burn": tensor_out = 1 - (1 - blurred_target) / source_tensor elif mode == "divide": tensor_out = (source_tensor / blurred_source) * blurred_target else: tensor_out = source_tensor tensor_out = torch.lerp(target_tensor, tensor_out, blend_factor) if mask is not None: # Call the function and pass in mask and target_tensor mask = self.adjust_mask(mask, target_tensor) mask = mask.to(device) tensor_out = torch.lerp(target_tensor, tensor_out, mask) tensor_out = torch.clamp(tensor_out, 0, 1) tensor_out = tensor_out.permute(0, 2, 3, 1).cpu().float() return (tensor_out,) NODE_CLASS_MAPPINGS = { "LoadAndApplyICLightUnet": LoadAndApplyICLightUnet, "ICLightConditioning": ICLightConditioning, "LightSource": LightSource, "CalculateNormalsFromImages": CalculateNormalsFromImages, "LoadHDRImage": LoadHDRImage, "BackgroundScaler": BackgroundScaler, "DetailTransfer": DetailTransfer } NODE_DISPLAY_NAME_MAPPINGS = { "LoadAndApplyICLightUnet": "Load And Apply IC-Light", "ICLightConditioning": "IC-Light Conditioning", "LightSource": "Simple Light Source", "CalculateNormalsFromImages": "Calculate Normals From Images", "LoadHDRImage": "Load HDR Image", "BackgroundScaler": "Background Scaler", "DetailTransfer": "Detail Transfer" }