|
import os |
|
import torch |
|
from torch.nn import functional as F |
|
from omegaconf import OmegaConf |
|
import comfy.utils |
|
import comfy.model_management as mm |
|
import folder_paths |
|
from nodes import ImageScaleBy |
|
from nodes import ImageScale |
|
import torch.cuda |
|
from .sgm.util import instantiate_from_config |
|
from .SUPIR.util import convert_dtype, load_state_dict |
|
import open_clip |
|
from contextlib import contextmanager |
|
|
|
from transformers import ( |
|
CLIPTextModel, |
|
CLIPTokenizer, |
|
CLIPTextConfig, |
|
|
|
) |
|
script_directory = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
def dummy_build_vision_tower(*args, **kwargs): |
|
|
|
return None |
|
|
|
@contextmanager |
|
def patch_build_vision_tower(): |
|
original_build_vision_tower = open_clip.model._build_vision_tower |
|
open_clip.model._build_vision_tower = dummy_build_vision_tower |
|
|
|
try: |
|
yield |
|
finally: |
|
open_clip.model._build_vision_tower = original_build_vision_tower |
|
|
|
def build_text_model_from_openai_state_dict( |
|
state_dict: dict, |
|
cast_dtype=torch.float16, |
|
): |
|
|
|
embed_dim = state_dict["text_projection"].shape[1] |
|
context_length = state_dict["positional_embedding"].shape[0] |
|
vocab_size = state_dict["token_embedding.weight"].shape[0] |
|
transformer_width = state_dict["ln_final.weight"].shape[0] |
|
transformer_heads = transformer_width // 64 |
|
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) |
|
|
|
vision_cfg = None |
|
text_cfg = open_clip.CLIPTextCfg( |
|
context_length=context_length, |
|
vocab_size=vocab_size, |
|
width=transformer_width, |
|
heads=transformer_heads, |
|
layers=transformer_layers, |
|
) |
|
|
|
with patch_build_vision_tower(): |
|
model = open_clip.CLIP( |
|
embed_dim, |
|
vision_cfg=vision_cfg, |
|
text_cfg=text_cfg, |
|
quick_gelu=True, |
|
cast_dtype=cast_dtype, |
|
) |
|
|
|
model.load_state_dict(state_dict, strict=False) |
|
model = model.eval() |
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
return model |
|
|
|
class SUPIR_Upscale: |
|
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"supir_model": (folder_paths.get_filename_list("checkpoints"),), |
|
"sdxl_model": (folder_paths.get_filename_list("checkpoints"),), |
|
"image": ("IMAGE",), |
|
"seed": ("INT", {"default": 123, "min": 0, "max": 0xffffffffffffffff, "step": 1}), |
|
"resize_method": (s.upscale_methods, {"default": "lanczos"}), |
|
"scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 20.0, "step": 0.01}), |
|
"steps": ("INT", {"default": 45, "min": 3, "max": 4096, "step": 1}), |
|
"restoration_scale": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 6.0, "step": 1.0}), |
|
"cfg_scale": ("FLOAT", {"default": 4.0, "min": 0, "max": 100, "step": 0.01}), |
|
"a_prompt": ("STRING", {"multiline": True, "default": "high quality, detailed", }), |
|
"n_prompt": ("STRING", {"multiline": True, "default": "bad quality, blurry, messy", }), |
|
"s_churn": ("INT", {"default": 5, "min": 0, "max": 40, "step": 1}), |
|
"s_noise": ("FLOAT", {"default": 1.003, "min": 1.0, "max": 1.1, "step": 0.001}), |
|
"control_scale": ("FLOAT", {"default": 1.0, "min": 0, "max": 10.0, "step": 0.05}), |
|
"cfg_scale_start": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 100.0, "step": 0.05}), |
|
"control_scale_start": ("FLOAT", {"default": 0.0, "min": 0, "max": 1.0, "step": 0.05}), |
|
"color_fix_type": ( |
|
[ |
|
'None', |
|
'AdaIn', |
|
'Wavelet', |
|
], { |
|
"default": 'Wavelet' |
|
}), |
|
"keep_model_loaded": ("BOOLEAN", {"default": True}), |
|
"use_tiled_vae": ("BOOLEAN", {"default": True}), |
|
"encoder_tile_size_pixels": ("INT", {"default": 512, "min": 64, "max": 8192, "step": 64}), |
|
"decoder_tile_size_latent": ("INT", {"default": 64, "min": 32, "max": 8192, "step": 64}), |
|
}, |
|
"optional": { |
|
"captions": ("STRING", {"forceInput": True, "multiline": False, "default": "", }), |
|
"diffusion_dtype": ( |
|
[ |
|
'fp16', |
|
'bf16', |
|
'fp32', |
|
'auto' |
|
], { |
|
"default": 'auto' |
|
}), |
|
"encoder_dtype": ( |
|
[ |
|
'bf16', |
|
'fp32', |
|
'auto' |
|
], { |
|
"default": 'auto' |
|
}), |
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 128, "step": 1}), |
|
"use_tiled_sampling": ("BOOLEAN", {"default": False}), |
|
"sampler_tile_size": ("INT", {"default": 1024, "min": 64, "max": 4096, "step": 32}), |
|
"sampler_tile_stride": ("INT", {"default": 512, "min": 32, "max": 2048, "step": 32}), |
|
"fp8_unet": ("BOOLEAN", {"default": False}), |
|
"fp8_vae": ("BOOLEAN", {"default": False}), |
|
"sampler": ( |
|
[ |
|
'RestoreDPMPP2MSampler', |
|
'RestoreEDMSampler', |
|
], { |
|
"default": 'RestoreEDMSampler' |
|
}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
RETURN_NAMES = ("upscaled_image",) |
|
FUNCTION = "process" |
|
|
|
CATEGORY = "SUPIR" |
|
|
|
def process(self, steps, image, color_fix_type, seed, scale_by, cfg_scale, resize_method, s_churn, s_noise, |
|
encoder_tile_size_pixels, decoder_tile_size_latent, |
|
control_scale, cfg_scale_start, control_scale_start, restoration_scale, keep_model_loaded, |
|
a_prompt, n_prompt, sdxl_model, supir_model, use_tiled_vae, use_tiled_sampling=False, sampler_tile_size=128, sampler_tile_stride=64, captions="", diffusion_dtype="auto", |
|
encoder_dtype="auto", batch_size=1, fp8_unet=False, fp8_vae=False, sampler="RestoreEDMSampler"): |
|
device = mm.get_torch_device() |
|
mm.unload_all_models() |
|
|
|
SUPIR_MODEL_PATH = folder_paths.get_full_path("checkpoints", supir_model) |
|
SDXL_MODEL_PATH = folder_paths.get_full_path("checkpoints", sdxl_model) |
|
|
|
config_path = os.path.join(script_directory, "options/SUPIR_v0.yaml") |
|
config_path_tiled = os.path.join(script_directory, "options/SUPIR_v0_tiled.yaml") |
|
clip_config_path = os.path.join(script_directory, "configs/clip_vit_config.json") |
|
tokenizer_path = os.path.join(script_directory, "configs/tokenizer") |
|
|
|
custom_config = { |
|
'sdxl_model': sdxl_model, |
|
'diffusion_dtype': diffusion_dtype, |
|
'encoder_dtype': encoder_dtype, |
|
'use_tiled_vae': use_tiled_vae, |
|
'supir_model': supir_model, |
|
'use_tiled_sampling': use_tiled_sampling, |
|
'fp8_unet': fp8_unet, |
|
'fp8_vae': fp8_vae, |
|
'sampler': sampler |
|
} |
|
|
|
if diffusion_dtype == 'auto': |
|
try: |
|
if mm.should_use_fp16(): |
|
print("Diffusion using fp16") |
|
dtype = torch.float16 |
|
model_dtype = 'fp16' |
|
if mm.should_use_bf16(): |
|
print("Diffusion using bf16") |
|
dtype = torch.bfloat16 |
|
model_dtype = 'bf16' |
|
else: |
|
print("Diffusion using using fp32") |
|
dtype = torch.float32 |
|
model_dtype = 'fp32' |
|
except: |
|
raise AttributeError("ComfyUI too old, can't autodecet properly. Set your dtypes manually.") |
|
else: |
|
print(f"Diffusion using using {diffusion_dtype}") |
|
dtype = convert_dtype(diffusion_dtype) |
|
model_dtype = diffusion_dtype |
|
|
|
if encoder_dtype == 'auto': |
|
try: |
|
if mm.should_use_bf16(): |
|
print("Encoder using bf16") |
|
vae_dtype = 'bf16' |
|
else: |
|
print("Encoder using using fp32") |
|
vae_dtype = 'fp32' |
|
except: |
|
raise AttributeError("ComfyUI too old, can't autodetect properly. Set your dtypes manually.") |
|
else: |
|
vae_dtype = encoder_dtype |
|
print(f"Encoder using using {vae_dtype}") |
|
|
|
if not hasattr(self, "model") or self.model is None or self.current_config != custom_config: |
|
self.current_config = custom_config |
|
self.model = None |
|
|
|
mm.soft_empty_cache() |
|
|
|
if use_tiled_sampling: |
|
config = OmegaConf.load(config_path_tiled) |
|
config.model.params.sampler_config.params.tile_size = sampler_tile_size // 8 |
|
config.model.params.sampler_config.params.tile_stride = sampler_tile_stride // 8 |
|
config.model.params.sampler_config.target = f".sgm.modules.diffusionmodules.sampling.Tiled{sampler}" |
|
print("Using tiled sampling") |
|
else: |
|
config = OmegaConf.load(config_path) |
|
config.model.params.sampler_config.target = f".sgm.modules.diffusionmodules.sampling.{sampler}" |
|
print("Using non-tiled sampling") |
|
|
|
if mm.XFORMERS_IS_AVAILABLE: |
|
config.model.params.control_stage_config.params.spatial_transformer_attn_type = "softmax-xformers" |
|
config.model.params.network_config.params.spatial_transformer_attn_type = "softmax-xformers" |
|
config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla-xformers" |
|
|
|
config.model.params.ae_dtype = vae_dtype |
|
config.model.params.diffusion_dtype = model_dtype |
|
|
|
self.model = instantiate_from_config(config.model).cpu() |
|
|
|
try: |
|
print(f'Attempting to load SUPIR model: [{SUPIR_MODEL_PATH}]') |
|
supir_state_dict = load_state_dict(SUPIR_MODEL_PATH) |
|
|
|
except: |
|
raise Exception("Failed to load SUPIR model") |
|
try: |
|
print(f"Attempting to load SDXL model: [{SDXL_MODEL_PATH}]") |
|
sdxl_state_dict = load_state_dict(SDXL_MODEL_PATH) |
|
except: |
|
raise Exception("Failed to load SDXL model") |
|
self.model.load_state_dict(supir_state_dict, strict=False) |
|
self.model.load_state_dict(sdxl_state_dict, strict=False) |
|
|
|
del supir_state_dict |
|
|
|
|
|
try: |
|
print("Loading first clip model from SDXL checkpoint") |
|
|
|
replace_prefix = {} |
|
replace_prefix["conditioner.embedders.0.transformer."] = "" |
|
|
|
sd = comfy.utils.state_dict_prefix_replace(sdxl_state_dict, replace_prefix, filter_keys=False) |
|
clip_text_config = CLIPTextConfig.from_pretrained(clip_config_path) |
|
self.model.conditioner.embedders[0].tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) |
|
self.model.conditioner.embedders[0].transformer = CLIPTextModel(clip_text_config) |
|
self.model.conditioner.embedders[0].transformer.load_state_dict(sd, strict=False) |
|
self.model.conditioner.embedders[0].eval() |
|
for param in self.model.conditioner.embedders[0].parameters(): |
|
param.requires_grad = False |
|
except: |
|
raise Exception("Failed to load first clip model from SDXL checkpoint") |
|
|
|
del sdxl_state_dict |
|
|
|
|
|
try: |
|
print("Loading second clip model from SDXL checkpoint") |
|
replace_prefix2 = {} |
|
replace_prefix2["conditioner.embedders.1.model."] = "" |
|
sd = comfy.utils.state_dict_prefix_replace(sd, replace_prefix2, filter_keys=True) |
|
clip_g = build_text_model_from_openai_state_dict(sd, cast_dtype=dtype) |
|
self.model.conditioner.embedders[1].model = clip_g |
|
except: |
|
raise Exception("Failed to load second clip model from SDXL checkpoint") |
|
|
|
del sd, clip_g |
|
mm.soft_empty_cache() |
|
|
|
self.model.to(dtype) |
|
|
|
|
|
if fp8_unet: |
|
self.model.model.to(torch.float8_e4m3fn) |
|
if fp8_vae: |
|
self.model.first_stage_model.to(torch.float8_e4m3fn) |
|
|
|
if use_tiled_vae: |
|
self.model.init_tile_vae(encoder_tile_size=encoder_tile_size_pixels, decoder_tile_size=decoder_tile_size_latent) |
|
|
|
upscaled_image, = ImageScaleBy.upscale(self, image, resize_method, scale_by) |
|
B, H, W, C = upscaled_image.shape |
|
new_height = H if H % 64 == 0 else ((H // 64) + 1) * 64 |
|
new_width = W if W % 64 == 0 else ((W // 64) + 1) * 64 |
|
upscaled_image = upscaled_image.permute(0, 3, 1, 2) |
|
resized_image = F.interpolate(upscaled_image, size=(new_height, new_width), mode='bicubic', align_corners=False) |
|
resized_image = resized_image.to(device) |
|
|
|
captions_list = [] |
|
captions_list.append(captions) |
|
print("captions: ", captions_list) |
|
|
|
use_linear_CFG = cfg_scale_start > 0 |
|
use_linear_control_scale = control_scale_start > 0 |
|
out = [] |
|
pbar = comfy.utils.ProgressBar(B) |
|
|
|
batched_images = [resized_image[i:i + batch_size] for i in |
|
range(0, len(resized_image), batch_size)] |
|
captions_list = captions_list * resized_image.shape[0] |
|
batched_captions = [captions_list[i:i + batch_size] for i in range(0, len(captions_list), batch_size)] |
|
|
|
mm.soft_empty_cache() |
|
i = 1 |
|
for imgs, caps in zip(batched_images, batched_captions): |
|
try: |
|
samples = self.model.batchify_sample(imgs, caps, num_steps=steps, |
|
restoration_scale=restoration_scale, s_churn=s_churn, |
|
s_noise=s_noise, cfg_scale=cfg_scale, control_scale=control_scale, |
|
seed=seed, |
|
num_samples=1, p_p=a_prompt, n_p=n_prompt, |
|
color_fix_type=color_fix_type, |
|
use_linear_CFG=use_linear_CFG, |
|
use_linear_control_scale=use_linear_control_scale, |
|
cfg_scale_start=cfg_scale_start, |
|
control_scale_start=control_scale_start) |
|
except torch.cuda.OutOfMemoryError as e: |
|
mm.free_memory(mm.get_total_memory(mm.get_torch_device()), mm.get_torch_device()) |
|
self.model = None |
|
mm.soft_empty_cache() |
|
print("It's likely that too large of an image or batch_size for SUPIR was used," |
|
" and it has devoured all of the memory it had reserved, you may need to restart ComfyUI. Make sure you are using tiled_vae, " |
|
" you can also try using fp8 for reduced memory usage if your system supports it.") |
|
raise e |
|
|
|
out.append(samples.squeeze(0).cpu()) |
|
print("Sampled ", i * len(imgs), " out of ", B) |
|
i = i + 1 |
|
pbar.update(1) |
|
if not keep_model_loaded: |
|
self.model = None |
|
mm.soft_empty_cache() |
|
|
|
if len(out[0].shape) == 4: |
|
out_stacked = torch.cat(out, dim=0).cpu().to(torch.float32).permute(0, 2, 3, 1) |
|
else: |
|
out_stacked = torch.stack(out, dim=0).cpu().to(torch.float32).permute(0, 2, 3, 1) |
|
|
|
final_image, = ImageScale.upscale(self, out_stacked, resize_method, W, H, crop="disabled") |
|
|
|
return (final_image,) |
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"SUPIR_Upscale": SUPIR_Upscale |
|
} |
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
"SUPIR_Upscale": "SUPIR_Upscale" |
|
} |