import os import torch from torch import nn from diffusers import UniPCMultistepScheduler, AutoencoderKL from safetensors.torch import load_file from pipeline.pipeline_controlnext import StableDiffusionXLControlNeXtPipeline from models.unet import UNet2DConditionModel, UNET_CONFIG from models.controlnet import ControlNetModel from . import utils def get_pipeline( pretrained_model_name_or_path, unet_model_name_or_path, controlnet_model_name_or_path, vae_model_name_or_path=None, lora_path=None, load_weight_increasement=False, enable_xformers_memory_efficient_attention=False, revision=None, variant=None, hf_cache_dir=None, use_safetensors=True, device=None, ): pipeline_init_kwargs = {} if controlnet_model_name_or_path is not None: print(f"loading controlnet from {controlnet_model_name_or_path}") controlnet = ControlNetModel() if controlnet_model_name_or_path is not None: utils.load_safetensors(controlnet, controlnet_model_name_or_path) else: controlnet.scale = nn.Parameter(torch.tensor(0.), requires_grad=False) controlnet.to(device, dtype=torch.float32) pipeline_init_kwargs["controlnet"] = controlnet utils.log_model_info(controlnet, "controlnext") else: print(f"no controlnet") print(f"loading unet from {pretrained_model_name_or_path}") if os.path.isfile(pretrained_model_name_or_path): # load unet from local checkpoint unet_sd = load_file(pretrained_model_name_or_path) if pretrained_model_name_or_path.endswith(".safetensors") else torch.load(pretrained_model_name_or_path) unet_sd = utils.extract_unet_state_dict(unet_sd) unet_sd = utils.convert_sdxl_unet_state_dict_to_diffusers(unet_sd) unet = UNet2DConditionModel.from_config(UNET_CONFIG) unet.load_state_dict(unet_sd, strict=True) else: from huggingface_hub import hf_hub_download filename = "diffusion_pytorch_model" if variant == "fp16": filename += ".fp16" if use_safetensors: filename += ".safetensors" else: filename += ".pt" unet_file = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="unet" + '/' + filename, cache_dir=hf_cache_dir, ) unet_sd = load_file(unet_file) if unet_file.endswith(".safetensors") else torch.load(pretrained_model_name_or_path) unet_sd = utils.extract_unet_state_dict(unet_sd) unet_sd = utils.convert_sdxl_unet_state_dict_to_diffusers(unet_sd) unet = UNet2DConditionModel.from_config(UNET_CONFIG) unet.load_state_dict(unet_sd, strict=True) unet = unet.to(dtype=torch.float16) utils.log_model_info(unet, "unet") if unet_model_name_or_path is not None: print(f"loading controlnext unet from {unet_model_name_or_path}") controlnext_unet_sd = load_file(unet_model_name_or_path) controlnext_unet_sd = utils.convert_to_controlnext_unet_state_dict(controlnext_unet_sd) unet_sd = unet.state_dict() assert all( k in unet_sd for k in controlnext_unet_sd), \ f"controlnext unet state dict is not compatible with unet state dict, missing keys: {set(controlnext_unet_sd.keys()) - set(unet_sd.keys())}, extra keys: {set(unet_sd.keys()) - set(controlnext_unet_sd.keys())}" if load_weight_increasement: print("loading weight increasement") for k in controlnext_unet_sd.keys(): controlnext_unet_sd[k] = controlnext_unet_sd[k] + unet_sd[k] unet.load_state_dict(controlnext_unet_sd, strict=False) utils.log_model_info(controlnext_unet_sd, "controlnext unet") pipeline_init_kwargs["unet"] = unet if vae_model_name_or_path is not None: print(f"loading vae from {vae_model_name_or_path}") vae = AutoencoderKL.from_pretrained(vae_model_name_or_path, cache_dir=hf_cache_dir, torch_dtype=torch.float16).to(device) pipeline_init_kwargs["vae"] = vae print(f"loading pipeline from {pretrained_model_name_or_path}") if os.path.isfile(pretrained_model_name_or_path): pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_single_file( pretrained_model_name_or_path, use_safetensors=pretrained_model_name_or_path.endswith(".safetensors"), local_files_only=True, cache_dir=hf_cache_dir, **pipeline_init_kwargs, ) else: pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_pretrained( pretrained_model_name_or_path, revision=revision, variant=variant, use_safetensors=use_safetensors, cache_dir=hf_cache_dir, **pipeline_init_kwargs, ) pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) pipeline.set_progress_bar_config() pipeline = pipeline.to(device, dtype=torch.float16) if lora_path is not None: pipeline.load_lora_weights(lora_path) if enable_xformers_memory_efficient_attention: pipeline.enable_xformers_memory_efficient_attention() return pipeline def get_scheduler( scheduler_name, scheduler_config, ): if scheduler_name == 'Euler A': from diffusers.schedulers import EulerAncestralDiscreteScheduler scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config) elif scheduler_name == 'UniPC': from diffusers.schedulers import UniPCMultistepScheduler scheduler = UniPCMultistepScheduler.from_config(scheduler_config) elif scheduler_name == 'Euler': from diffusers.schedulers import EulerDiscreteScheduler scheduler = EulerDiscreteScheduler.from_config(scheduler_config) elif scheduler_name == 'DDIM': from diffusers.schedulers import DDIMScheduler scheduler = DDIMScheduler.from_config(scheduler_config) elif scheduler_name == 'DDPM': from diffusers.schedulers import DDPMScheduler scheduler = DDPMScheduler.from_config(scheduler_config) else: raise ValueError(f"Unknown scheduler: {scheduler_name}") return scheduler