# Adapted from https://github.com/Limitex/ComfyUI-Diffusers/blob/main/utils.py import io import os import torch import requests import numpy as np from PIL import Image from omegaconf import OmegaConf from torchvision.transforms import ToTensor from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( assign_to_checkpoint, conv_attn_to_linear, create_vae_diffusers_config, renew_vae_attention_paths, renew_vae_resnet_paths, ) from diffusers import ( AutoencoderKL, DDIMScheduler, DDPMScheduler, DEISMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, UniPCMultistepScheduler, LCMScheduler, StableDiffusionXLPipeline, ) from .mvadapter.pipelines.pipeline_mvadapter_t2mv_sdxl import MVAdapterT2MVSDXLPipeline from .mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline from .mvadapter.pipelines.pipeline_mvadapter_i2mv_sd import MVAdapterI2MVSDPipeline from .mvadapter.pipelines.pipeline_mvadapter_t2mv_sd import MVAdapterT2MVSDPipeline from .mvadapter.utils import ( get_orthogonal_camera, get_plucker_embeds_from_cameras_ortho, make_image_grid, ) NODE_CACHE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache") PIPELINES = { "StableDiffusionXLPipeline": StableDiffusionXLPipeline, "MVAdapterT2MVSDXLPipeline": MVAdapterT2MVSDXLPipeline, "MVAdapterI2MVSDXLPipeline": MVAdapterI2MVSDXLPipeline, "MVAdapterI2MVSDPipeline": MVAdapterI2MVSDPipeline, "MVAdapterT2MVSDPipeline": MVAdapterT2MVSDPipeline, } SCHEDULERS = { "DDIM": DDIMScheduler, "DDPM": DDPMScheduler, "DEISMultistep": DEISMultistepScheduler, "DPMSolverMultistep": DPMSolverMultistepScheduler, "DPMSolverSinglestep": DPMSolverSinglestepScheduler, "EulerAncestralDiscrete": EulerAncestralDiscreteScheduler, "EulerDiscrete": EulerDiscreteScheduler, "HeunDiscrete": HeunDiscreteScheduler, "KDPM2AncestralDiscrete": KDPM2AncestralDiscreteScheduler, "KDPM2Discrete": KDPM2DiscreteScheduler, "UniPCMultistep": UniPCMultistepScheduler, "LCM": LCMScheduler, } MVADAPTERS = [ "mvadapter_t2mv_sdxl.safetensors", "mvadapter_i2mv_sdxl.safetensors", "mvadapter_i2mv_sdxl_beta.safetensors", "mvadapter_t2mv_sd21.safetensors", "mvadapter_i2mv_sd21.safetensors", ] # Reference from : https://github.com/huggingface/diffusers/blob/main/scripts/convert_vae_pt_to_diffusers.py def custom_convert_ldm_vae_checkpoint(checkpoint, config): vae_state_dict = checkpoint new_checkpoint = {} new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[ "encoder.conv_out.weight" ] new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[ "encoder.norm_out.weight" ] new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[ "encoder.norm_out.bias" ] new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[ "decoder.conv_out.weight" ] new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[ "decoder.norm_out.weight" ] new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[ "decoder.norm_out.bias" ] new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] # Retrieves the keys for the encoder down blocks only num_down_blocks = len( { ".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer } ) down_blocks = { layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) } # Retrieves the keys for the decoder up blocks only num_up_blocks = len( { ".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer } ) up_blocks = { layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) } for i in range(num_down_blocks): resnets = [ key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key ] if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = ( vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight") ) new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = ( vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias") ) paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} assign_to_checkpoint( paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config, ) mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] num_mid_res_blocks = 2 for i in range(1, num_mid_res_blocks + 1): resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} assign_to_checkpoint( paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config, ) mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} assign_to_checkpoint( paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config, ) conv_attn_to_linear(new_checkpoint) for i in range(num_up_blocks): block_id = num_up_blocks - 1 - i resnets = [ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key ] if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = ( vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"] ) new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = ( vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"] ) paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} assign_to_checkpoint( paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config, ) mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] num_mid_res_blocks = 2 for i in range(1, num_mid_res_blocks + 1): resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} assign_to_checkpoint( paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config, ) mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} assign_to_checkpoint( paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config, ) conv_attn_to_linear(new_checkpoint) return new_checkpoint # Reference from : https://github.com/huggingface/diffusers/blob/main/scripts/convert_vae_pt_to_diffusers.py def vae_pt_to_vae_diffuser(checkpoint_path: str, force_upcast: bool = True): try: config_path = os.path.join( NODE_CACHE_PATH, "stable-diffusion-v1-inference.yaml" ) original_config = OmegaConf.load(config_path) except FileNotFoundError as e: print(f"Warning: {e}") r = requests.get( "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" ) io_obj = io.BytesIO(r.content) original_config = OmegaConf.load(io_obj) image_size = 512 device = "cuda" if torch.cuda.is_available() else "cpu" if checkpoint_path.endswith("safetensors"): from safetensors import safe_open checkpoint = {} with safe_open(checkpoint_path, framework="pt", device="cpu") as f: for key in f.keys(): checkpoint[key] = f.get_tensor(key) else: checkpoint = torch.load(checkpoint_path, map_location=device)["state_dict"] # Convert the VAE model. vae_config = create_vae_diffusers_config(original_config, image_size=image_size) vae_config.update({"force_upcast": force_upcast}) converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint, vae_config) vae = AutoencoderKL(**vae_config) vae.load_state_dict(converted_vae_checkpoint) return vae def convert_images_to_tensors(images: list[Image.Image]): return torch.stack([np.transpose(ToTensor()(image), (1, 2, 0)) for image in images]) def convert_tensors_to_images(images: torch.tensor): return [ Image.fromarray(np.clip(255.0 * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in images ] def resize_images(images: list[Image.Image], size: tuple[int, int]): return [image.resize(size) for image in images] def prepare_camera_embed(num_views, size, device, azimuth_degrees=None): cameras = get_orthogonal_camera( elevation_deg=[0] * num_views, distance=[1.8] * num_views, left=-0.55, right=0.55, bottom=-0.55, top=0.55, azimuth_deg=[x - 90 for x in azimuth_degrees], device=device, ) plucker_embeds = get_plucker_embeds_from_cameras_ortho( cameras.c2w, [1.1] * num_views, size ) control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1) return control_images def preprocess_image(image: Image.Image, height, width): image = np.array(image) alpha = image[..., 3] > 0 H, W = alpha.shape # get the bounding box of alpha y, x = np.where(alpha) y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H) x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W) image_center = image[y0:y1, x0:x1] # resize the longer side to H * 0.9 H, W, _ = image_center.shape if H > W: W = int(W * (height * 0.9) / H) H = int(height * 0.9) else: H = int(H * (width * 0.9) / W) W = int(width * 0.9) image_center = np.array(Image.fromarray(image_center).resize((W, H))) # pad to H, W start_h = (height - H) // 2 start_w = (width - W) // 2 image = np.zeros((height, width, 4), dtype=np.uint8) image[start_h : start_h + H, start_w : start_w + W] = image_center image = image.astype(np.float32) / 255.0 image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5 image = (image * 255).clip(0, 255).astype(np.uint8) image = Image.fromarray(image) return image