""" Custom nodes for SDXL in ComfyUI MIT License Copyright (c) 2023 Searge Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import warnings import comfy from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel from comfy.model_management import batch_area_memory, get_torch_device, load_models_gpu import comfy.sample import comfy.samplers import comfy.utils import latent_preview import torch from torch.nn.functional import pad from ..nodes.interposer import VyroLatentInterposer convert_latent = VyroLatentInterposer().convert def next_multiple_of(value, factor): return int(int((value + factor - 1) // factor) * factor) def get_image_size(image): if image is None: return (None, None,) (_, height, width, _) = image.shape return (width, height,) def get_mask_size(mask): if mask is None: return (None, None,) (height, width) = mask.shape return (width, height,) def get_latent_size(latent): if latent is None or "samples" not in latent: return (None, None,) samples = latent["samples"] (_, _, height, width) = samples.shape return (width, height,) def get_latent_pixel_size(latent): (width, height) = get_latent_size(latent) if width is None or height is None: return (None, None,) return (width * 8, height * 8,) def slerp(factor, input1, input2): dims = input1.shape input1 = input1.reshape(dims[0], -1) input2 = input2.reshape(dims[0], -1) input1_norm = input1 / torch.norm(input1, dim=1, keepdim=True) input2_norm = input2 / torch.norm(input2, dim=1, keepdim=True) input1_norm[input1_norm != input1_norm] = 0.0 input2_norm[input2_norm != input2_norm] = 0.0 omega = torch.acos((input1_norm * input2_norm).sum(1)) sin_omega = torch.sin(omega) result = ((torch.sin((1.0 - factor) * omega) / sin_omega).unsqueeze(1) * input1 + (torch.sin(factor * omega) / sin_omega).unsqueeze(1) * input2) return result.reshape(dims) def slerp_latents(latent1, latent2, factor): result = slerp(factor, latent1.clone(), latent2.clone()) return result def bilateral_blur(inp, kernel_size, sigma_color, sigma_space, border_type='reflect', color_distance_type='l1'): if isinstance(sigma_color, torch.Tensor): sigma_color = sigma_color.to(device=inp.device, dtype=inp.dtype).view(-1, 1, 1, 1, 1) ky, kx = _unpack_2d_ks(kernel_size) pad_y, pad_x = (ky - 1) // 2, (kx - 1) // 2 padded_input = pad(inp, (pad_x, pad_x, pad_y, pad_y), mode=border_type) unfolded_input = padded_input.unfold(2, ky, 1).unfold(3, kx, 1).flatten(-2) # (B, C, H, W, Ky x Kx) diff = unfolded_input - inp.unsqueeze(-1) if color_distance_type == "l1": color_distance_sq = diff.abs().sum(1, keepdim=True).square() elif color_distance_type == "l2": color_distance_sq = diff.square().sum(1, keepdim=True) else: color_distance_sq = diff.abs().sum(1, keepdim=True).square() color_kernel = (-0.5 / sigma_color ** 2 * color_distance_sq).exp() # (B, 1, H, W, Ky x Kx) space_kernel = get_gaussian_kernel2d(kernel_size, sigma_space, inp.device, inp.dtype) space_kernel = space_kernel.view(-1, 1, 1, 1, kx * ky) kernel = space_kernel * color_kernel out = (unfolded_input * kernel).sum(-1) / kernel.sum(-1) return out def _unpack_2d_ks(kernel_size): if isinstance(kernel_size, int): ky = kx = kernel_size else: ky, kx = kernel_size return (int(ky), int(kx)) def get_gaussian_kernel2d(kernel_size, sigma, device, dtype): if isinstance(sigma, tuple): sigma = torch.tensor([sigma], device=device, dtype=dtype) else: sigma = torch.tensor([[sigma, sigma]], device=device, dtype=dtype) ksize_y, ksize_x = _unpack_2d_ks(kernel_size) sigma_y, sigma_x = sigma[:, 0, None], sigma[:, 1, None] kernel_y = get_gaussian_kernel1d(ksize_y, sigma_y, device, dtype)[..., None] kernel_x = get_gaussian_kernel1d(ksize_x, sigma_x, device, dtype)[..., None] return kernel_y * kernel_x.view(-1, 1, ksize_x) def get_gaussian_kernel1d(kernel_size, sigma, device, dtype): if isinstance(sigma, float): sigma = torch.tensor([[sigma]], device=device, dtype=dtype) batch_size = sigma.shape[0] x = (torch.arange(kernel_size, device=sigma.device, dtype=sigma.dtype) - kernel_size // 2).expand(batch_size, -1) if kernel_size % 2 == 0: x = x + 0.5 gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) return gauss / gauss.sum(-1, keepdim=True) # -------------------------------------------------------------------------------- class CfgMethods: INTERPOLATE = "interpolate" RESCALE = "rescale" TONEMAP = "tonemap" # -------------------------------------------------------------------------------- def unet_function(func, params): cond_or_uncond = params["cond_or_uncond"] input_x = params["input"] timestep = params["timestep"] c = params["c"] transformer_options = c["transformer_options"] transformer_options["uc_mask"] = torch.Tensor(cond_or_uncond).to(input_x).float()[:, None, None, None] # duplicate for each batch batch_size = input_x.shape[0] / 2 if batch_size > 1: transformer_options["uc_mask"] = transformer_options["uc_mask"].repeat_interleave(int(batch_size), dim=0) return func(input_x, timestep, **c) # -------------------------------------------------------------------------------- def new_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): x0 = old_unet_forward(self, x, timesteps, context, y, control, transformer_options, **kwargs) # do filtering here if "uc_mask" in transformer_options: uc_mask = transformer_options["uc_mask"] sharpness = 2.0 alpha = 1.0 - (timesteps / 999.0)[:, None, None, None].clone() alpha *= 0.001 * sharpness degraded_x0 = bilateral_blur(x0, (13, 13), 3.0, 3.0) * alpha + x0 * (1.0 - alpha) x0 = x0 * uc_mask + degraded_x0 * (1.0 - uc_mask) return x0 old_unet_forward = UNetModel.forward UNetModel.forward = new_unet_forward # -------------------------------------------------------------------------------- def sdxl_sample(base_model, refiner_model, noise, base_steps, refiner_steps, cfg, sampler_name, scheduler, base_positive, base_negative, refiner_positive, refiner_negative, latent_image, batch_inds, denoise=1.0, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, base_callback=None, refiner_callback=None, disable_pbar=False, seed=None, cfg_method=None, dynamic_base_cfg=0.0, dynamic_refiner_cfg=0.0, refiner_detail_boost=0.0, restart_wrapper=None): device = get_torch_device() if noise_mask is not None: noise_mask = comfy.sample.prepare_mask(noise_mask, noise.shape, device) steps = base_steps + refiner_steps def base_cfg_callback(args): (cond, uncond, cond_scale, timestep) = (args["cond"], args["uncond"], args["cond_scale"], args["timestep"]) dyn_cfg = dynamic_base_cfg if dyn_cfg < 0.0: dyn_cfg = -dyn_cfg ts = 1.0 - float(timestep) / 999.0 else: ts = float(timestep) / 999.0 if dyn_cfg > 0.0999: cond_scale = cond_scale * ts + (cond_scale * (1.0 - dyn_cfg) + dyn_cfg) * (1.0 - ts) return uncond + (cond - uncond) * cond_scale def base_rescale_cfg(args): multiplier = dynamic_base_cfg if dynamic_base_cfg >= 0.0 else -dynamic_base_cfg cond = args["cond"] uncond = args["uncond"] cond_scale = args["cond_scale"] x_cfg = uncond + cond_scale * (cond - uncond) ro_pos = torch.std(cond, dim=(1, 2, 3), keepdim=True) ro_cfg = torch.std(x_cfg, dim=(1, 2, 3), keepdim=True) x_rescaled = x_cfg * (ro_pos / ro_cfg) x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg return x_final def base_tonemap_reinhard(args): multiplier = dynamic_base_cfg if dynamic_base_cfg >= 0.0 else -dynamic_base_cfg cond = args["cond"] uncond = args["uncond"] cond_scale = args["cond_scale"] noise_pred = (cond - uncond) noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:, None] noise_pred /= noise_pred_vector_magnitude mean = torch.mean(noise_pred_vector_magnitude, dim=(1, 2, 3), keepdim=True) std = torch.std(noise_pred_vector_magnitude, dim=(1, 2, 3), keepdim=True) top = (std * 3 + mean) * multiplier noise_pred_vector_magnitude *= (1.0 / top) new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0) new_magnitude *= top return uncond + noise_pred * new_magnitude * cond_scale base_model = base_model.clone() base_model.set_model_unet_function_wrapper(unet_function) if cfg_method is not None: if cfg_method == CfgMethods.INTERPOLATE: base_model.set_model_sampler_cfg_function(base_cfg_callback) elif cfg_method == CfgMethods.RESCALE and dynamic_base_cfg > 0.0: base_model.set_model_sampler_cfg_function(base_rescale_cfg) elif cfg_method == CfgMethods.TONEMAP and dynamic_base_cfg > 0.0: base_model.set_model_sampler_cfg_function(base_tonemap_reinhard) # base_models = comfy.sample.get_additional_models(base_positive, base_negative) # load_models_gpu([base_model] + base_models, batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3])) base_models, inference_memory = comfy.sample.get_additional_models(base_positive, base_negative, base_model.model_dtype()) memory_required = batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory load_models_gpu([base_model] + base_models, memory_required) real_base_model = base_model.model original_latent = latent_image noise = noise.to(device) latent_image = latent_image.to(device) pos_base_copy = comfy.sample.broadcast_cond(base_positive, noise.shape[0], device) neg_base_copy = comfy.sample.broadcast_cond(base_negative, noise.shape[0], device) base_sampler = comfy.samplers.KSampler(real_base_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=base_model.model_options) base_samples = base_sampler.sample(noise, pos_base_copy, neg_base_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=base_steps, force_full_denoise=False, denoise_mask=noise_mask, sigmas=sigmas, callback=base_callback, disable_pbar=disable_pbar, seed=seed) comfy.sample.cleanup_additional_models(base_models) noise = torch.zeros(base_samples.size(), dtype=base_samples.dtype, layout=base_samples.layout, device=device) if refiner_steps < 1: return base_samples.cpu() if refiner_detail_boost > 0.0: new_noise = comfy.sample.prepare_noise(original_latent, seed + 1, batch_inds).to(device) new_noise /= real_base_model.latent_format.scale_factor factor = base_sampler.sigmas[-refiner_steps - 1] new_noise = new_noise * factor noised_samples = base_samples + new_noise base_samples = slerp_latents(base_samples, noised_samples, refiner_detail_boost) if noise_mask is not None: latent_from_base = base_samples * noise_mask + latent_image * (1.0 - noise_mask) else: latent_from_base = base_samples # latent_from_base = convert_latent(latent_from_base,'xl','v1') # latent_from_base.to(base_samples.device) def refiner_cfg_callback(args): (cond, uncond, cond_scale, timestep) = (args["cond"], args["uncond"], args["cond_scale"], args["timestep"]) dyn_cfg = dynamic_refiner_cfg if dyn_cfg < 0.0: dyn_cfg = -dyn_cfg ts = 1.0 - float(timestep) / 999.0 else: ts = float(timestep) / 999.0 if dyn_cfg > 0.0999: cond_scale = cond_scale * ts + (cond_scale * (1.0 - dyn_cfg) + dyn_cfg) * (1.0 - ts) return uncond + (cond - uncond) * cond_scale def refiner_rescale_cfg(args): multiplier = dynamic_refiner_cfg if dynamic_refiner_cfg >= 0.0 else -dynamic_refiner_cfg cond = args["cond"] uncond = args["uncond"] cond_scale = args["cond_scale"] x_cfg = uncond + cond_scale * (cond - uncond) ro_pos = torch.std(cond, dim=(1, 2, 3), keepdim=True) ro_cfg = torch.std(x_cfg, dim=(1, 2, 3), keepdim=True) x_rescaled = x_cfg * (ro_pos / ro_cfg) return multiplier * x_rescaled + (1.0 - multiplier) * x_cfg def refiner_tonemap_reinhard(args): multiplier = dynamic_refiner_cfg if dynamic_refiner_cfg >= 0.0 else -dynamic_refiner_cfg cond = args["cond"] uncond = args["uncond"] cond_scale = args["cond_scale"] noise_pred = (cond - uncond) noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:, None] noise_pred /= noise_pred_vector_magnitude mean = torch.mean(noise_pred_vector_magnitude, dim=(1, 2, 3), keepdim=True) std = torch.std(noise_pred_vector_magnitude, dim=(1, 2, 3), keepdim=True) top = (std * 3 + mean) * multiplier noise_pred_vector_magnitude *= (1.0 / top) new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0) new_magnitude *= top return uncond + noise_pred * new_magnitude * cond_scale refiner_model = refiner_model.clone() refiner_model.set_model_unet_function_wrapper(unet_function) if cfg_method is not None: if cfg_method == CfgMethods.INTERPOLATE: refiner_model.set_model_sampler_cfg_function(refiner_cfg_callback) elif cfg_method == CfgMethods.RESCALE and dynamic_refiner_cfg > 0.0: refiner_model.set_model_sampler_cfg_function(refiner_rescale_cfg) elif cfg_method == CfgMethods.TONEMAP and dynamic_refiner_cfg > 0.0: refiner_model.set_model_sampler_cfg_function(refiner_tonemap_reinhard) # refiner_models = comfy.sample.get_additional_models(base_positive, base_negative) # load_models_gpu([refiner_model] + refiner_models, batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3])) refiner_models, inference_memory = comfy.sample.get_additional_models(refiner_positive, refiner_negative, refiner_model.model_dtype()) memory_required = batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory load_models_gpu([refiner_model] + refiner_models, memory_required) real_refiner_model = refiner_model.model pos_refiner_copy = comfy.sample.broadcast_cond(refiner_positive, noise.shape[0], device) neg_refiner_copy = comfy.sample.broadcast_cond(refiner_negative, noise.shape[0], device) if restart_wrapper is not None: restart_wrapper.__class__.refiner_stage = True refiner_sampler = comfy.samplers.KSampler(real_refiner_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=refiner_model.model_options) refiner_samples = refiner_sampler.sample(noise, pos_refiner_copy, neg_refiner_copy, cfg=cfg, latent_image=latent_from_base, start_step=base_steps, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=refiner_callback, disable_pbar=disable_pbar, seed=seed) refiner_samples = refiner_samples.cpu() comfy.sample.cleanup_additional_models(refiner_models) return refiner_samples # -------------------------------------------------------------------------------- def sdxl_ksampler(base_model, refiner_model, seed, base_steps, refiner_steps, cfg, sampler_name, scheduler, base_positive, base_negative, refiner_positive, refiner_negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, cfg_method=None, dynamic_base_cfg=0.0, dynamic_refiner_cfg=0.0, refiner_detail_boost=0.0, restart_wrapper=None): # print(f"\n\nBase Model Type: {type(base_model.model.diffusion_model)}") device = get_torch_device() latent_image = latent["samples"] batch_inds = None if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: batch_inds = latent["batch_index"] if "batch_index" in latent else None noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds) noise_mask = None if "noise_mask" in latent: noise_mask = latent["noise_mask"] preview_format = "JPEG" if preview_format not in ["JPEG", "PNG"]: preview_format = "JPEG" base_previewer = latent_preview.get_previewer(device, base_model.model.latent_format) refiner_previewer = None if refiner_model is not None: refiner_previewer = latent_preview.get_previewer(device, refiner_model.model.latent_format) steps = base_steps + refiner_steps pbar = comfy.utils.ProgressBar(steps) def base_callback(step, x0, x, total_steps): preview_bytes = None if base_previewer: preview_bytes = base_previewer.decode_latent_to_preview_image(preview_format, x0) pbar.update_absolute(step + 1, total_steps, preview_bytes) def refiner_callback(step, x0, x, total_steps): preview_bytes = None if refiner_previewer: preview_bytes = refiner_previewer.decode_latent_to_preview_image(preview_format, x0) pbar.update_absolute(step + 1, total_steps, preview_bytes) with warnings.catch_warnings(): warnings.simplefilter("ignore") samples = sdxl_sample(base_model, refiner_model, noise, base_steps, refiner_steps, cfg, sampler_name, scheduler, base_positive, base_negative, refiner_positive, refiner_negative, latent_image, batch_inds, denoise=denoise, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, noise_mask=noise_mask, base_callback=base_callback, refiner_callback=refiner_callback, seed=seed, dynamic_base_cfg=dynamic_base_cfg, dynamic_refiner_cfg=dynamic_refiner_cfg, cfg_method=cfg_method, refiner_detail_boost=refiner_detail_boost,restart_wrapper=restart_wrapper) out = latent.copy() out["samples"] = samples return (out,)