|
import os |
|
import torch |
|
from omegaconf import OmegaConf |
|
import comfy.utils |
|
import comfy.model_management as mm |
|
import folder_paths |
|
import torch.cuda |
|
import torch.nn.functional as F |
|
from .sgm.util import instantiate_from_config |
|
from .SUPIR.util import convert_dtype, load_state_dict |
|
from .sgm.modules.distributions.distributions import DiagonalGaussianDistribution |
|
import open_clip |
|
from contextlib import contextmanager, nullcontext |
|
import gc |
|
|
|
from contextlib import nullcontext |
|
try: |
|
from accelerate import init_empty_weights |
|
from accelerate.utils import set_module_tensor_to_device |
|
is_accelerate_available = True |
|
except: |
|
pass |
|
|
|
from transformers import ( |
|
CLIPTextModel, |
|
CLIPTokenizer, |
|
CLIPTextConfig, |
|
|
|
) |
|
script_directory = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
def dummy_build_vision_tower(*args, **kwargs): |
|
|
|
return None |
|
|
|
@contextmanager |
|
def patch_build_vision_tower(): |
|
original_build_vision_tower = open_clip.model._build_vision_tower |
|
open_clip.model._build_vision_tower = dummy_build_vision_tower |
|
|
|
try: |
|
yield |
|
finally: |
|
open_clip.model._build_vision_tower = original_build_vision_tower |
|
|
|
def build_text_model_from_openai_state_dict( |
|
state_dict: dict, |
|
device, |
|
cast_dtype=torch.float16, |
|
): |
|
|
|
embed_dim = state_dict["text_projection"].shape[1] |
|
context_length = state_dict["positional_embedding"].shape[0] |
|
vocab_size = state_dict["token_embedding.weight"].shape[0] |
|
transformer_width = state_dict["ln_final.weight"].shape[0] |
|
transformer_heads = transformer_width // 64 |
|
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) |
|
|
|
vision_cfg = None |
|
text_cfg = open_clip.CLIPTextCfg( |
|
context_length=context_length, |
|
vocab_size=vocab_size, |
|
width=transformer_width, |
|
heads=transformer_heads, |
|
layers=transformer_layers, |
|
) |
|
|
|
with patch_build_vision_tower(): |
|
with (init_empty_weights() if is_accelerate_available else nullcontext()): |
|
model = open_clip.CLIP( |
|
embed_dim, |
|
vision_cfg=vision_cfg, |
|
text_cfg=text_cfg, |
|
quick_gelu=True, |
|
cast_dtype=cast_dtype, |
|
) |
|
if is_accelerate_available: |
|
for key in state_dict: |
|
set_module_tensor_to_device(model, key, device=device, value=state_dict[key]) |
|
else: |
|
model.load_state_dict(state_dict, strict=False) |
|
model = model.eval() |
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
return model |
|
|
|
class SUPIR_encode: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"SUPIR_VAE": ("SUPIRVAE",), |
|
"image": ("IMAGE",), |
|
"use_tiled_vae": ("BOOLEAN", {"default": True}), |
|
"encoder_tile_size": ("INT", {"default": 512, "min": 64, "max": 8192, "step": 64}), |
|
"encoder_dtype": ( |
|
[ |
|
'bf16', |
|
'fp32', |
|
'auto' |
|
], { |
|
"default": 'auto' |
|
}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("LATENT",) |
|
RETURN_NAMES = ("latent",) |
|
FUNCTION = "encode" |
|
CATEGORY = "SUPIR" |
|
|
|
def encode(self, SUPIR_VAE, image, encoder_dtype, use_tiled_vae, encoder_tile_size): |
|
device = mm.get_torch_device() |
|
mm.unload_all_models() |
|
if encoder_dtype == 'auto': |
|
try: |
|
if mm.should_use_bf16(): |
|
print("Encoder using bf16") |
|
vae_dtype = 'bf16' |
|
else: |
|
print("Encoder using fp32") |
|
vae_dtype = 'fp32' |
|
except: |
|
raise AttributeError("ComfyUI version too old, can't autodetect properly. Set your dtypes manually.") |
|
else: |
|
vae_dtype = encoder_dtype |
|
print(f"Encoder using {vae_dtype}") |
|
|
|
dtype = convert_dtype(vae_dtype) |
|
|
|
image = image.permute(0, 3, 1, 2) |
|
B, C, H, W = image.shape |
|
downscale_ratio = 32 |
|
orig_H, orig_W = H, W |
|
if W % downscale_ratio != 0: |
|
W = W - (W % downscale_ratio) |
|
if H % downscale_ratio != 0: |
|
H = H - (H % downscale_ratio) |
|
if orig_H % downscale_ratio != 0 or orig_W % downscale_ratio != 0: |
|
image = F.interpolate(image, size=(H, W), mode="bicubic") |
|
resized_image = image.to(device) |
|
|
|
if use_tiled_vae: |
|
from .SUPIR.utils.tilevae import VAEHook |
|
|
|
if not hasattr(SUPIR_VAE.encoder, 'original_forward'): |
|
SUPIR_VAE.encoder.original_forward = SUPIR_VAE.encoder.forward |
|
SUPIR_VAE.encoder.forward = VAEHook( |
|
SUPIR_VAE.encoder, encoder_tile_size, is_decoder=False, fast_decoder=False, |
|
fast_encoder=False, color_fix=False, to_gpu=True) |
|
else: |
|
|
|
if hasattr(SUPIR_VAE.encoder, 'original_forward'): |
|
SUPIR_VAE.encoder.forward = SUPIR_VAE.encoder.original_forward |
|
|
|
pbar = comfy.utils.ProgressBar(B) |
|
out = [] |
|
for img in resized_image: |
|
|
|
SUPIR_VAE.to(dtype).to(device) |
|
|
|
autocast_condition = (dtype != torch.float32) and not comfy.model_management.is_device_mps(device) |
|
with torch.autocast(comfy.model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext(): |
|
|
|
z = SUPIR_VAE.encode(img.unsqueeze(0)) |
|
z = z * 0.13025 |
|
out.append(z) |
|
pbar.update(1) |
|
|
|
if len(out[0].shape) == 4: |
|
samples_out_stacked = torch.cat(out, dim=0) |
|
else: |
|
samples_out_stacked = torch.stack(out, dim=0) |
|
return ({"samples":samples_out_stacked, "original_size": [orig_H, orig_W]},) |
|
|
|
class SUPIR_decode: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"SUPIR_VAE": ("SUPIRVAE",), |
|
"latents": ("LATENT",), |
|
"use_tiled_vae": ("BOOLEAN", {"default": True}), |
|
"decoder_tile_size": ("INT", {"default": 512, "min": 64, "max": 8192, "step": 64}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
RETURN_NAMES = ("image",) |
|
FUNCTION = "decode" |
|
CATEGORY = "SUPIR" |
|
|
|
def decode(self, SUPIR_VAE, latents, use_tiled_vae, decoder_tile_size): |
|
device = mm.get_torch_device() |
|
mm.unload_all_models() |
|
samples = latents["samples"] |
|
|
|
B, H, W, C = samples.shape |
|
|
|
pbar = comfy.utils.ProgressBar(B) |
|
|
|
if mm.should_use_bf16(): |
|
print("Decoder using bf16") |
|
dtype = torch.bfloat16 |
|
else: |
|
print("Decoder using fp32") |
|
dtype = torch.float32 |
|
print("SUPIR decoder using", dtype) |
|
|
|
SUPIR_VAE.to(dtype).to(device) |
|
samples = samples.to(device) |
|
|
|
if use_tiled_vae: |
|
from .SUPIR.utils.tilevae import VAEHook |
|
|
|
if not hasattr(SUPIR_VAE.decoder, 'original_forward'): |
|
SUPIR_VAE.decoder.original_forward = SUPIR_VAE.decoder.forward |
|
SUPIR_VAE.decoder.forward = VAEHook( |
|
SUPIR_VAE.decoder, decoder_tile_size // 8, is_decoder=True, fast_decoder=False, |
|
fast_encoder=False, color_fix=False, to_gpu=True) |
|
else: |
|
|
|
if hasattr(SUPIR_VAE.decoder, 'original_forward'): |
|
SUPIR_VAE.decoder.forward = SUPIR_VAE.decoder.original_forward |
|
|
|
out = [] |
|
for sample in samples: |
|
autocast_condition = (dtype != torch.float32) and not comfy.model_management.is_device_mps(device) |
|
with torch.autocast(comfy.model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext(): |
|
sample = 1.0 / 0.13025 * sample |
|
decoded_image = SUPIR_VAE.decode(sample.unsqueeze(0)) |
|
out.append(decoded_image) |
|
pbar.update(1) |
|
|
|
decoded_out= torch.cat(out, dim=0).float() |
|
|
|
if "original_size" in latents and latents["original_size"] is not None: |
|
orig_H, orig_W = latents["original_size"] |
|
if decoded_out.shape[2] != orig_H or decoded_out.shape[3] != orig_W: |
|
print("Restoring original dimensions: ", orig_W,"x",orig_H) |
|
decoded_out = F.interpolate(decoded_out, size=(orig_H, orig_W), mode="bicubic") |
|
|
|
decoded_out = torch.clip(decoded_out, 0, 1) |
|
decoded_out = decoded_out.cpu().to(torch.float32).permute(0, 2, 3, 1) |
|
|
|
|
|
return (decoded_out,) |
|
|
|
class SUPIR_first_stage: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"SUPIR_VAE": ("SUPIRVAE",), |
|
"image": ("IMAGE",), |
|
"use_tiled_vae": ("BOOLEAN", {"default": True}), |
|
"encoder_tile_size": ("INT", {"default": 512, "min": 64, "max": 8192, "step": 64}), |
|
"decoder_tile_size": ("INT", {"default": 512, "min": 64, "max": 8192, "step": 64}), |
|
"encoder_dtype": ( |
|
[ |
|
'bf16', |
|
'fp32', |
|
'auto' |
|
], { |
|
"default": 'auto' |
|
}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("SUPIRVAE", "IMAGE", "LATENT",) |
|
RETURN_NAMES = ("SUPIR_VAE", "denoised_image", "denoised_latents",) |
|
FUNCTION = "process" |
|
CATEGORY = "SUPIR" |
|
DESCRIPTION = """ |
|
SUPIR "first stage" processing. |
|
Encodes and decodes the image using SUPIR's "denoise_encoder", purpose |
|
is to fix compression artifacts and such, ends up blurring the image often |
|
which is expected. Can be replaced with any other denoiser/blur or not used at all. |
|
""" |
|
|
|
def process(self, SUPIR_VAE, image, encoder_dtype, use_tiled_vae, encoder_tile_size, decoder_tile_size): |
|
device = mm.get_torch_device() |
|
mm.unload_all_models() |
|
if encoder_dtype == 'auto': |
|
try: |
|
|
|
if mm.should_use_bf16(): |
|
print("Encoder using bf16") |
|
vae_dtype = 'bf16' |
|
else: |
|
print("Encoder using fp32") |
|
vae_dtype = 'fp32' |
|
except: |
|
raise AttributeError("ComfyUI version too old, can't autodetect properly. Set your dtypes manually.") |
|
else: |
|
vae_dtype = encoder_dtype |
|
print(f"Encoder using {vae_dtype}") |
|
|
|
dtype = convert_dtype(vae_dtype) |
|
|
|
if use_tiled_vae: |
|
from .SUPIR.utils.tilevae import VAEHook |
|
|
|
if not hasattr(SUPIR_VAE.encoder, 'original_forward'): |
|
SUPIR_VAE.denoise_encoder.original_forward = SUPIR_VAE.denoise_encoder.forward |
|
SUPIR_VAE.decoder.original_forward = SUPIR_VAE.decoder.forward |
|
|
|
SUPIR_VAE.denoise_encoder.forward = VAEHook( |
|
SUPIR_VAE.denoise_encoder, encoder_tile_size, is_decoder=False, fast_decoder=False, |
|
fast_encoder=False, color_fix=False, to_gpu=True) |
|
|
|
SUPIR_VAE.decoder.forward = VAEHook( |
|
SUPIR_VAE.decoder, decoder_tile_size // 8, is_decoder=True, fast_decoder=False, |
|
fast_encoder=False, color_fix=False, to_gpu=True) |
|
else: |
|
|
|
if hasattr(SUPIR_VAE.denoise_encoder, 'original_forward'): |
|
SUPIR_VAE.denoise_encoder.forward = SUPIR_VAE.denoise_encoder.original_forward |
|
SUPIR_VAE.decoder.forward = SUPIR_VAE.decoder.original_forward |
|
|
|
image = image.permute(0, 3, 1, 2) |
|
B, C, H, W = image.shape |
|
downscale_ratio = 32 |
|
orig_H, orig_W = H, W |
|
if W % downscale_ratio != 0: |
|
W = W - (W % downscale_ratio) |
|
if H % downscale_ratio != 0: |
|
H = H - (H % downscale_ratio) |
|
if orig_H % downscale_ratio != 0 or orig_W % downscale_ratio != 0: |
|
image = F.interpolate(image, size=(H, W), mode="bicubic") |
|
resized_image = image.to(device) |
|
|
|
pbar = comfy.utils.ProgressBar(B) |
|
out = [] |
|
out_samples = [] |
|
for img in resized_image: |
|
|
|
SUPIR_VAE.to(dtype).to(device) |
|
|
|
autocast_condition = (dtype != torch.float32) and not comfy.model_management.is_device_mps(device) |
|
with torch.autocast(comfy.model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext(): |
|
|
|
h = SUPIR_VAE.denoise_encoder(img.unsqueeze(0)) |
|
moments = SUPIR_VAE.quant_conv(h) |
|
posterior = DiagonalGaussianDistribution(moments) |
|
sample = posterior.sample() |
|
decoded_images = SUPIR_VAE.decode(sample).float() |
|
|
|
out.append(decoded_images.cpu()) |
|
out_samples.append(sample.cpu() * 0.13025) |
|
pbar.update(1) |
|
|
|
|
|
out_stacked = torch.cat(out, dim=0).to(torch.float32).permute(0, 2, 3, 1) |
|
out_samples_stacked = torch.cat(out_samples, dim=0) |
|
original_size = [orig_H, orig_W] |
|
return (SUPIR_VAE, out_stacked, {"samples": out_samples_stacked, "original_size": original_size},) |
|
|
|
class SUPIR_sample: |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"SUPIR_model": ("SUPIRMODEL",), |
|
"latents": ("LATENT",), |
|
"positive": ("SUPIR_cond_pos",), |
|
"negative": ("SUPIR_cond_neg",), |
|
"seed": ("INT", {"default": 123, "min": 0, "max": 0xffffffffffffffff, "step": 1}), |
|
"steps": ("INT", {"default": 45, "min": 3, "max": 4096, "step": 1}), |
|
"cfg_scale_start": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 100.0, "step": 0.01}), |
|
"cfg_scale_end": ("FLOAT", {"default": 4.0, "min": 0, "max": 100.0, "step": 0.01}), |
|
"EDM_s_churn": ("INT", {"default": 5, "min": 0, "max": 40, "step": 1}), |
|
"s_noise": ("FLOAT", {"default": 1.003, "min": 1.0, "max": 1.1, "step": 0.001}), |
|
"DPMPP_eta": ("FLOAT", {"default": 1.0, "min": 0, "max": 10.0, "step": 0.01}), |
|
"control_scale_start": ("FLOAT", {"default": 1.0, "min": 0, "max": 10.0, "step": 0.01}), |
|
"control_scale_end": ("FLOAT", {"default": 1.0, "min": 0, "max": 10.0, "step": 0.01}), |
|
"restore_cfg": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 20.0, "step": 0.01}), |
|
"keep_model_loaded": ("BOOLEAN", {"default": False}), |
|
"sampler": ( |
|
[ |
|
'RestoreDPMPP2MSampler', |
|
'RestoreEDMSampler', |
|
'TiledRestoreDPMPP2MSampler', |
|
'TiledRestoreEDMSampler', |
|
], { |
|
"default": 'RestoreEDMSampler' |
|
}), |
|
}, |
|
"optional": { |
|
"sampler_tile_size": ("INT", {"default": 1024, "min": 64, "max": 4096, "step": 32}), |
|
"sampler_tile_stride": ("INT", {"default": 512, "min": 32, "max": 2048, "step": 32}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("LATENT",) |
|
RETURN_NAMES = ("latent",) |
|
FUNCTION = "sample" |
|
CATEGORY = "SUPIR" |
|
DESCRIPTION = """ |
|
- **latent:** |
|
Latent to sample from, when using SUPIR latent this is just for the noise shape, |
|
it's actually not used otherwise here. Identical to feeding this comfy empty latent. |
|
If fed anything else it's used as it is, no noise is added. |
|
- **cfg:** |
|
Linearly scaled CFG is always used, first step will use the cfg_scale_start value, |
|
and that is interpolated to the cfg_scale_end value at last step. |
|
To disable scaling set these values to be the same. |
|
- **EDM_s_churn:** |
|
controls the rate of adaptation of the diffusion process to changes in noise levels |
|
over time. Has no effect with DPMPP samplers. |
|
- **s_noise:** |
|
This parameter directly controls the amount of noise added to the image at each |
|
step of the diffusion process. |
|
- **DPMPP_eta:** |
|
Scaling factor that influences the diffusion process by adjusting how the denoising |
|
process adapts to changes in noise levels over time. |
|
No effect with EDM samplers. |
|
- **control_scale:** |
|
The strenght of the SUPIR control model, scales linearly from start to end. |
|
Lower values allow more freedom from the input image. |
|
- **restore_cfg:** |
|
Controls the degree of restoration towards the original image during the diffusion |
|
process. It allows for dome fine-tuning of the process. |
|
- **samplers:** |
|
EDM samplers need lots of steps but generally have better quality. |
|
DPMPP samplers work well with lower steps, good for lightning models. |
|
Tiled samplers enable tiled diffusion process, this is very slow but allows higher |
|
resolutions to be used by saving VRAM. Tile size should be chosen so the image |
|
is evenly tiled. Tile stride affects the overlap of the tiles. Check the |
|
SUPIR Tiles -node for preview to understand how the image is tiled. |
|
|
|
""" |
|
|
|
def sample(self, SUPIR_model, latents, steps, seed, cfg_scale_end, EDM_s_churn, s_noise, positive, negative, |
|
cfg_scale_start, control_scale_start, control_scale_end, restore_cfg, keep_model_loaded, DPMPP_eta, |
|
sampler, sampler_tile_size=1024, sampler_tile_stride=512): |
|
|
|
torch.manual_seed(seed) |
|
device = mm.get_torch_device() |
|
mm.unload_all_models() |
|
mm.soft_empty_cache() |
|
|
|
self.sampler_config = { |
|
'target': f'.sgm.modules.diffusionmodules.sampling.{sampler}', |
|
'params': { |
|
'num_steps': steps, |
|
'restore_cfg': restore_cfg, |
|
's_churn': EDM_s_churn, |
|
's_noise': s_noise, |
|
'discretization_config': { |
|
'target': '.sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization' |
|
}, |
|
'guider_config': { |
|
'target': '.sgm.modules.diffusionmodules.guiders.LinearCFG', |
|
'params': { |
|
'scale': cfg_scale_start, |
|
'scale_min': cfg_scale_end |
|
} |
|
} |
|
} |
|
} |
|
if 'Tiled' in sampler: |
|
self.sampler_config['params']['tile_size'] = sampler_tile_size // 8 |
|
self.sampler_config['params']['tile_stride'] = sampler_tile_stride // 8 |
|
if 'DPMPP' in sampler: |
|
self.sampler_config['params']['eta'] = DPMPP_eta |
|
self.sampler_config['params']['restore_cfg'] = -1 |
|
if not hasattr (self,'sampler') or self.sampler_config != self.current_sampler_config: |
|
self.sampler = instantiate_from_config(self.sampler_config) |
|
self.current_sampler_config = self.sampler_config |
|
|
|
print("sampler_config: ", self.sampler_config) |
|
|
|
SUPIR_model.denoiser.to(device) |
|
SUPIR_model.model.diffusion_model.to(device) |
|
SUPIR_model.model.control_model.to(device) |
|
|
|
use_linear_control_scale = control_scale_start != control_scale_end |
|
|
|
denoiser = lambda input, sigma, c, control_scale: SUPIR_model.denoiser(SUPIR_model.model, input, sigma, c, control_scale) |
|
|
|
original_size = positive['original_size'] |
|
positive = positive['cond'] |
|
negative = negative['uncond'] |
|
samples = latents["samples"] |
|
samples = samples.to(device) |
|
|
|
|
|
out = [] |
|
pbar = comfy.utils.ProgressBar(samples.shape[0]) |
|
for i, sample in enumerate(samples): |
|
try: |
|
if 'original_size' in latents: |
|
print("Using random noise") |
|
noised_z = torch.randn_like(sample.unsqueeze(0), device=samples.device) |
|
else: |
|
print("Using latent from input") |
|
noised_z = torch.randn_like(sample.unsqueeze(0), device=samples.device) |
|
noised_z += sample.unsqueeze(0) |
|
if len(positive) != len(samples): |
|
print("Tiled sampling") |
|
_samples = self.sampler(denoiser, noised_z, cond=positive, uc=negative, x_center=sample.unsqueeze(0), control_scale=control_scale_end, |
|
use_linear_control_scale=use_linear_control_scale, control_scale_start=control_scale_start) |
|
else: |
|
|
|
|
|
_samples = self.sampler(denoiser, noised_z, cond=positive[i], uc=negative[i], x_center=sample.unsqueeze(0), control_scale=control_scale_end, |
|
use_linear_control_scale=use_linear_control_scale, control_scale_start=control_scale_start) |
|
|
|
|
|
except torch.cuda.OutOfMemoryError as e: |
|
mm.free_memory(mm.get_total_memory(mm.get_torch_device()), mm.get_torch_device()) |
|
SUPIR_model = None |
|
mm.soft_empty_cache() |
|
print("It's likely that too large of an image or batch_size for SUPIR was used," |
|
" and it has devoured all of the memory it had reserved, you may need to restart ComfyUI. Make sure you are using tiled_vae, " |
|
" you can also try using fp8 for reduced memory usage if your system supports it.") |
|
raise e |
|
out.append(_samples) |
|
print("Sampled ", i+1, " of ", samples.shape[0]) |
|
pbar.update(1) |
|
|
|
if not keep_model_loaded: |
|
SUPIR_model.denoiser.to('cpu') |
|
SUPIR_model.model.diffusion_model.to('cpu') |
|
SUPIR_model.model.control_model.to('cpu') |
|
mm.soft_empty_cache() |
|
|
|
if len(out[0].shape) == 4: |
|
samples_out_stacked = torch.cat(out, dim=0) |
|
else: |
|
samples_out_stacked = torch.stack(out, dim=0) |
|
|
|
if original_size is None: |
|
samples_out_stacked = samples_out_stacked / 0.13025 |
|
|
|
return ({"samples":samples_out_stacked, "original_size": original_size},) |
|
|
|
class SUPIR_conditioner: |
|
|
|
|
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"SUPIR_model": ("SUPIRMODEL",), |
|
"latents": ("LATENT",), |
|
"positive_prompt": ("STRING", {"multiline": True, "default": "high quality, detailed", }), |
|
"negative_prompt": ("STRING", {"multiline": True, "default": "bad quality, blurry, messy", }), |
|
}, |
|
"optional": { |
|
"captions": ("STRING", {"forceInput": True, "multiline": False, "default": "", }), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("SUPIR_cond_pos", "SUPIR_cond_neg",) |
|
RETURN_NAMES = ("positive", "negative",) |
|
FUNCTION = "condition" |
|
CATEGORY = "SUPIR" |
|
DESCRIPTION = """ |
|
Creates the conditioning for the sampler. |
|
Caption input is optional, when it receives a single caption, it's added to the positive prompt. |
|
|
|
If a list of caption is given for single input image, the captions need to match the number of tiles, |
|
refer to the SUPIR Tiles node. |
|
|
|
If a list of captions is given and it matches the incoming image batch, each image uses corresponding caption. |
|
""" |
|
|
|
def condition(self, SUPIR_model, latents, positive_prompt, negative_prompt, captions=""): |
|
|
|
device = mm.get_torch_device() |
|
mm.soft_empty_cache() |
|
|
|
if "original_size" in latents: |
|
original_size = latents["original_size"] |
|
samples = latents["samples"] |
|
else: |
|
original_size = None |
|
samples = latents["samples"] * 0.13025 |
|
|
|
N, H, W, C = samples.shape |
|
import copy |
|
|
|
if not isinstance(captions, list): |
|
captions_list = [] |
|
captions_list.append([captions]) |
|
captions_list = captions_list * N |
|
else: |
|
captions_list = captions |
|
|
|
print("captions: ", captions_list) |
|
|
|
SUPIR_model.conditioner.to(device) |
|
samples = samples.to(device) |
|
|
|
uc = [] |
|
pbar = comfy.utils.ProgressBar(N) |
|
autocast_condition = (SUPIR_model.model.dtype != torch.float32) and not comfy.model_management.is_device_mps(device) |
|
with torch.autocast(comfy.model_management.get_autocast_device(device), dtype=SUPIR_model.model.dtype) if autocast_condition else nullcontext(): |
|
if N != len(captions_list): |
|
print("Tiled captioning") |
|
c = [] |
|
uc = [] |
|
for i, caption in enumerate(captions_list): |
|
cond = {} |
|
cond['original_size_as_tuple'] = torch.tensor([[1024, 1024]]).to(device) |
|
cond['crop_coords_top_left'] = torch.tensor([[0, 0]]).to(device) |
|
cond['target_size_as_tuple'] = torch.tensor([[1024, 1024]]).to(device) |
|
cond['aesthetic_score'] = torch.tensor([[9.0]]).to(device) |
|
cond['control'] = samples[0].unsqueeze(0) |
|
|
|
uncond = copy.deepcopy(cond) |
|
uncond['txt'] = [negative_prompt] |
|
|
|
cond['txt'] = [''.join([caption[0], positive_prompt])] |
|
if i == 0: |
|
_c, uc = SUPIR_model.conditioner.get_unconditional_conditioning(cond, uncond) |
|
else: |
|
_c, _ = SUPIR_model.conditioner.get_unconditional_conditioning(cond, None) |
|
|
|
c.append(_c) |
|
pbar.update(1) |
|
else: |
|
print("Batch captioning") |
|
c = [] |
|
uc = [] |
|
for i, sample in enumerate(samples): |
|
|
|
cond = {} |
|
cond['original_size_as_tuple'] = torch.tensor([[1024, 1024]]).to(device) |
|
cond['crop_coords_top_left'] = torch.tensor([[0, 0]]).to(device) |
|
cond['target_size_as_tuple'] = torch.tensor([[1024, 1024]]).to(device) |
|
cond['aesthetic_score'] = torch.tensor([[9.0]]).to(device) |
|
cond['control'] = sample.unsqueeze(0) |
|
|
|
uncond = copy.deepcopy(cond) |
|
uncond['txt'] = [negative_prompt] |
|
cond['txt'] = [''.join([captions_list[i][0], positive_prompt])] |
|
_c, _uc = SUPIR_model.conditioner.get_unconditional_conditioning(cond, uncond) |
|
c.append(_c) |
|
uc.append(_uc) |
|
|
|
pbar.update(1) |
|
|
|
|
|
SUPIR_model.conditioner.to('cpu') |
|
|
|
if "original_size" in latents: |
|
original_size = latents["original_size"] |
|
else: |
|
original_size = None |
|
|
|
return ({"cond": c, "original_size":original_size}, {"uncond": uc},) |
|
|
|
class SUPIR_model_loader: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"supir_model": (folder_paths.get_filename_list("checkpoints"),), |
|
"sdxl_model": (folder_paths.get_filename_list("checkpoints"),), |
|
"fp8_unet": ("BOOLEAN", {"default": False}), |
|
"diffusion_dtype": ( |
|
[ |
|
'fp16', |
|
'bf16', |
|
'fp32', |
|
'auto' |
|
], { |
|
"default": 'auto' |
|
}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("SUPIRMODEL", "SUPIRVAE") |
|
RETURN_NAMES = ("SUPIR_model","SUPIR_VAE",) |
|
FUNCTION = "process" |
|
CATEGORY = "SUPIR" |
|
DESCRIPTION = """ |
|
Old loader, not recommended to be used. |
|
Loads the SUPIR model and the selected SDXL model and merges them. |
|
""" |
|
|
|
def process(self, supir_model, sdxl_model, diffusion_dtype, fp8_unet): |
|
device = mm.get_torch_device() |
|
mm.unload_all_models() |
|
|
|
SUPIR_MODEL_PATH = folder_paths.get_full_path("checkpoints", supir_model) |
|
SDXL_MODEL_PATH = folder_paths.get_full_path("checkpoints", sdxl_model) |
|
|
|
config_path = os.path.join(script_directory, "options/SUPIR_v0.yaml") |
|
clip_config_path = os.path.join(script_directory, "configs/clip_vit_config.json") |
|
tokenizer_path = os.path.join(script_directory, "configs/tokenizer") |
|
|
|
custom_config = { |
|
'sdxl_model': sdxl_model, |
|
'diffusion_dtype': diffusion_dtype, |
|
'supir_model': supir_model, |
|
'fp8_unet': fp8_unet, |
|
} |
|
|
|
if diffusion_dtype == 'auto': |
|
try: |
|
if mm.should_use_fp16(): |
|
print("Diffusion using fp16") |
|
dtype = torch.float16 |
|
model_dtype = 'fp16' |
|
elif mm.should_use_bf16(): |
|
print("Diffusion using bf16") |
|
dtype = torch.bfloat16 |
|
model_dtype = 'bf16' |
|
else: |
|
print("Diffusion using fp32") |
|
dtype = torch.float32 |
|
model_dtype = 'fp32' |
|
except: |
|
raise AttributeError("ComfyUI version too old, can't autodetect properly. Set your dtypes manually.") |
|
else: |
|
print(f"Diffusion using {diffusion_dtype}") |
|
dtype = convert_dtype(diffusion_dtype) |
|
model_dtype = diffusion_dtype |
|
|
|
if not hasattr(self, "model") or self.model is None or self.current_config != custom_config: |
|
self.current_config = custom_config |
|
self.model = None |
|
|
|
mm.soft_empty_cache() |
|
|
|
config = OmegaConf.load(config_path) |
|
|
|
if mm.XFORMERS_IS_AVAILABLE: |
|
print("Using XFORMERS") |
|
config.model.params.control_stage_config.params.spatial_transformer_attn_type = "softmax-xformers" |
|
config.model.params.network_config.params.spatial_transformer_attn_type = "softmax-xformers" |
|
config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla-xformers" |
|
|
|
config.model.params.diffusion_dtype = model_dtype |
|
config.model.target = ".SUPIR.models.SUPIR_model_v2.SUPIRModel" |
|
pbar = comfy.utils.ProgressBar(5) |
|
|
|
self.model = instantiate_from_config(config.model).cpu() |
|
self.model.model.dtype = dtype |
|
pbar.update(1) |
|
try: |
|
print(f"Attempting to load SDXL model: [{SDXL_MODEL_PATH}]") |
|
sdxl_state_dict = load_state_dict(SDXL_MODEL_PATH) |
|
self.model.load_state_dict(sdxl_state_dict, strict=False) |
|
if fp8_unet: |
|
self.model.model.to(torch.float8_e4m3fn) |
|
else: |
|
self.model.model.to(dtype) |
|
pbar.update(1) |
|
except: |
|
raise Exception("Failed to load SDXL model") |
|
|
|
|
|
try: |
|
print("Loading first clip model from SDXL checkpoint") |
|
|
|
replace_prefix = {} |
|
replace_prefix["conditioner.embedders.0.transformer."] = "" |
|
|
|
sd = comfy.utils.state_dict_prefix_replace(sdxl_state_dict, replace_prefix, filter_keys=False) |
|
clip_text_config = CLIPTextConfig.from_pretrained(clip_config_path) |
|
self.model.conditioner.embedders[0].tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) |
|
self.model.conditioner.embedders[0].transformer = CLIPTextModel(clip_text_config) |
|
self.model.conditioner.embedders[0].transformer.load_state_dict(sd, strict=False) |
|
self.model.conditioner.embedders[0].eval() |
|
self.model.conditioner.embedders[0].to(dtype) |
|
for param in self.model.conditioner.embedders[0].parameters(): |
|
param.requires_grad = False |
|
pbar.update(1) |
|
except: |
|
raise Exception("Failed to load first clip model from SDXL checkpoint") |
|
|
|
del sdxl_state_dict |
|
|
|
|
|
try: |
|
print("Loading second clip model from SDXL checkpoint") |
|
replace_prefix2 = {} |
|
replace_prefix2["conditioner.embedders.1.model."] = "" |
|
sd = comfy.utils.state_dict_prefix_replace(sd, replace_prefix2, filter_keys=True) |
|
clip_g = build_text_model_from_openai_state_dict(sd, device, cast_dtype=dtype) |
|
self.model.conditioner.embedders[1].model = clip_g |
|
self.model.conditioner.embedders[1].to(dtype) |
|
pbar.update(1) |
|
except: |
|
raise Exception("Failed to load second clip model from SDXL checkpoint") |
|
|
|
del sd, clip_g |
|
|
|
try: |
|
print(f'Attempting to load SUPIR model: [{SUPIR_MODEL_PATH}]') |
|
supir_state_dict = load_state_dict(SUPIR_MODEL_PATH) |
|
self.model.load_state_dict(supir_state_dict, strict=False) |
|
if fp8_unet: |
|
self.model.model.to(torch.float8_e4m3fn) |
|
else: |
|
self.model.model.to(dtype) |
|
del supir_state_dict |
|
pbar.update(1) |
|
except: |
|
raise Exception("Failed to load SUPIR model") |
|
mm.soft_empty_cache() |
|
|
|
return (self.model, self.model.first_stage_model,) |
|
|
|
class SUPIR_model_loader_v2: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"model" :("MODEL",), |
|
"clip": ("CLIP",), |
|
"vae": ("VAE",), |
|
"supir_model": (folder_paths.get_filename_list("checkpoints"),), |
|
"fp8_unet": ("BOOLEAN", {"default": False}), |
|
"diffusion_dtype": ( |
|
[ |
|
'fp16', |
|
'bf16', |
|
'fp32', |
|
'auto' |
|
], { |
|
"default": 'auto' |
|
}), |
|
}, |
|
"optional": { |
|
"high_vram": ("BOOLEAN", {"default": False}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("SUPIRMODEL", "SUPIRVAE") |
|
RETURN_NAMES = ("SUPIR_model","SUPIR_VAE",) |
|
FUNCTION = "process" |
|
CATEGORY = "SUPIR" |
|
DESCRIPTION = """ |
|
Loads the SUPIR model and merges it with the SDXL model. |
|
|
|
Diffusion type should be kept on auto, unless you have issues loading the model. |
|
fp8_unet casts the unet weights to torch.float8_e4m3fn, which saves a lot of VRAM but has slight quality impact. |
|
high_vram: uses Accelerate to load weights to GPU, slightly faster model loading. |
|
""" |
|
|
|
def process(self, supir_model, diffusion_dtype, fp8_unet, model, clip, vae, high_vram=False): |
|
if high_vram: |
|
device = mm.get_torch_device() |
|
else: |
|
device = mm.unet_offload_device() |
|
print("Loading weights to: ", device) |
|
mm.unload_all_models() |
|
|
|
SUPIR_MODEL_PATH = folder_paths.get_full_path("checkpoints", supir_model) |
|
|
|
config_path = os.path.join(script_directory, "options/SUPIR_v0.yaml") |
|
clip_config_path = os.path.join(script_directory, "configs/clip_vit_config.json") |
|
tokenizer_path = os.path.join(script_directory, "configs/tokenizer") |
|
|
|
custom_config = { |
|
'diffusion_dtype': diffusion_dtype, |
|
'supir_model': supir_model, |
|
'fp8_unet': fp8_unet, |
|
'model': model, |
|
"clip": clip, |
|
"vae": vae |
|
} |
|
|
|
if diffusion_dtype == 'auto': |
|
try: |
|
if mm.should_use_fp16(): |
|
print("Diffusion using fp16") |
|
dtype = torch.float16 |
|
elif mm.should_use_bf16(): |
|
print("Diffusion using bf16") |
|
dtype = torch.bfloat16 |
|
else: |
|
print("Diffusion using fp32") |
|
dtype = torch.float32 |
|
except: |
|
raise AttributeError("ComfyUI version too old, can't autodecet properly. Set your dtypes manually.") |
|
else: |
|
print(f"Diffusion using {diffusion_dtype}") |
|
dtype = convert_dtype(diffusion_dtype) |
|
|
|
if not hasattr(self, "model") or self.model is None or self.current_config != custom_config: |
|
self.current_config = custom_config |
|
self.model = None |
|
|
|
mm.soft_empty_cache() |
|
|
|
config = OmegaConf.load(config_path) |
|
if mm.XFORMERS_IS_AVAILABLE: |
|
print("Using XFORMERS") |
|
config.model.params.control_stage_config.params.spatial_transformer_attn_type = "softmax-xformers" |
|
config.model.params.network_config.params.spatial_transformer_attn_type = "softmax-xformers" |
|
config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla-xformers" |
|
|
|
config.model.target = ".SUPIR.models.SUPIR_model_v2.SUPIRModel" |
|
pbar = comfy.utils.ProgressBar(5) |
|
|
|
|
|
self.model = instantiate_from_config(config.model).cpu() |
|
self.model.model.dtype = dtype |
|
pbar.update(1) |
|
try: |
|
print(f"Attempting to load SDXL model from node inputs") |
|
mm.load_model_gpu(model) |
|
sdxl_state_dict = model.model.state_dict_for_saving(None, vae.get_sd(), None) |
|
if is_accelerate_available: |
|
for key in sdxl_state_dict: |
|
set_module_tensor_to_device(self.model, key, device=device, dtype=dtype, value=sdxl_state_dict[key]) |
|
else: |
|
self.model.load_state_dict(sdxl_state_dict, strict=False) |
|
if fp8_unet: |
|
self.model.model.to(torch.float8_e4m3fn) |
|
else: |
|
self.model.model.to(dtype) |
|
del sdxl_state_dict |
|
pbar.update(1) |
|
except: |
|
raise Exception("Failed to load SDXL model") |
|
gc.collect() |
|
mm.soft_empty_cache() |
|
|
|
try: |
|
print("Loading first clip model from SDXL checkpoint") |
|
clip_sd = None |
|
clip_model = clip.load_model() |
|
mm.load_model_gpu(clip_model) |
|
clip_sd = clip.get_sd() |
|
clip_sd = model.model.model_config.process_clip_state_dict_for_saving(clip_sd) |
|
|
|
replace_prefix = {} |
|
replace_prefix["conditioner.embedders.0.transformer."] = "" |
|
|
|
clip_l_sd = comfy.utils.state_dict_prefix_replace(clip_sd, replace_prefix, filter_keys=True) |
|
clip_text_config = CLIPTextConfig.from_pretrained(clip_config_path) |
|
self.model.conditioner.embedders[0].tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) |
|
with (init_empty_weights() if is_accelerate_available else nullcontext()): |
|
self.model.conditioner.embedders[0].transformer = CLIPTextModel(clip_text_config) |
|
if is_accelerate_available: |
|
for key in clip_l_sd: |
|
set_module_tensor_to_device(self.model.conditioner.embedders[0].transformer, key, device=device, dtype=dtype, value=clip_l_sd[key]) |
|
else: |
|
self.model.conditioner.embedders[0].transformer.load_state_dict(clip_l_sd, strict=False) |
|
self.model.conditioner.embedders[0].eval() |
|
for param in self.model.conditioner.embedders[0].parameters(): |
|
param.requires_grad = False |
|
self.model.conditioner.embedders[0].to(dtype) |
|
del clip_l_sd |
|
pbar.update(1) |
|
except: |
|
raise Exception("Failed to load first clip model from SDXL checkpoint") |
|
gc.collect() |
|
mm.soft_empty_cache() |
|
|
|
try: |
|
print("Loading second clip model from SDXL checkpoint") |
|
replace_prefix2 = {} |
|
replace_prefix2["conditioner.embedders.1.model."] = "" |
|
clip_g_sd = comfy.utils.state_dict_prefix_replace(clip_sd, replace_prefix2, filter_keys=True) |
|
clip_g = build_text_model_from_openai_state_dict(clip_g_sd, device, cast_dtype=dtype) |
|
self.model.conditioner.embedders[1].model = clip_g |
|
self.model.conditioner.embedders[1].model.to(dtype) |
|
del clip_g_sd |
|
pbar.update(1) |
|
except: |
|
raise Exception("Failed to load second clip model from SDXL checkpoint") |
|
|
|
try: |
|
print(f'Attempting to load SUPIR model: [{SUPIR_MODEL_PATH}]') |
|
supir_state_dict = load_state_dict(SUPIR_MODEL_PATH) |
|
if "Q" not in supir_model or not is_accelerate_available: |
|
for key in supir_state_dict: |
|
set_module_tensor_to_device(self.model, key, device=device, dtype=dtype, value=supir_state_dict[key]) |
|
else: |
|
self.model.load_state_dict(supir_state_dict, strict=False) |
|
if fp8_unet: |
|
self.model.model.to(torch.float8_e4m3fn) |
|
else: |
|
self.model.model.to(dtype) |
|
del supir_state_dict |
|
pbar.update(1) |
|
except: |
|
raise Exception("Failed to load SUPIR model") |
|
mm.soft_empty_cache() |
|
|
|
return (self.model, self.model.first_stage_model,) |
|
|
|
class SUPIR_model_loader_v2_clip: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"model" :("MODEL",), |
|
"clip_l": ("CLIP",), |
|
"clip_g": ("CLIP",), |
|
"vae": ("VAE",), |
|
"supir_model": (folder_paths.get_filename_list("checkpoints"),), |
|
"fp8_unet": ("BOOLEAN", {"default": False}), |
|
"diffusion_dtype": ( |
|
[ |
|
'fp16', |
|
'bf16', |
|
'fp32', |
|
'auto' |
|
], { |
|
"default": 'auto' |
|
}), |
|
}, |
|
"optional": { |
|
"high_vram": ("BOOLEAN", {"default": False}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("SUPIRMODEL", "SUPIRVAE") |
|
RETURN_NAMES = ("SUPIR_model","SUPIR_VAE",) |
|
FUNCTION = "process" |
|
CATEGORY = "SUPIR" |
|
DESCRIPTION = """ |
|
Loads the SUPIR model and merges it with the SDXL model. |
|
|
|
Diffusion type should be kept on auto, unless you have issues loading the model. |
|
fp8_unet casts the unet weights to torch.float8_e4m3fn, which saves a lot of VRAM but has slight quality impact. |
|
high_vram: uses Accelerate to load weights to GPU, slightly faster model loading. |
|
""" |
|
|
|
def process(self, supir_model, diffusion_dtype, fp8_unet, model, clip_l, clip_g, vae, high_vram=False): |
|
if high_vram: |
|
device = mm.get_torch_device() |
|
else: |
|
device = mm.unet_offload_device() |
|
print("Loading weights to: ", device) |
|
mm.unload_all_models() |
|
|
|
SUPIR_MODEL_PATH = folder_paths.get_full_path("checkpoints", supir_model) |
|
|
|
config_path = os.path.join(script_directory, "options/SUPIR_v0.yaml") |
|
clip_config_path = os.path.join(script_directory, "configs/clip_vit_config.json") |
|
tokenizer_path = os.path.join(script_directory, "configs/tokenizer") |
|
|
|
custom_config = { |
|
'diffusion_dtype': diffusion_dtype, |
|
'supir_model': supir_model, |
|
'fp8_unet': fp8_unet, |
|
'model': model, |
|
"clip": clip_l, |
|
"clip_g": clip_g, |
|
"vae": vae |
|
} |
|
|
|
if diffusion_dtype == 'auto': |
|
try: |
|
if mm.should_use_fp16(): |
|
print("Diffusion using fp16") |
|
dtype = torch.float16 |
|
elif mm.should_use_bf16(): |
|
print("Diffusion using bf16") |
|
dtype = torch.bfloat16 |
|
else: |
|
print("Diffusion using fp32") |
|
dtype = torch.float32 |
|
except: |
|
raise AttributeError("ComfyUI version too old, can't autodecet properly. Set your dtypes manually.") |
|
else: |
|
print(f"Diffusion using {diffusion_dtype}") |
|
dtype = convert_dtype(diffusion_dtype) |
|
|
|
if not hasattr(self, "model") or self.model is None or self.current_config != custom_config: |
|
self.current_config = custom_config |
|
self.model = None |
|
|
|
mm.soft_empty_cache() |
|
|
|
config = OmegaConf.load(config_path) |
|
if mm.XFORMERS_IS_AVAILABLE: |
|
print("Using XFORMERS") |
|
config.model.params.control_stage_config.params.spatial_transformer_attn_type = "softmax-xformers" |
|
config.model.params.network_config.params.spatial_transformer_attn_type = "softmax-xformers" |
|
config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla-xformers" |
|
|
|
config.model.target = ".SUPIR.models.SUPIR_model_v2.SUPIRModel" |
|
pbar = comfy.utils.ProgressBar(5) |
|
|
|
|
|
self.model = instantiate_from_config(config.model).cpu() |
|
self.model.model.dtype = dtype |
|
pbar.update(1) |
|
try: |
|
print(f"Attempting to load SDXL model from node inputs") |
|
mm.load_model_gpu(model) |
|
sdxl_state_dict = model.model.state_dict_for_saving(None, vae.get_sd(), None) |
|
if is_accelerate_available: |
|
for key in sdxl_state_dict: |
|
set_module_tensor_to_device(self.model, key, device=device, dtype=dtype, value=sdxl_state_dict[key]) |
|
else: |
|
self.model.load_state_dict(sdxl_state_dict, strict=False) |
|
if fp8_unet: |
|
self.model.model.to(torch.float8_e4m3fn) |
|
else: |
|
self.model.model.to(dtype) |
|
del sdxl_state_dict |
|
pbar.update(1) |
|
except: |
|
raise Exception("Failed to load SDXL model") |
|
gc.collect() |
|
mm.soft_empty_cache() |
|
|
|
try: |
|
print("Loading first clip model from SDXL checkpoint") |
|
clip_l_sd = None |
|
clip_l_model = clip_l.load_model() |
|
mm.load_model_gpu(clip_l_model) |
|
clip_l_sd = clip_l.get_sd() |
|
clip_l_sd = model.model.model_config.process_clip_state_dict_for_saving(clip_l_sd) |
|
|
|
replace_prefix = {} |
|
replace_prefix["conditioner.embedders.0.transformer."] = "" |
|
|
|
clip_l_sd = comfy.utils.state_dict_prefix_replace(clip_l_sd, replace_prefix, filter_keys=True) |
|
clip_text_config = CLIPTextConfig.from_pretrained(clip_config_path) |
|
self.model.conditioner.embedders[0].tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) |
|
with (init_empty_weights() if is_accelerate_available else nullcontext()): |
|
self.model.conditioner.embedders[0].transformer = CLIPTextModel(clip_text_config) |
|
if is_accelerate_available: |
|
for key in clip_l_sd: |
|
set_module_tensor_to_device(self.model.conditioner.embedders[0].transformer, key, device=device, dtype=dtype, value=clip_l_sd[key]) |
|
else: |
|
self.model.conditioner.embedders[0].transformer.load_state_dict(clip_l_sd, strict=False) |
|
self.model.conditioner.embedders[0].eval() |
|
for param in self.model.conditioner.embedders[0].parameters(): |
|
param.requires_grad = False |
|
self.model.conditioner.embedders[0].to(dtype) |
|
del clip_l_sd |
|
pbar.update(1) |
|
except: |
|
raise Exception("Failed to load first clip model from SDXL checkpoint") |
|
gc.collect() |
|
mm.soft_empty_cache() |
|
|
|
try: |
|
print("Loading second clip model from SDXL checkpoint") |
|
clip_g_sd = None |
|
clip_g_model = clip_g.load_model() |
|
mm.load_model_gpu(clip_g_model) |
|
clip_g_sd = clip_g.get_sd() |
|
clip_g_sd = model.model.model_config.process_clip_state_dict_for_saving(clip_g_sd) |
|
|
|
replace_prefix2 = {} |
|
replace_prefix2["conditioner.embedders.1.model."] = "" |
|
clip_g_sd = comfy.utils.state_dict_prefix_replace(clip_g_sd, replace_prefix2, filter_keys=True) |
|
clip_g = build_text_model_from_openai_state_dict(clip_g_sd, device, cast_dtype=dtype) |
|
self.model.conditioner.embedders[1].model = clip_g |
|
self.model.conditioner.embedders[1].model.to(dtype) |
|
del clip_g_sd |
|
pbar.update(1) |
|
except: |
|
raise Exception("Failed to load second clip model from SDXL checkpoint") |
|
|
|
try: |
|
print(f'Attempting to load SUPIR model: [{SUPIR_MODEL_PATH}]') |
|
supir_state_dict = load_state_dict(SUPIR_MODEL_PATH) |
|
if "Q" not in supir_model or not is_accelerate_available: |
|
for key in supir_state_dict: |
|
set_module_tensor_to_device(self.model, key, device=device, dtype=dtype, value=supir_state_dict[key]) |
|
else: |
|
self.model.load_state_dict(supir_state_dict, strict=False) |
|
if fp8_unet: |
|
self.model.model.to(torch.float8_e4m3fn) |
|
else: |
|
self.model.model.to(dtype) |
|
del supir_state_dict |
|
pbar.update(1) |
|
except: |
|
raise Exception("Failed to load SUPIR model") |
|
mm.soft_empty_cache() |
|
|
|
return (self.model, self.model.first_stage_model,) |
|
|
|
class SUPIR_tiles: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"image": ("IMAGE",), |
|
"tile_size": ("INT", {"default": 512, "min": 64, "max": 8192, "step": 64}), |
|
"tile_stride": ("INT", {"default": 256, "min": 64, "max": 8192, "step": 64}), |
|
|
|
} |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE", "INT", "INT",) |
|
RETURN_NAMES = ("image_tiles", "tile_size", "tile_stride",) |
|
FUNCTION = "tile" |
|
CATEGORY = "SUPIR" |
|
DESCRIPTION = """ |
|
Tiles the image with same function as the Tiled samplers use. |
|
Useful for previewing the tiling and generating captions per tile (WIP feature) |
|
""" |
|
|
|
def tile(self, image, tile_size, tile_stride): |
|
|
|
def _sliding_windows(h: int, w: int, tile_size: int, tile_stride: int): |
|
hi_list = list(range(0, h - tile_size + 1, tile_stride)) |
|
if (h - tile_size) % tile_stride != 0: |
|
hi_list.append(h - tile_size) |
|
|
|
wi_list = list(range(0, w - tile_size + 1, tile_stride)) |
|
if (w - tile_size) % tile_stride != 0: |
|
wi_list.append(w - tile_size) |
|
|
|
coords = [] |
|
for hi in hi_list: |
|
for wi in wi_list: |
|
coords.append((hi, hi + tile_size, wi, wi + tile_size)) |
|
return coords |
|
|
|
image = image.permute(0, 3, 1, 2) |
|
_, _, h, w = image.shape |
|
|
|
tiles_iterator = _sliding_windows(h, w, tile_size, tile_stride) |
|
|
|
tiles = [] |
|
for hi, hi_end, wi, wi_end in tiles_iterator: |
|
tile = image[:, :, hi:hi_end, wi:wi_end] |
|
|
|
tiles.append(tile) |
|
out = torch.cat(tiles, dim=0).to(torch.float32).permute(0, 2, 3, 1) |
|
print(out.shape) |
|
print("len(tiles): ", len(tiles)) |
|
|
|
return (out, tile_size, tile_stride,) |
|
|