Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,334 Bytes
9892334 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
import torch
from torch import nn
from diffusers import UniPCMultistepScheduler, AutoencoderKL
from safetensors.torch import load_file
from pipeline.pipeline_controlnext import StableDiffusionXLControlNeXtPipeline
from models.unet import UNet2DConditionModel, UNET_CONFIG
from models.controlnet import ControlNetModel
from . import utils
def get_pipeline(
pretrained_model_name_or_path,
unet_model_name_or_path,
controlnet_model_name_or_path,
vae_model_name_or_path=None,
lora_path=None,
load_weight_increasement=False,
enable_xformers_memory_efficient_attention=False,
revision=None,
variant=None,
hf_cache_dir=None,
use_safetensors=True,
device=None,
):
pipeline_init_kwargs = {}
if controlnet_model_name_or_path is not None:
print(f"loading controlnet from {controlnet_model_name_or_path}")
controlnet = ControlNetModel()
if controlnet_model_name_or_path is not None:
utils.load_safetensors(controlnet, controlnet_model_name_or_path)
else:
controlnet.scale = nn.Parameter(torch.tensor(0.), requires_grad=False)
controlnet.to(device, dtype=torch.float32)
pipeline_init_kwargs["controlnet"] = controlnet
utils.log_model_info(controlnet, "controlnext")
else:
print(f"no controlnet")
print(f"loading unet from {pretrained_model_name_or_path}")
if os.path.isfile(pretrained_model_name_or_path):
# load unet from local checkpoint
unet_sd = load_file(pretrained_model_name_or_path) if pretrained_model_name_or_path.endswith(".safetensors") else torch.load(pretrained_model_name_or_path)
unet_sd = utils.extract_unet_state_dict(unet_sd)
unet_sd = utils.convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
unet = UNet2DConditionModel.from_config(UNET_CONFIG)
unet.load_state_dict(unet_sd, strict=True)
else:
from huggingface_hub import hf_hub_download
filename = "diffusion_pytorch_model"
if variant == "fp16":
filename += ".fp16"
if use_safetensors:
filename += ".safetensors"
else:
filename += ".pt"
unet_file = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="unet" + '/' + filename,
cache_dir=hf_cache_dir,
)
unet_sd = load_file(unet_file) if unet_file.endswith(".safetensors") else torch.load(pretrained_model_name_or_path)
unet_sd = utils.extract_unet_state_dict(unet_sd)
unet_sd = utils.convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
unet = UNet2DConditionModel.from_config(UNET_CONFIG)
unet.load_state_dict(unet_sd, strict=True)
unet = unet.to(dtype=torch.float16)
utils.log_model_info(unet, "unet")
if unet_model_name_or_path is not None:
print(f"loading controlnext unet from {unet_model_name_or_path}")
controlnext_unet_sd = load_file(unet_model_name_or_path)
controlnext_unet_sd = utils.convert_to_controlnext_unet_state_dict(controlnext_unet_sd)
unet_sd = unet.state_dict()
assert all(
k in unet_sd for k in controlnext_unet_sd), \
f"controlnext unet state dict is not compatible with unet state dict, missing keys: {set(controlnext_unet_sd.keys()) - set(unet_sd.keys())}, extra keys: {set(unet_sd.keys()) - set(controlnext_unet_sd.keys())}"
if load_weight_increasement:
print("loading weight increasement")
for k in controlnext_unet_sd.keys():
controlnext_unet_sd[k] = controlnext_unet_sd[k] + unet_sd[k]
unet.load_state_dict(controlnext_unet_sd, strict=False)
utils.log_model_info(controlnext_unet_sd, "controlnext unet")
pipeline_init_kwargs["unet"] = unet
if vae_model_name_or_path is not None:
print(f"loading vae from {vae_model_name_or_path}")
vae = AutoencoderKL.from_pretrained(vae_model_name_or_path, cache_dir=hf_cache_dir, torch_dtype=torch.float16).to(device)
pipeline_init_kwargs["vae"] = vae
print(f"loading pipeline from {pretrained_model_name_or_path}")
if os.path.isfile(pretrained_model_name_or_path):
pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_single_file(
pretrained_model_name_or_path,
use_safetensors=pretrained_model_name_or_path.endswith(".safetensors"),
local_files_only=True,
cache_dir=hf_cache_dir,
**pipeline_init_kwargs,
)
else:
pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_pretrained(
pretrained_model_name_or_path,
revision=revision,
variant=variant,
use_safetensors=use_safetensors,
cache_dir=hf_cache_dir,
**pipeline_init_kwargs,
)
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline.set_progress_bar_config()
pipeline = pipeline.to(device, dtype=torch.float16)
if lora_path is not None:
pipeline.load_lora_weights(lora_path)
if enable_xformers_memory_efficient_attention:
pipeline.enable_xformers_memory_efficient_attention()
return pipeline
def get_scheduler(
scheduler_name,
scheduler_config,
):
if scheduler_name == 'Euler A':
from diffusers.schedulers import EulerAncestralDiscreteScheduler
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config)
elif scheduler_name == 'UniPC':
from diffusers.schedulers import UniPCMultistepScheduler
scheduler = UniPCMultistepScheduler.from_config(scheduler_config)
elif scheduler_name == 'Euler':
from diffusers.schedulers import EulerDiscreteScheduler
scheduler = EulerDiscreteScheduler.from_config(scheduler_config)
elif scheduler_name == 'DDIM':
from diffusers.schedulers import DDIMScheduler
scheduler = DDIMScheduler.from_config(scheduler_config)
elif scheduler_name == 'DDPM':
from diffusers.schedulers import DDPMScheduler
scheduler = DDPMScheduler.from_config(scheduler_config)
else:
raise ValueError(f"Unknown scheduler: {scheduler_name}")
return scheduler
|