Spaces:
Build error
Build error
from typing import List, NamedTuple, Any | |
import numpy as np | |
import cv2 | |
import torch | |
from safetensors.torch import load_file | |
from diffusers import UNet2DConditionModel | |
from diffusers.models.unet_2d_condition import UNet2DConditionOutput | |
import library.model_util as model_util | |
class ControlNetInfo(NamedTuple): | |
unet: Any | |
net: Any | |
prep: Any | |
weight: float | |
ratio: float | |
class ControlNet(torch.nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
# make control model | |
self.control_model = torch.nn.Module() | |
dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280] | |
zero_convs = torch.nn.ModuleList() | |
for i, dim in enumerate(dims): | |
sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)]) | |
zero_convs.append(sub_list) | |
self.control_model.add_module("zero_convs", zero_convs) | |
middle_block_out = torch.nn.Conv2d(1280, 1280, 1) | |
self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out])) | |
dims = [16, 16, 32, 32, 96, 96, 256, 320] | |
strides = [1, 1, 2, 1, 2, 1, 2, 1] | |
prev_dim = 3 | |
input_hint_block = torch.nn.Sequential() | |
for i, (dim, stride) in enumerate(zip(dims, strides)): | |
input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1)) | |
if i < len(dims) - 1: | |
input_hint_block.append(torch.nn.SiLU()) | |
prev_dim = dim | |
self.control_model.add_module("input_hint_block", input_hint_block) | |
def load_control_net(v2, unet, model): | |
device = unet.device | |
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む | |
# state dictを読み込む | |
print(f"ControlNet: loading control SD model : {model}") | |
if model_util.is_safetensors(model): | |
ctrl_sd_sd = load_file(model) | |
else: | |
ctrl_sd_sd = torch.load(model, map_location='cpu') | |
ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd) | |
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む | |
is_difference = "difference" in ctrl_sd_sd | |
print("ControlNet: loading difference:", is_difference) | |
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく | |
# またTransfer Controlの元weightとなる | |
ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict()) | |
# 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける | |
for key in list(ctrl_unet_sd_sd.keys()): | |
ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone() | |
zero_conv_sd = {} | |
for key in list(ctrl_sd_sd.keys()): | |
if key.startswith("control_"): | |
unet_key = "model.diffusion_" + key[len("control_"):] | |
if unet_key not in ctrl_unet_sd_sd: # zero conv | |
zero_conv_sd[key] = ctrl_sd_sd[key] | |
continue | |
if is_difference: # Transfer Control | |
ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype) | |
else: | |
ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype) | |
unet_config = model_util.create_unet_diffusers_config(v2) | |
ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict | |
# ControlNetのU-Netを作成する | |
ctrl_unet = UNet2DConditionModel(**unet_config) | |
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd) | |
print("ControlNet: loading Control U-Net:", info) | |
# U-Net以外のControlNetを作成する | |
# TODO support middle only | |
ctrl_net = ControlNet() | |
info = ctrl_net.load_state_dict(zero_conv_sd) | |
print("ControlNet: loading ControlNet:", info) | |
ctrl_unet.to(unet.device, dtype=unet.dtype) | |
ctrl_net.to(unet.device, dtype=unet.dtype) | |
return ctrl_unet, ctrl_net | |
def load_preprocess(prep_type: str): | |
if prep_type is None or prep_type.lower() == "none": | |
return None | |
if prep_type.startswith("canny"): | |
args = prep_type.split("_") | |
th1 = int(args[1]) if len(args) >= 2 else 63 | |
th2 = int(args[2]) if len(args) >= 3 else 191 | |
def canny(img): | |
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
return cv2.Canny(img, th1, th2) | |
return canny | |
print("Unsupported prep type:", prep_type) | |
return None | |
def preprocess_ctrl_net_hint_image(image): | |
image = np.array(image).astype(np.float32) / 255.0 | |
# ControlNetのサンプルはcv2を使っているが、読み込みはGradioなので実はRGBになっている | |
# image = image[:, :, ::-1].copy() # rgb to bgr | |
image = image[None].transpose(0, 3, 1, 2) # nchw | |
image = torch.from_numpy(image) | |
return image # 0 to 1 | |
def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints): | |
guided_hints = [] | |
for i, cnet_info in enumerate(control_nets): | |
# hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること | |
b_hints = [] | |
if len(hints) == 1: # すべて同じ画像をhintとして使う | |
hint = hints[0] | |
if cnet_info.prep is not None: | |
hint = cnet_info.prep(hint) | |
hint = preprocess_ctrl_net_hint_image(hint) | |
b_hints = [hint for _ in range(b_size)] | |
else: | |
for bi in range(b_size): | |
hint = hints[(bi * len(control_nets) + i) % len(hints)] | |
if cnet_info.prep is not None: | |
hint = cnet_info.prep(hint) | |
hint = preprocess_ctrl_net_hint_image(hint) | |
b_hints.append(hint) | |
b_hints = torch.cat(b_hints, dim=0) | |
b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype) | |
guided_hint = cnet_info.net.control_model.input_hint_block(b_hints) | |
guided_hints.append(guided_hint) | |
return guided_hints | |
def call_unet_and_control_net(step, num_latent_input, original_unet, control_nets: List[ControlNetInfo], guided_hints, current_ratio, sample, timestep, encoder_hidden_states): | |
# ControlNet | |
# 複数のControlNetの場合は、出力をマージするのではなく交互に適用する | |
cnet_cnt = len(control_nets) | |
cnet_idx = step % cnet_cnt | |
cnet_info = control_nets[cnet_idx] | |
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) | |
if cnet_info.ratio < current_ratio: | |
return original_unet(sample, timestep, encoder_hidden_states) | |
guided_hint = guided_hints[cnet_idx] | |
guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1)) | |
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states) | |
outs = [o * cnet_info.weight for o in outs] | |
# U-Net | |
return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states) | |
""" | |
# これはmergeのバージョン | |
# ControlNet | |
cnet_outs_list = [] | |
for i, cnet_info in enumerate(control_nets): | |
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) | |
if cnet_info.ratio < current_ratio: | |
continue | |
guided_hint = guided_hints[i] | |
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states) | |
for i in range(len(outs)): | |
outs[i] *= cnet_info.weight | |
cnet_outs_list.append(outs) | |
count = len(cnet_outs_list) | |
if count == 0: | |
return original_unet(sample, timestep, encoder_hidden_states) | |
# sum of controlnets | |
for i in range(1, count): | |
cnet_outs_list[0] += cnet_outs_list[i] | |
# U-Net | |
return unet_forward(False, cnet_info.net, original_unet, None, cnet_outs_list[0], sample, timestep, encoder_hidden_states) | |
""" | |
def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states): | |
# copy from UNet2DConditionModel | |
default_overall_up_factor = 2**unet.num_upsamplers | |
forward_upsample_size = False | |
upsample_size = None | |
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): | |
print("Forward upsample size to force interpolation output size.") | |
forward_upsample_size = True | |
# 0. center input if necessary | |
if unet.config.center_input_sample: | |
sample = 2 * sample - 1.0 | |
# 1. time | |
timesteps = timestep | |
if not torch.is_tensor(timesteps): | |
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | |
# This would be a good case for the `match` statement (Python 3.10+) | |
is_mps = sample.device.type == "mps" | |
if isinstance(timestep, float): | |
dtype = torch.float32 if is_mps else torch.float64 | |
else: | |
dtype = torch.int32 if is_mps else torch.int64 | |
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | |
elif len(timesteps.shape) == 0: | |
timesteps = timesteps[None].to(sample.device) | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
timesteps = timesteps.expand(sample.shape[0]) | |
t_emb = unet.time_proj(timesteps) | |
# timesteps does not contain any weights and will always return f32 tensors | |
# but time_embedding might actually be running in fp16. so we need to cast here. | |
# there might be better ways to encapsulate this. | |
t_emb = t_emb.to(dtype=unet.dtype) | |
emb = unet.time_embedding(t_emb) | |
outs = [] # output of ControlNet | |
zc_idx = 0 | |
# 2. pre-process | |
sample = unet.conv_in(sample) | |
if is_control_net: | |
sample += guided_hint | |
outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states)) | |
zc_idx += 1 | |
# 3. down | |
down_block_res_samples = (sample,) | |
for downsample_block in unet.down_blocks: | |
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: | |
sample, res_samples = downsample_block( | |
hidden_states=sample, | |
temb=emb, | |
encoder_hidden_states=encoder_hidden_states, | |
) | |
else: | |
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) | |
if is_control_net: | |
for rs in res_samples: | |
outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states)) | |
zc_idx += 1 | |
down_block_res_samples += res_samples | |
# 4. mid | |
sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) | |
if is_control_net: | |
outs.append(control_net.control_model.middle_block_out[0](sample)) | |
return outs | |
if not is_control_net: | |
sample += ctrl_outs.pop() | |
# 5. up | |
for i, upsample_block in enumerate(unet.up_blocks): | |
is_final_block = i == len(unet.up_blocks) - 1 | |
res_samples = down_block_res_samples[-len(upsample_block.resnets):] | |
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] | |
if not is_control_net and len(ctrl_outs) > 0: | |
res_samples = list(res_samples) | |
apply_ctrl_outs = ctrl_outs[-len(res_samples):] | |
ctrl_outs = ctrl_outs[:-len(res_samples)] | |
for j in range(len(res_samples)): | |
res_samples[j] = res_samples[j] + apply_ctrl_outs[j] | |
res_samples = tuple(res_samples) | |
# if we have not reached the final block and need to forward the | |
# upsample size, we do it here | |
if not is_final_block and forward_upsample_size: | |
upsample_size = down_block_res_samples[-1].shape[2:] | |
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: | |
sample = upsample_block( | |
hidden_states=sample, | |
temb=emb, | |
res_hidden_states_tuple=res_samples, | |
encoder_hidden_states=encoder_hidden_states, | |
upsample_size=upsample_size, | |
) | |
else: | |
sample = upsample_block( | |
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size | |
) | |
# 6. post-process | |
sample = unet.conv_norm_out(sample) | |
sample = unet.conv_act(sample) | |
sample = unet.conv_out(sample) | |
return UNet2DConditionOutput(sample=sample) | |