diff --git a/diffsynth/__init__.py b/diffsynth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f2351777ad6d3d5734a6bbf99d8c7685158ffd5 --- /dev/null +++ b/diffsynth/__init__.py @@ -0,0 +1,6 @@ +from .data import * +from .models import * +from .prompts import * +from .schedulers import * +from .pipelines import * +from .controlnets import * diff --git a/diffsynth/controlnets/__init__.py b/diffsynth/controlnets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b08ba4cbdac2e97f79285a5d5c772dae3b0e25ed --- /dev/null +++ b/diffsynth/controlnets/__init__.py @@ -0,0 +1,2 @@ +from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager +from .processors import Annotator diff --git a/diffsynth/controlnets/controlnet_unit.py b/diffsynth/controlnets/controlnet_unit.py new file mode 100644 index 0000000000000000000000000000000000000000..9754baed90b55dc3518c1cf33ea5863303a49db6 --- /dev/null +++ b/diffsynth/controlnets/controlnet_unit.py @@ -0,0 +1,53 @@ +import torch +import numpy as np +from .processors import Processor_id + + +class ControlNetConfigUnit: + def __init__(self, processor_id: Processor_id, model_path, scale=1.0): + self.processor_id = processor_id + self.model_path = model_path + self.scale = scale + + +class ControlNetUnit: + def __init__(self, processor, model, scale=1.0): + self.processor = processor + self.model = model + self.scale = scale + + +class MultiControlNetManager: + def __init__(self, controlnet_units=[]): + self.processors = [unit.processor for unit in controlnet_units] + self.models = [unit.model for unit in controlnet_units] + self.scales = [unit.scale for unit in controlnet_units] + + def process_image(self, image, processor_id=None): + if processor_id is None: + processed_image = [processor(image) for processor in self.processors] + else: + processed_image = [self.processors[processor_id](image)] + processed_image = torch.concat([ + torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0) + for image_ in processed_image + ], dim=0) + return processed_image + + def __call__( + self, + sample, timestep, encoder_hidden_states, conditionings, + tiled=False, tile_size=64, tile_stride=32 + ): + res_stack = None + for conditioning, model, scale in zip(conditionings, self.models, self.scales): + res_stack_ = model( + sample, timestep, encoder_hidden_states, conditioning, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride + ) + res_stack_ = [res * scale for res in res_stack_] + if res_stack is None: + res_stack = res_stack_ + else: + res_stack = [i + j for i, j in zip(res_stack, res_stack_)] + return res_stack diff --git a/diffsynth/controlnets/processors.py b/diffsynth/controlnets/processors.py new file mode 100644 index 0000000000000000000000000000000000000000..a37884215d896ef28723949ae0b9d2a99bba2a45 --- /dev/null +++ b/diffsynth/controlnets/processors.py @@ -0,0 +1,51 @@ +from typing_extensions import Literal, TypeAlias +import warnings +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from controlnet_aux.processor import ( + CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector + ) + + +Processor_id: TypeAlias = Literal[ + "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile" +] + +class Annotator: + def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None): + if processor_id == "canny": + self.processor = CannyDetector() + elif processor_id == "depth": + self.processor = MidasDetector.from_pretrained(model_path).to("cuda") + elif processor_id == "softedge": + self.processor = HEDdetector.from_pretrained(model_path).to("cuda") + elif processor_id == "lineart": + self.processor = LineartDetector.from_pretrained(model_path).to("cuda") + elif processor_id == "lineart_anime": + self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda") + elif processor_id == "openpose": + self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda") + elif processor_id == "tile": + self.processor = None + else: + raise ValueError(f"Unsupported processor_id: {processor_id}") + + self.processor_id = processor_id + self.detect_resolution = detect_resolution + + def __call__(self, image): + width, height = image.size + if self.processor_id == "openpose": + kwargs = { + "include_body": True, + "include_hand": True, + "include_face": True + } + else: + kwargs = {} + if self.processor is not None: + detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height) + image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs) + image = image.resize((width, height)) + return image + diff --git a/diffsynth/data/__init__.py b/diffsynth/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..de09a29905d289673e40e53278a0a3181232640d --- /dev/null +++ b/diffsynth/data/__init__.py @@ -0,0 +1 @@ +from .video import VideoData, save_video, save_frames diff --git a/diffsynth/data/video.py b/diffsynth/data/video.py new file mode 100644 index 0000000000000000000000000000000000000000..16e19186ac3b803dd74a1a56b3006b06427bbc51 --- /dev/null +++ b/diffsynth/data/video.py @@ -0,0 +1,148 @@ +import imageio, os +import numpy as np +from PIL import Image +from tqdm import tqdm + + +class LowMemoryVideo: + def __init__(self, file_name): + self.reader = imageio.get_reader(file_name) + + def __len__(self): + return self.reader.count_frames() + + def __getitem__(self, item): + return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB") + + def __del__(self): + self.reader.close() + + +def split_file_name(file_name): + result = [] + number = -1 + for i in file_name: + if ord(i)>=ord("0") and ord(i)<=ord("9"): + if number == -1: + number = 0 + number = number*10 + ord(i) - ord("0") + else: + if number != -1: + result.append(number) + number = -1 + result.append(i) + if number != -1: + result.append(number) + result = tuple(result) + return result + + +def search_for_images(folder): + file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")] + file_list = [(split_file_name(file_name), file_name) for file_name in file_list] + file_list = [i[1] for i in sorted(file_list)] + file_list = [os.path.join(folder, i) for i in file_list] + return file_list + + +class LowMemoryImageFolder: + def __init__(self, folder, file_list=None): + if file_list is None: + self.file_list = search_for_images(folder) + else: + self.file_list = [os.path.join(folder, file_name) for file_name in file_list] + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, item): + return Image.open(self.file_list[item]).convert("RGB") + + def __del__(self): + pass + + +def crop_and_resize(image, height, width): + image = np.array(image) + image_height, image_width, _ = image.shape + if image_height / image_width < height / width: + croped_width = int(image_height / height * width) + left = (image_width - croped_width) // 2 + image = image[:, left: left+croped_width] + image = Image.fromarray(image).resize((width, height)) + else: + croped_height = int(image_width / width * height) + left = (image_height - croped_height) // 2 + image = image[left: left+croped_height, :] + image = Image.fromarray(image).resize((width, height)) + return image + + +class VideoData: + def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs): + if video_file is not None: + self.data_type = "video" + self.data = LowMemoryVideo(video_file, **kwargs) + elif image_folder is not None: + self.data_type = "images" + self.data = LowMemoryImageFolder(image_folder, **kwargs) + else: + raise ValueError("Cannot open video or image folder") + self.length = None + self.set_shape(height, width) + + def raw_data(self): + frames = [] + for i in range(self.__len__()): + frames.append(self.__getitem__(i)) + return frames + + def set_length(self, length): + self.length = length + + def set_shape(self, height, width): + self.height = height + self.width = width + + def __len__(self): + if self.length is None: + return len(self.data) + else: + return self.length + + def shape(self): + if self.height is not None and self.width is not None: + return self.height, self.width + else: + height, width, _ = self.__getitem__(0).shape + return height, width + + def __getitem__(self, item): + frame = self.data.__getitem__(item) + width, height = frame.size + if self.height is not None and self.width is not None: + if self.height != height or self.width != width: + frame = crop_and_resize(frame, self.height, self.width) + return frame + + def __del__(self): + pass + + def save_images(self, folder): + os.makedirs(folder, exist_ok=True) + for i in tqdm(range(self.__len__()), desc="Saving images"): + frame = self.__getitem__(i) + frame.save(os.path.join(folder, f"{i}.png")) + + +def save_video(frames, save_path, fps, quality=9): + writer = imageio.get_writer(save_path, fps=fps, quality=quality) + for frame in tqdm(frames, desc="Saving video"): + frame = np.array(frame) + writer.append_data(frame) + writer.close() + +def save_frames(frames, save_path): + os.makedirs(save_path, exist_ok=True) + for i, frame in enumerate(tqdm(frames, desc="Saving images")): + frame.save(os.path.join(save_path, f"{i}.png")) diff --git a/diffsynth/extensions/ESRGAN/__init__.py b/diffsynth/extensions/ESRGAN/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e71cd3f24f1f81c7fc224421aea08aa944707493 --- /dev/null +++ b/diffsynth/extensions/ESRGAN/__init__.py @@ -0,0 +1,118 @@ +import torch +from einops import repeat +from PIL import Image +import numpy as np + + +class ResidualDenseBlock(torch.nn.Module): + + def __init__(self, num_feat=64, num_grow_ch=32): + super(ResidualDenseBlock, self).__init__() + self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) + self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) + self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +class RRDB(torch.nn.Module): + + def __init__(self, num_feat, num_grow_ch=32): + super(RRDB, self).__init__() + self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) + + def forward(self, x): + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + return out * 0.2 + x + + +class RRDBNet(torch.nn.Module): + + def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32): + super(RRDBNet, self).__init__() + self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)]) + self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1) + # upsample + self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + feat = x + feat = self.conv_first(feat) + body_feat = self.conv_body(self.body(feat)) + feat = feat + body_feat + # upsample + feat = repeat(feat, "B C H W -> B C (H 2) (W 2)") + feat = self.lrelu(self.conv_up1(feat)) + feat = repeat(feat, "B C H W -> B C (H 2) (W 2)") + feat = self.lrelu(self.conv_up2(feat)) + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + return out + + +class ESRGAN(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + @staticmethod + def from_pretrained(model_path): + model = RRDBNet() + state_dict = torch.load(model_path, map_location="cpu")["params_ema"] + model.load_state_dict(state_dict) + model.eval() + return ESRGAN(model) + + def process_image(self, image): + image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1) + return image + + def process_images(self, images): + images = [self.process_image(image) for image in images] + images = torch.stack(images) + return images + + def decode_images(self, images): + images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8) + images = [Image.fromarray(image) for image in images] + return images + + @torch.no_grad() + def upscale(self, images, batch_size=4, progress_bar=lambda x:x): + # Preprocess + input_tensor = self.process_images(images) + + # Interpolate + output_tensor = [] + for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)): + batch_id_ = min(batch_id + batch_size, input_tensor.shape[0]) + batch_input_tensor = input_tensor[batch_id: batch_id_] + batch_input_tensor = batch_input_tensor.to( + device=self.model.conv_first.weight.device, + dtype=self.model.conv_first.weight.dtype) + batch_output_tensor = self.model(batch_input_tensor) + output_tensor.append(batch_output_tensor.cpu()) + + # Output + output_tensor = torch.concat(output_tensor, dim=0) + + # To images + output_images = self.decode_images(output_tensor) + return output_images diff --git a/diffsynth/extensions/FastBlend/__init__.py b/diffsynth/extensions/FastBlend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bf812c2085082bfa82658dd249ebca89e9fb465 --- /dev/null +++ b/diffsynth/extensions/FastBlend/__init__.py @@ -0,0 +1,63 @@ +from .runners.fast import TableManager, PyramidPatchMatcher +from PIL import Image +import numpy as np +import cupy as cp + + +class FastBlendSmoother: + def __init__(self): + self.batch_size = 8 + self.window_size = 64 + self.ebsynth_config = { + "minimum_patch_size": 5, + "threads_per_block": 8, + "num_iter": 5, + "gpu_id": 0, + "guide_weight": 10.0, + "initialize": "identity", + "tracking_window_size": 0, + } + + @staticmethod + def from_model_manager(model_manager): + # TODO: fetch GPU ID from model_manager + return FastBlendSmoother() + + def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config): + frames_guide = [np.array(frame) for frame in frames_guide] + frames_style = [np.array(frame) for frame in frames_style] + table_manager = TableManager() + patch_match_engine = PyramidPatchMatcher( + image_height=frames_style[0].shape[0], + image_width=frames_style[0].shape[1], + channel=3, + **ebsynth_config + ) + # left part + table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4") + table_l = table_manager.remapping_table_to_blending_table(table_l) + table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4") + # right part + table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4") + table_r = table_manager.remapping_table_to_blending_table(table_r) + table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1] + # merge + frames = [] + for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r): + weight_m = -1 + weight = weight_l + weight_m + weight_r + frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight) + frames.append(frame) + frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames] + return frames + + def __call__(self, rendered_frames, original_frames=None, **kwargs): + frames = self.run( + original_frames, rendered_frames, + self.batch_size, self.window_size, self.ebsynth_config + ) + mempool = cp.get_default_memory_pool() + pinned_mempool = cp.get_default_pinned_memory_pool() + mempool.free_all_blocks() + pinned_mempool.free_all_blocks() + return frames \ No newline at end of file diff --git a/diffsynth/extensions/FastBlend/api.py b/diffsynth/extensions/FastBlend/api.py new file mode 100644 index 0000000000000000000000000000000000000000..2db24330e375ed62065af54613b6ab956c9c64cf --- /dev/null +++ b/diffsynth/extensions/FastBlend/api.py @@ -0,0 +1,397 @@ +from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner +from .data import VideoData, get_video_fps, save_video, search_for_images +import os +import gradio as gr + + +def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder): + frames_guide = VideoData(video_guide, video_guide_folder) + frames_style = VideoData(video_style, video_style_folder) + message = "" + if len(frames_guide) < len(frames_style): + message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n" + frames_style.set_length(len(frames_guide)) + elif len(frames_guide) > len(frames_style): + message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n" + frames_guide.set_length(len(frames_style)) + height_guide, width_guide = frames_guide.shape() + height_style, width_style = frames_style.shape() + if height_guide != height_style or width_guide != width_style: + message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n" + frames_style.set_shape(height_guide, width_guide) + return frames_guide, frames_style, message + + +def smooth_video( + video_guide, + video_guide_folder, + video_style, + video_style_folder, + mode, + window_size, + batch_size, + tracking_window_size, + output_path, + fps, + minimum_patch_size, + num_iter, + guide_weight, + initialize, + progress = None, +): + # input + frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder) + if len(message) > 0: + print(message) + # output + if output_path == "": + if video_style is None: + output_path = os.path.join(video_style_folder, "output") + else: + output_path = os.path.join(os.path.split(video_style)[0], "output") + os.makedirs(output_path, exist_ok=True) + print("No valid output_path. Your video will be saved here:", output_path) + elif not os.path.exists(output_path): + os.makedirs(output_path, exist_ok=True) + print("Your video will be saved here:", output_path) + frames_path = os.path.join(output_path, "frames") + video_path = os.path.join(output_path, "video.mp4") + os.makedirs(frames_path, exist_ok=True) + # process + if mode == "Fast" or mode == "Balanced": + tracking_window_size = 0 + ebsynth_config = { + "minimum_patch_size": minimum_patch_size, + "threads_per_block": 8, + "num_iter": num_iter, + "gpu_id": 0, + "guide_weight": guide_weight, + "initialize": initialize, + "tracking_window_size": tracking_window_size, + } + if mode == "Fast": + FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path) + elif mode == "Balanced": + BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path) + elif mode == "Accurate": + AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path) + # output + try: + fps = int(fps) + except: + fps = get_video_fps(video_style) if video_style is not None else 30 + print("Fps:", fps) + print("Saving video...") + video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps) + print("Success!") + print("Your frames are here:", frames_path) + print("Your video is here:", video_path) + return output_path, fps, video_path + + +class KeyFrameMatcher: + def __init__(self): + pass + + def extract_number_from_filename(self, file_name): + result = [] + number = -1 + for i in file_name: + if ord(i)>=ord("0") and ord(i)<=ord("9"): + if number == -1: + number = 0 + number = number*10 + ord(i) - ord("0") + else: + if number != -1: + result.append(number) + number = -1 + if number != -1: + result.append(number) + result = tuple(result) + return result + + def extract_number_from_filenames(self, file_names): + numbers = [self.extract_number_from_filename(file_name) for file_name in file_names] + min_length = min(len(i) for i in numbers) + for i in range(min_length-1, -1, -1): + if len(set(number[i] for number in numbers))==len(file_names): + return [number[i] for number in numbers] + return list(range(len(file_names))) + + def match_using_filename(self, file_names_a, file_names_b): + file_names_b_set = set(file_names_b) + matched_file_name = [] + for file_name in file_names_a: + if file_name not in file_names_b_set: + matched_file_name.append(None) + else: + matched_file_name.append(file_name) + return matched_file_name + + def match_using_numbers(self, file_names_a, file_names_b): + numbers_a = self.extract_number_from_filenames(file_names_a) + numbers_b = self.extract_number_from_filenames(file_names_b) + numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)} + matched_file_name = [] + for number in numbers_a: + if number in numbers_b_dict: + matched_file_name.append(numbers_b_dict[number]) + else: + matched_file_name.append(None) + return matched_file_name + + def match_filenames(self, file_names_a, file_names_b): + matched_file_name = self.match_using_filename(file_names_a, file_names_b) + if sum([i is not None for i in matched_file_name]) > 0: + return matched_file_name + matched_file_name = self.match_using_numbers(file_names_a, file_names_b) + return matched_file_name + + +def detect_frames(frames_path, keyframes_path): + if not os.path.exists(frames_path) and not os.path.exists(keyframes_path): + return "Please input the directory of guide video and rendered frames" + elif not os.path.exists(frames_path): + return "Please input the directory of guide video" + elif not os.path.exists(keyframes_path): + return "Please input the directory of rendered frames" + frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)] + keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)] + if len(frames)==0: + return f"No images detected in {frames_path}" + if len(keyframes)==0: + return f"No images detected in {keyframes_path}" + matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes) + max_filename_length = max([len(i) for i in frames]) + if sum([i is not None for i in matched_keyframes])==0: + message = "" + for frame, matched_keyframe in zip(frames, matched_keyframes): + message += frame + " " * (max_filename_length - len(frame) + 1) + message += "--> No matched keyframes\n" + else: + message = "" + for frame, matched_keyframe in zip(frames, matched_keyframes): + message += frame + " " * (max_filename_length - len(frame) + 1) + if matched_keyframe is None: + message += "--> [to be rendered]\n" + else: + message += f"--> {matched_keyframe}\n" + return message + + +def check_input_for_interpolating(frames_path, keyframes_path): + # search for images + frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)] + keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)] + # match frames + matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes) + file_list = [file_name for file_name in matched_keyframes if file_name is not None] + index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None] + frames_guide = VideoData(None, frames_path) + frames_style = VideoData(None, keyframes_path, file_list=file_list) + # match shape + message = "" + height_guide, width_guide = frames_guide.shape() + height_style, width_style = frames_style.shape() + if height_guide != height_style or width_guide != width_style: + message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n" + frames_style.set_shape(height_guide, width_guide) + return frames_guide, frames_style, index_style, message + + +def interpolate_video( + frames_path, + keyframes_path, + output_path, + fps, + batch_size, + tracking_window_size, + minimum_patch_size, + num_iter, + guide_weight, + initialize, + progress = None, +): + # input + frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path) + if len(message) > 0: + print(message) + # output + if output_path == "": + output_path = os.path.join(keyframes_path, "output") + os.makedirs(output_path, exist_ok=True) + print("No valid output_path. Your video will be saved here:", output_path) + elif not os.path.exists(output_path): + os.makedirs(output_path, exist_ok=True) + print("Your video will be saved here:", output_path) + output_frames_path = os.path.join(output_path, "frames") + output_video_path = os.path.join(output_path, "video.mp4") + os.makedirs(output_frames_path, exist_ok=True) + # process + ebsynth_config = { + "minimum_patch_size": minimum_patch_size, + "threads_per_block": 8, + "num_iter": num_iter, + "gpu_id": 0, + "guide_weight": guide_weight, + "initialize": initialize, + "tracking_window_size": tracking_window_size + } + if len(index_style)==1: + InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path) + else: + InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path) + try: + fps = int(fps) + except: + fps = 30 + print("Fps:", fps) + print("Saving video...") + video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps) + print("Success!") + print("Your frames are here:", output_frames_path) + print("Your video is here:", video_path) + return output_path, fps, video_path + + +def on_ui_tabs(): + with gr.Blocks(analytics_enabled=False) as ui_component: + with gr.Tab("Blend"): + gr.Markdown(""" +# Blend + +Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution. + """) + with gr.Row(): + with gr.Column(): + with gr.Tab("Guide video"): + video_guide = gr.Video(label="Guide video") + with gr.Tab("Guide video (images format)"): + video_guide_folder = gr.Textbox(label="Guide video (images format)", value="") + with gr.Column(): + with gr.Tab("Style video"): + video_style = gr.Video(label="Style video") + with gr.Tab("Style video (images format)"): + video_style_folder = gr.Textbox(label="Style video (images format)", value="") + with gr.Column(): + output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video") + fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps") + video_output = gr.Video(label="Output video", interactive=False, show_share_button=True) + btn = gr.Button(value="Blend") + with gr.Row(): + with gr.Column(): + gr.Markdown("# Settings") + mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True) + window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True) + batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True) + tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True) + gr.Markdown("## Advanced Settings") + minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True) + num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True) + guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True) + initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True) + with gr.Column(): + gr.Markdown(""" +# Reference + +* Output directory: the directory to save the video. +* Inference mode + +|Mode|Time|Memory|Quality|Frame by frame output|Description| +|-|-|-|-|-|-| +|Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.| +|Balanced|■■|■|■■|Yes|Blend the frames naively.| +|Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.| + +* Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy. +* Batch size: a larger batch size makes the program faster but requires more VRAM. +* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough. +* Advanced settings + * Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5) + * Number of iterations: the number of iterations of patch matching. (Default: 5) + * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10) + * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity) + """) + btn.click( + smooth_video, + inputs=[ + video_guide, + video_guide_folder, + video_style, + video_style_folder, + mode, + window_size, + batch_size, + tracking_window_size, + output_path, + fps, + minimum_patch_size, + num_iter, + guide_weight, + initialize + ], + outputs=[output_path, fps, video_output] + ) + with gr.Tab("Interpolate"): + gr.Markdown(""" +# Interpolate + +Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution. + """) + with gr.Row(): + with gr.Column(): + with gr.Row(): + with gr.Column(): + video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="") + with gr.Column(): + rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="") + with gr.Row(): + detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False) + video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames) + rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames) + with gr.Column(): + output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes") + fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps") + video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True) + btn_ = gr.Button(value="Interpolate") + with gr.Row(): + with gr.Column(): + gr.Markdown("# Settings") + batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True) + tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True) + gr.Markdown("## Advanced Settings") + minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True) + num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True) + guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True) + initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True) + with gr.Column(): + gr.Markdown(""" +# Reference + +* Output directory: the directory to save the video. +* Batch size: a larger batch size makes the program faster but requires more VRAM. +* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough. +* Advanced settings + * Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)** + * Number of iterations: the number of iterations of patch matching. (Default: 5) + * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10) + * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity) + """) + btn_.click( + interpolate_video, + inputs=[ + video_guide_folder_, + rendered_keyframes_, + output_path_, + fps_, + batch_size_, + tracking_window_size_, + minimum_patch_size_, + num_iter_, + guide_weight_, + initialize_, + ], + outputs=[output_path_, fps_, video_output_] + ) + + return [(ui_component, "FastBlend", "FastBlend_ui")] diff --git a/diffsynth/extensions/FastBlend/cupy_kernels.py b/diffsynth/extensions/FastBlend/cupy_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..70e2790a2c67a2dd537f4188b38ebfc785f1fb34 --- /dev/null +++ b/diffsynth/extensions/FastBlend/cupy_kernels.py @@ -0,0 +1,119 @@ +import cupy as cp + +remapping_kernel = cp.RawKernel(r''' +extern "C" __global__ +void remap( + const int height, + const int width, + const int channel, + const int patch_size, + const int pad_size, + const float* source_style, + const int* nnf, + float* target_style +) { + const int r = (patch_size - 1) / 2; + const int x = blockDim.x * blockIdx.x + threadIdx.x; + const int y = blockDim.y * blockIdx.y + threadIdx.y; + if (x >= height or y >= width) return; + const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel; + const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size); + const int min_px = x < r ? -x : -r; + const int max_px = x + r > height - 1 ? height - 1 - x : r; + const int min_py = y < r ? -y : -r; + const int max_py = y + r > width - 1 ? width - 1 - y : r; + int num = 0; + for (int px = min_px; px <= max_px; px++){ + for (int py = min_py; py <= max_py; py++){ + const int nid = (x + px) * width + y + py; + const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px; + const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py; + if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue; + const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size); + num++; + for (int c = 0; c < channel; c++){ + target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c]; + } + } + } + for (int c = 0; c < channel; c++){ + target_style[z + pid * channel + c] /= num; + } +} +''', 'remap') + + +patch_error_kernel = cp.RawKernel(r''' +extern "C" __global__ +void patch_error( + const int height, + const int width, + const int channel, + const int patch_size, + const int pad_size, + const float* source, + const int* nnf, + const float* target, + float* error +) { + const int r = (patch_size - 1) / 2; + const int x = blockDim.x * blockIdx.x + threadIdx.x; + const int y = blockDim.y * blockIdx.y + threadIdx.y; + const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel; + if (x >= height or y >= width) return; + const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0]; + const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1]; + float e = 0; + for (int px = -r; px <= r; px++){ + for (int py = -r; py <= r; py++){ + const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py; + const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py; + for (int c = 0; c < channel; c++){ + const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c]; + e += diff * diff; + } + } + } + error[blockIdx.z * height * width + x * width + y] = e; +} +''', 'patch_error') + + +pairwise_patch_error_kernel = cp.RawKernel(r''' +extern "C" __global__ +void pairwise_patch_error( + const int height, + const int width, + const int channel, + const int patch_size, + const int pad_size, + const float* source_a, + const int* nnf_a, + const float* source_b, + const int* nnf_b, + float* error +) { + const int r = (patch_size - 1) / 2; + const int x = blockDim.x * blockIdx.x + threadIdx.x; + const int y = blockDim.y * blockIdx.y + threadIdx.y; + const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel; + if (x >= height or y >= width) return; + const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2; + const int x_a = nnf_a[z_nnf + 0]; + const int y_a = nnf_a[z_nnf + 1]; + const int x_b = nnf_b[z_nnf + 0]; + const int y_b = nnf_b[z_nnf + 1]; + float e = 0; + for (int px = -r; px <= r; px++){ + for (int py = -r; py <= r; py++){ + const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py; + const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py; + for (int c = 0; c < channel; c++){ + const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c]; + e += diff * diff; + } + } + } + error[blockIdx.z * height * width + x * width + y] = e; +} +''', 'pairwise_patch_error') diff --git a/diffsynth/extensions/FastBlend/data.py b/diffsynth/extensions/FastBlend/data.py new file mode 100644 index 0000000000000000000000000000000000000000..dcaddd77de9eaf208cd083dd522e5eaa6b58f783 --- /dev/null +++ b/diffsynth/extensions/FastBlend/data.py @@ -0,0 +1,146 @@ +import imageio, os +import numpy as np +from PIL import Image + + +def read_video(file_name): + reader = imageio.get_reader(file_name) + video = [] + for frame in reader: + frame = np.array(frame) + video.append(frame) + reader.close() + return video + + +def get_video_fps(file_name): + reader = imageio.get_reader(file_name) + fps = reader.get_meta_data()["fps"] + reader.close() + return fps + + +def save_video(frames_path, video_path, num_frames, fps): + writer = imageio.get_writer(video_path, fps=fps, quality=9) + for i in range(num_frames): + frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i))) + writer.append_data(frame) + writer.close() + return video_path + + +class LowMemoryVideo: + def __init__(self, file_name): + self.reader = imageio.get_reader(file_name) + + def __len__(self): + return self.reader.count_frames() + + def __getitem__(self, item): + return np.array(self.reader.get_data(item)) + + def __del__(self): + self.reader.close() + + +def split_file_name(file_name): + result = [] + number = -1 + for i in file_name: + if ord(i)>=ord("0") and ord(i)<=ord("9"): + if number == -1: + number = 0 + number = number*10 + ord(i) - ord("0") + else: + if number != -1: + result.append(number) + number = -1 + result.append(i) + if number != -1: + result.append(number) + result = tuple(result) + return result + + +def search_for_images(folder): + file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")] + file_list = [(split_file_name(file_name), file_name) for file_name in file_list] + file_list = [i[1] for i in sorted(file_list)] + file_list = [os.path.join(folder, i) for i in file_list] + return file_list + + +def read_images(folder): + file_list = search_for_images(folder) + frames = [np.array(Image.open(i)) for i in file_list] + return frames + + +class LowMemoryImageFolder: + def __init__(self, folder, file_list=None): + if file_list is None: + self.file_list = search_for_images(folder) + else: + self.file_list = [os.path.join(folder, file_name) for file_name in file_list] + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, item): + return np.array(Image.open(self.file_list[item])) + + def __del__(self): + pass + + +class VideoData: + def __init__(self, video_file, image_folder, **kwargs): + if video_file is not None: + self.data_type = "video" + self.data = LowMemoryVideo(video_file, **kwargs) + elif image_folder is not None: + self.data_type = "images" + self.data = LowMemoryImageFolder(image_folder, **kwargs) + else: + raise ValueError("Cannot open video or image folder") + self.length = None + self.height = None + self.width = None + + def raw_data(self): + frames = [] + for i in range(self.__len__()): + frames.append(self.__getitem__(i)) + return frames + + def set_length(self, length): + self.length = length + + def set_shape(self, height, width): + self.height = height + self.width = width + + def __len__(self): + if self.length is None: + return len(self.data) + else: + return self.length + + def shape(self): + if self.height is not None and self.width is not None: + return self.height, self.width + else: + height, width, _ = self.__getitem__(0).shape + return height, width + + def __getitem__(self, item): + frame = self.data.__getitem__(item) + height, width, _ = frame.shape + if self.height is not None and self.width is not None: + if self.height != height or self.width != width: + frame = Image.fromarray(frame).resize((self.width, self.height)) + frame = np.array(frame) + return frame + + def __del__(self): + pass diff --git a/diffsynth/extensions/FastBlend/patch_match.py b/diffsynth/extensions/FastBlend/patch_match.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb1f7f9e31b4b2ad77ec58ba8d32361315e0390 --- /dev/null +++ b/diffsynth/extensions/FastBlend/patch_match.py @@ -0,0 +1,298 @@ +from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel +import numpy as np +import cupy as cp +import cv2 + + +class PatchMatcher: + def __init__( + self, height, width, channel, minimum_patch_size, + threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0, + random_search_steps=3, random_search_range=4, + use_mean_target_style=False, use_pairwise_patch_error=False, + tracking_window_size=0 + ): + self.height = height + self.width = width + self.channel = channel + self.minimum_patch_size = minimum_patch_size + self.threads_per_block = threads_per_block + self.num_iter = num_iter + self.gpu_id = gpu_id + self.guide_weight = guide_weight + self.random_search_steps = random_search_steps + self.random_search_range = random_search_range + self.use_mean_target_style = use_mean_target_style + self.use_pairwise_patch_error = use_pairwise_patch_error + self.tracking_window_size = tracking_window_size + + self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1] + self.pad_size = self.patch_size_list[0] // 2 + self.grid = ( + (height + threads_per_block - 1) // threads_per_block, + (width + threads_per_block - 1) // threads_per_block + ) + self.block = (threads_per_block, threads_per_block) + + def pad_image(self, image): + return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0))) + + def unpad_image(self, image): + return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :] + + def apply_nnf_to_image(self, nnf, source): + batch_size = source.shape[0] + target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32) + remapping_kernel( + self.grid + (batch_size,), + self.block, + (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target) + ) + return target + + def get_patch_error(self, source, nnf, target): + batch_size = source.shape[0] + error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32) + patch_error_kernel( + self.grid + (batch_size,), + self.block, + (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error) + ) + return error + + def get_pairwise_patch_error(self, source, nnf): + batch_size = source.shape[0]//2 + error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32) + source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy() + source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy() + pairwise_patch_error_kernel( + self.grid + (batch_size,), + self.block, + (self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error) + ) + error = error.repeat(2, axis=0) + return error + + def get_error(self, source_guide, target_guide, source_style, target_style, nnf): + error_guide = self.get_patch_error(source_guide, nnf, target_guide) + if self.use_mean_target_style: + target_style = self.apply_nnf_to_image(nnf, source_style) + target_style = target_style.mean(axis=0, keepdims=True) + target_style = target_style.repeat(source_guide.shape[0], axis=0) + if self.use_pairwise_patch_error: + error_style = self.get_pairwise_patch_error(source_style, nnf) + else: + error_style = self.get_patch_error(source_style, nnf, target_style) + error = error_guide * self.guide_weight + error_style + return error + + def clamp_bound(self, nnf): + nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1) + nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1) + return nnf + + def random_step(self, nnf, r): + batch_size = nnf.shape[0] + step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32) + upd_nnf = self.clamp_bound(nnf + step) + return upd_nnf + + def neighboor_step(self, nnf, d): + if d==0: + upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1) + upd_nnf[:, :, :, 0] += 1 + elif d==1: + upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2) + upd_nnf[:, :, :, 1] += 1 + elif d==2: + upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1) + upd_nnf[:, :, :, 0] -= 1 + elif d==3: + upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2) + upd_nnf[:, :, :, 1] -= 1 + upd_nnf = self.clamp_bound(upd_nnf) + return upd_nnf + + def shift_nnf(self, nnf, d): + if d>0: + d = min(nnf.shape[0], d) + upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0) + else: + d = max(-nnf.shape[0], d) + upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0) + return upd_nnf + + def track_step(self, nnf, d): + if self.use_pairwise_patch_error: + upd_nnf = cp.zeros_like(nnf) + upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d) + upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d) + else: + upd_nnf = self.shift_nnf(nnf, d) + return upd_nnf + + def C(self, n, m): + # not used + c = 1 + for i in range(1, n+1): + c *= i + for i in range(1, m+1): + c //= i + for i in range(1, n-m+1): + c //= i + return c + + def bezier_step(self, nnf, r): + # not used + n = r * 2 - 1 + upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32) + for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))): + if d>0: + ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0) + elif d<0: + ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0) + upd_nnf += ctl_nnf * (self.C(n, i) / 2**n) + upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype) + return upd_nnf + + def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf): + upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf) + upd_idx = (upd_err < err) + nnf[upd_idx] = upd_nnf[upd_idx] + err[upd_idx] = upd_err[upd_idx] + return nnf, err + + def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err): + for d in cp.random.permutation(4): + upd_nnf = self.neighboor_step(nnf, d) + nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf) + return nnf, err + + def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err): + for i in range(self.random_search_steps): + upd_nnf = self.random_step(nnf, self.random_search_range) + nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf) + return nnf, err + + def track(self, source_guide, target_guide, source_style, target_style, nnf, err): + for d in range(1, self.tracking_window_size + 1): + upd_nnf = self.track_step(nnf, d) + nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf) + upd_nnf = self.track_step(nnf, -d) + nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf) + return nnf, err + + def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err): + nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err) + nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err) + nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err) + return nnf, err + + def estimate_nnf(self, source_guide, target_guide, source_style, nnf): + with cp.cuda.Device(self.gpu_id): + source_guide = self.pad_image(source_guide) + target_guide = self.pad_image(target_guide) + source_style = self.pad_image(source_style) + for it in range(self.num_iter): + self.patch_size = self.patch_size_list[it] + target_style = self.apply_nnf_to_image(nnf, source_style) + err = self.get_error(source_guide, target_guide, source_style, target_style, nnf) + nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err) + target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style)) + return nnf, target_style + + +class PyramidPatchMatcher: + def __init__( + self, image_height, image_width, channel, minimum_patch_size, + threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0, + use_mean_target_style=False, use_pairwise_patch_error=False, + tracking_window_size=0, + initialize="identity" + ): + maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2 + self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size)) + self.pyramid_heights = [] + self.pyramid_widths = [] + self.patch_matchers = [] + self.minimum_patch_size = minimum_patch_size + self.num_iter = num_iter + self.gpu_id = gpu_id + self.initialize = initialize + for level in range(self.pyramid_level): + height = image_height//(2**(self.pyramid_level - 1 - level)) + width = image_width//(2**(self.pyramid_level - 1 - level)) + self.pyramid_heights.append(height) + self.pyramid_widths.append(width) + self.patch_matchers.append(PatchMatcher( + height, width, channel, minimum_patch_size=minimum_patch_size, + threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight, + use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error, + tracking_window_size=tracking_window_size + )) + + def resample_image(self, images, level): + height, width = self.pyramid_heights[level], self.pyramid_widths[level] + images = images.get() + images_resample = [] + for image in images: + image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA) + images_resample.append(image_resample) + images_resample = cp.array(np.stack(images_resample), dtype=cp.float32) + return images_resample + + def initialize_nnf(self, batch_size): + if self.initialize == "random": + height, width = self.pyramid_heights[0], self.pyramid_widths[0] + nnf = cp.stack([ + cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32), + cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32) + ], axis=3) + elif self.initialize == "identity": + height, width = self.pyramid_heights[0], self.pyramid_widths[0] + nnf = cp.stack([ + cp.repeat(cp.arange(height), width).reshape(height, width), + cp.tile(cp.arange(width), height).reshape(height, width) + ], axis=2) + nnf = cp.stack([nnf] * batch_size) + else: + raise NotImplementedError() + return nnf + + def update_nnf(self, nnf, level): + # upscale + nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2 + nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1 + nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1 + # check if scale is 2 + height, width = self.pyramid_heights[level], self.pyramid_widths[level] + if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2: + nnf = nnf.get().astype(np.float32) + nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf] + nnf = cp.array(np.stack(nnf), dtype=cp.int32) + nnf = self.patch_matchers[level].clamp_bound(nnf) + return nnf + + def apply_nnf_to_image(self, nnf, image): + with cp.cuda.Device(self.gpu_id): + image = self.patch_matchers[-1].pad_image(image) + image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image) + return image + + def estimate_nnf(self, source_guide, target_guide, source_style): + with cp.cuda.Device(self.gpu_id): + if not isinstance(source_guide, cp.ndarray): + source_guide = cp.array(source_guide, dtype=cp.float32) + if not isinstance(target_guide, cp.ndarray): + target_guide = cp.array(target_guide, dtype=cp.float32) + if not isinstance(source_style, cp.ndarray): + source_style = cp.array(source_style, dtype=cp.float32) + for level in range(self.pyramid_level): + nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level) + source_guide_ = self.resample_image(source_guide, level) + target_guide_ = self.resample_image(target_guide, level) + source_style_ = self.resample_image(source_style, level) + nnf, target_style = self.patch_matchers[level].estimate_nnf( + source_guide_, target_guide_, source_style_, nnf + ) + return nnf.get(), target_style.get() diff --git a/diffsynth/extensions/FastBlend/runners/__init__.py b/diffsynth/extensions/FastBlend/runners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..078382729690d282436411661693ce22f3dcc033 --- /dev/null +++ b/diffsynth/extensions/FastBlend/runners/__init__.py @@ -0,0 +1,4 @@ +from .accurate import AccurateModeRunner +from .fast import FastModeRunner +from .balanced import BalancedModeRunner +from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner diff --git a/diffsynth/extensions/FastBlend/runners/accurate.py b/diffsynth/extensions/FastBlend/runners/accurate.py new file mode 100644 index 0000000000000000000000000000000000000000..2e4a47f1981ebc1ec9a034a814dfc1130955c2e1 --- /dev/null +++ b/diffsynth/extensions/FastBlend/runners/accurate.py @@ -0,0 +1,35 @@ +from ..patch_match import PyramidPatchMatcher +import os +import numpy as np +from PIL import Image +from tqdm import tqdm + + +class AccurateModeRunner: + def __init__(self): + pass + + def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None): + patch_match_engine = PyramidPatchMatcher( + image_height=frames_style[0].shape[0], + image_width=frames_style[0].shape[1], + channel=3, + use_mean_target_style=True, + **ebsynth_config + ) + # run + n = len(frames_style) + for target in tqdm(range(n), desc=desc): + l, r = max(target - window_size, 0), min(target + window_size + 1, n) + remapped_frames = [] + for i in range(l, r, batch_size): + j = min(i + batch_size, r) + source_guide = np.stack([frames_guide[source] for source in range(i, j)]) + target_guide = np.stack([frames_guide[target]] * (j - i)) + source_style = np.stack([frames_style[source] for source in range(i, j)]) + _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) + remapped_frames.append(target_style) + frame = np.concatenate(remapped_frames, axis=0).mean(axis=0) + frame = frame.clip(0, 255).astype("uint8") + if save_path is not None: + Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target)) \ No newline at end of file diff --git a/diffsynth/extensions/FastBlend/runners/balanced.py b/diffsynth/extensions/FastBlend/runners/balanced.py new file mode 100644 index 0000000000000000000000000000000000000000..1c9a2bb7e438b49c89d0786e858ccf03302fab35 --- /dev/null +++ b/diffsynth/extensions/FastBlend/runners/balanced.py @@ -0,0 +1,46 @@ +from ..patch_match import PyramidPatchMatcher +import os +import numpy as np +from PIL import Image +from tqdm import tqdm + + +class BalancedModeRunner: + def __init__(self): + pass + + def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None): + patch_match_engine = PyramidPatchMatcher( + image_height=frames_style[0].shape[0], + image_width=frames_style[0].shape[1], + channel=3, + **ebsynth_config + ) + # tasks + n = len(frames_style) + tasks = [] + for target in range(n): + for source in range(target - window_size, target + window_size + 1): + if source >= 0 and source < n and source != target: + tasks.append((source, target)) + # run + frames = [(None, 1) for i in range(n)] + for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc): + tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))] + source_guide = np.stack([frames_guide[source] for source, target in tasks_batch]) + target_guide = np.stack([frames_guide[target] for source, target in tasks_batch]) + source_style = np.stack([frames_style[source] for source, target in tasks_batch]) + _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) + for (source, target), result in zip(tasks_batch, target_style): + frame, weight = frames[target] + if frame is None: + frame = frames_style[target] + frames[target] = ( + frame * (weight / (weight + 1)) + result / (weight + 1), + weight + 1 + ) + if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size): + frame = frame.clip(0, 255).astype("uint8") + if save_path is not None: + Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target)) + frames[target] = (None, 1) diff --git a/diffsynth/extensions/FastBlend/runners/fast.py b/diffsynth/extensions/FastBlend/runners/fast.py new file mode 100644 index 0000000000000000000000000000000000000000..2ba5731475ab875929b14181e0c22f4fd466c591 --- /dev/null +++ b/diffsynth/extensions/FastBlend/runners/fast.py @@ -0,0 +1,141 @@ +from ..patch_match import PyramidPatchMatcher +import functools, os +import numpy as np +from PIL import Image +from tqdm import tqdm + + +class TableManager: + def __init__(self): + pass + + def task_list(self, n): + tasks = [] + max_level = 1 + while (1<=n: + break + meta_data = { + "source": i, + "target": j, + "level": level + 1 + } + tasks.append(meta_data) + tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"])) + return tasks + + def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""): + n = len(frames_guide) + tasks = self.task_list(n) + remapping_table = [[(frames_style[i], 1)] for i in range(n)] + for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc): + tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))] + source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch]) + target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch]) + source_style = np.stack([frames_style[task["source"]] for task in tasks_batch]) + _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) + for task, result in zip(tasks_batch, target_style): + target, level = task["target"], task["level"] + if len(remapping_table[target])==level: + remapping_table[target].append((result, 1)) + else: + frame, weight = remapping_table[target][level] + remapping_table[target][level] = ( + frame * (weight / (weight + 1)) + result / (weight + 1), + weight + 1 + ) + return remapping_table + + def remapping_table_to_blending_table(self, table): + for i in range(len(table)): + for j in range(1, len(table[i])): + frame_1, weight_1 = table[i][j-1] + frame_2, weight_2 = table[i][j] + frame = (frame_1 + frame_2) / 2 + weight = weight_1 + weight_2 + table[i][j] = (frame, weight) + return table + + def tree_query(self, leftbound, rightbound): + node_list = [] + node_index = rightbound + while node_index>=leftbound: + node_level = 0 + while (1<=leftbound: + node_level += 1 + node_list.append((node_index, node_level)) + node_index -= 1<0: + tasks = [] + for m in range(index_style[0]): + tasks.append((index_style[0], m, index_style[0])) + task_group.append(tasks) + # middle frames + for l, r in zip(index_style[:-1], index_style[1:]): + tasks = [] + for m in range(l, r): + tasks.append((l, m, r)) + task_group.append(tasks) + # last frame + tasks = [] + for m in range(index_style[-1], n): + tasks.append((index_style[-1], m, index_style[-1])) + task_group.append(tasks) + return task_group + + def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None): + patch_match_engine = PyramidPatchMatcher( + image_height=frames_style[0].shape[0], + image_width=frames_style[0].shape[1], + channel=3, + use_mean_target_style=False, + use_pairwise_patch_error=True, + **ebsynth_config + ) + # task + index_dict = self.get_index_dict(index_style) + task_group = self.get_task_group(index_style, len(frames_guide)) + # run + for tasks in task_group: + index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks]) + for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"): + tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))] + source_guide, target_guide, source_style = [], [], [] + for l, m, r in tasks_batch: + # l -> m + source_guide.append(frames_guide[l]) + target_guide.append(frames_guide[m]) + source_style.append(frames_style[index_dict[l]]) + # r -> m + source_guide.append(frames_guide[r]) + target_guide.append(frames_guide[m]) + source_style.append(frames_style[index_dict[r]]) + source_guide = np.stack(source_guide) + target_guide = np.stack(target_guide) + source_style = np.stack(source_style) + _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) + if save_path is not None: + for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch): + weight_l, weight_r = self.get_weight(l, m, r) + frame = frame_l * weight_l + frame_r * weight_r + frame = frame.clip(0, 255).astype("uint8") + Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m)) + + +class InterpolationModeSingleFrameRunner: + def __init__(self): + pass + + def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None): + # check input + tracking_window_size = ebsynth_config["tracking_window_size"] + if tracking_window_size * 2 >= batch_size: + raise ValueError("batch_size should be larger than track_window_size * 2") + frame_style = frames_style[0] + frame_guide = frames_guide[index_style[0]] + patch_match_engine = PyramidPatchMatcher( + image_height=frame_style.shape[0], + image_width=frame_style.shape[1], + channel=3, + **ebsynth_config + ) + # run + frame_id, n = 0, len(frames_guide) + for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"): + if i + batch_size > n: + l, r = max(n - batch_size, 0), n + else: + l, r = i, i + batch_size + source_guide = np.stack([frame_guide] * (r-l)) + target_guide = np.stack([frames_guide[i] for i in range(l, r)]) + source_style = np.stack([frame_style] * (r-l)) + _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) + for i, frame in zip(range(l, r), target_style): + if i==frame_id: + frame = frame.clip(0, 255).astype("uint8") + Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id)) + frame_id += 1 + if r < n and r-frame_id <= tracking_window_size: + break diff --git a/diffsynth/extensions/RIFE/__init__.py b/diffsynth/extensions/RIFE/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..410550ed946348b9e537138a8b391df97f5ad84e --- /dev/null +++ b/diffsynth/extensions/RIFE/__init__.py @@ -0,0 +1,241 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from PIL import Image + + +def warp(tenInput, tenFlow, device): + backwarp_tenGrid = {} + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( + 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( + 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + backwarp_tenGrid[k] = torch.cat( + [tenHorizontal, tenVertical], 1).to(device) + + tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.PReLU(out_planes) + ) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64): + super(IFBlock, self).__init__() + self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),) + self.convblock0 = nn.Sequential(conv(c, c), conv(c, c)) + self.convblock1 = nn.Sequential(conv(c, c), conv(c, c)) + self.convblock2 = nn.Sequential(conv(c, c), conv(c, c)) + self.convblock3 = nn.Sequential(conv(c, c), conv(c, c)) + self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1)) + self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1)) + + def forward(self, x, flow, scale=1): + x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) + flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale + feat = self.conv0(torch.cat((x, flow), 1)) + feat = self.convblock0(feat) + feat + feat = self.convblock1(feat) + feat + feat = self.convblock2(feat) + feat + feat = self.convblock3(feat) + feat + flow = self.conv1(feat) + mask = self.conv2(feat) + flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale + mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) + return flow, mask + + +class IFNet(nn.Module): + def __init__(self): + super(IFNet, self).__init__() + self.block0 = IFBlock(7+4, c=90) + self.block1 = IFBlock(7+4, c=90) + self.block2 = IFBlock(7+4, c=90) + self.block_tea = IFBlock(10+4, c=90) + + def forward(self, x, scale_list=[4, 2, 1], training=False): + if training == False: + channel = x.shape[1] // 2 + img0 = x[:, :channel] + img1 = x[:, channel:] + flow_list = [] + merged = [] + mask_list = [] + warped_img0 = img0 + warped_img1 = img1 + flow = (x[:, :4]).detach() * 0 + mask = (x[:, :1]).detach() * 0 + block = [self.block0, self.block1, self.block2] + for i in range(3): + f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i]) + f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i]) + flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 + mask = mask + (m0 + (-m1)) / 2 + mask_list.append(mask) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2], device=x.device) + warped_img1 = warp(img1, flow[:, 2:4], device=x.device) + merged.append((warped_img0, warped_img1)) + ''' + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, 1:4] * 2 - 1 + ''' + for i in range(3): + mask_list[i] = torch.sigmoid(mask_list[i]) + merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) + return flow_list, mask_list[2], merged + + def state_dict_converter(self): + return IFNetStateDictConverter() + + +class IFNetStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()} + return state_dict_ + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) + + +class RIFEInterpolater: + def __init__(self, model, device="cuda"): + self.model = model + self.device = device + # IFNet only does not support float16 + self.torch_dtype = torch.float32 + + @staticmethod + def from_model_manager(model_manager): + return RIFEInterpolater(model_manager.RIFE, device=model_manager.device) + + def process_image(self, image): + width, height = image.size + if width % 32 != 0 or height % 32 != 0: + width = (width + 31) // 32 + height = (height + 31) // 32 + image = image.resize((width, height)) + image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1) + return image + + def process_images(self, images): + images = [self.process_image(image) for image in images] + images = torch.stack(images) + return images + + def decode_images(self, images): + images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8) + images = [Image.fromarray(image) for image in images] + return images + + def add_interpolated_images(self, images, interpolated_images): + output_images = [] + for image, interpolated_image in zip(images, interpolated_images): + output_images.append(image) + output_images.append(interpolated_image) + output_images.append(images[-1]) + return output_images + + + @torch.no_grad() + def interpolate_(self, images, scale=1.0): + input_tensor = self.process_images(images) + input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1) + input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype) + flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale]) + output_images = self.decode_images(merged[2].cpu()) + if output_images[0].size != images[0].size: + output_images = [image.resize(images[0].size) for image in output_images] + return output_images + + + @torch.no_grad() + def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x): + # Preprocess + processed_images = self.process_images(images) + + for iter in range(num_iter): + # Input + input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1) + + # Interpolate + output_tensor = [] + for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)): + batch_id_ = min(batch_id + batch_size, input_tensor.shape[0]) + batch_input_tensor = input_tensor[batch_id: batch_id_] + batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype) + flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale]) + output_tensor.append(merged[2].cpu()) + + # Output + output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1) + processed_images = self.add_interpolated_images(processed_images, output_tensor) + processed_images = torch.stack(processed_images) + + # To images + output_images = self.decode_images(processed_images) + if output_images[0].size != images[0].size: + output_images = [image.resize(images[0].size) for image in output_images] + return output_images + + +class RIFESmoother(RIFEInterpolater): + def __init__(self, model, device="cuda"): + super(RIFESmoother, self).__init__(model, device=device) + + @staticmethod + def from_model_manager(model_manager): + return RIFESmoother(model_manager.RIFE, device=model_manager.device) + + def process_tensors(self, input_tensor, scale=1.0, batch_size=4): + output_tensor = [] + for batch_id in range(0, input_tensor.shape[0], batch_size): + batch_id_ = min(batch_id + batch_size, input_tensor.shape[0]) + batch_input_tensor = input_tensor[batch_id: batch_id_] + batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype) + flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale]) + output_tensor.append(merged[2].cpu()) + output_tensor = torch.concat(output_tensor, dim=0) + return output_tensor + + @torch.no_grad() + def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs): + # Preprocess + processed_images = self.process_images(rendered_frames) + + for iter in range(num_iter): + # Input + input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1) + + # Interpolate + output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size) + + # Blend + input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1) + output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size) + + # Add to frames + processed_images[1:-1] = output_tensor + + # To images + output_images = self.decode_images(processed_images) + if output_images[0].size != rendered_frames[0].size: + output_images = [image.resize(rendered_frames[0].size) for image in output_images] + return output_images diff --git a/diffsynth/models/__init__.py b/diffsynth/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..21757e97ce926e9cc6ff0f896c1ca3230ff37660 --- /dev/null +++ b/diffsynth/models/__init__.py @@ -0,0 +1,482 @@ +import torch, os +from safetensors import safe_open + +from .sd_text_encoder import SDTextEncoder +from .sd_unet import SDUNet +from .sd_vae_encoder import SDVAEEncoder +from .sd_vae_decoder import SDVAEDecoder +from .sd_lora import SDLoRA + +from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2 +from .sdxl_unet import SDXLUNet +from .sdxl_vae_decoder import SDXLVAEDecoder +from .sdxl_vae_encoder import SDXLVAEEncoder + +from .sd_controlnet import SDControlNet + +from .sd_motion import SDMotionModel +from .sdxl_motion import SDXLMotionModel + +from .svd_image_encoder import SVDImageEncoder +from .svd_unet import SVDUNet +from .svd_vae_decoder import SVDVAEDecoder +from .svd_vae_encoder import SVDVAEEncoder + +from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder +from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder + +from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder +from .hunyuan_dit import HunyuanDiT + + +class ModelManager: + def __init__(self, torch_dtype=torch.float16, device="cuda"): + self.torch_dtype = torch_dtype + self.device = device + self.model = {} + self.model_path = {} + self.textual_inversion_dict = {} + + def is_stable_video_diffusion(self, state_dict): + param_name = "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight" + return param_name in state_dict + + def is_RIFE(self, state_dict): + param_name = "block_tea.convblock3.0.1.weight" + return param_name in state_dict or ("module." + param_name) in state_dict + + def is_beautiful_prompt(self, state_dict): + param_name = "transformer.h.9.self_attention.query_key_value.weight" + return param_name in state_dict + + def is_stabe_diffusion_xl(self, state_dict): + param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight" + return param_name in state_dict + + def is_stable_diffusion(self, state_dict): + if self.is_stabe_diffusion_xl(state_dict): + return False + param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight" + return param_name in state_dict + + def is_controlnet(self, state_dict): + param_name = "control_model.time_embed.0.weight" + param_name_2 = "mid_block.resnets.1.time_emb_proj.weight" # For controlnets in diffusers format + return param_name in state_dict or param_name_2 in state_dict + + def is_animatediff(self, state_dict): + param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight" + return param_name in state_dict + + def is_animatediff_xl(self, state_dict): + param_name = "up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight" + return param_name in state_dict + + def is_sd_lora(self, state_dict): + param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight" + return param_name in state_dict + + def is_translator(self, state_dict): + param_name = "model.encoder.layers.5.self_attn_layer_norm.weight" + return param_name in state_dict and len(state_dict) == 254 + + def is_ipadapter(self, state_dict): + return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([3072, 1024]) + + def is_ipadapter_image_encoder(self, state_dict): + param_name = "vision_model.encoder.layers.31.self_attn.v_proj.weight" + return param_name in state_dict and len(state_dict) == 521 + + def is_ipadapter_xl(self, state_dict): + return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([8192, 1280]) + + def is_ipadapter_xl_image_encoder(self, state_dict): + param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight" + return param_name in state_dict and len(state_dict) == 777 + + def is_hunyuan_dit_clip_text_encoder(self, state_dict): + param_name = "bert.encoder.layer.23.attention.output.dense.weight" + return param_name in state_dict + + def is_hunyuan_dit_t5_text_encoder(self, state_dict): + param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" + return param_name in state_dict + + def is_hunyuan_dit(self, state_dict): + param_name = "final_layer.adaLN_modulation.1.weight" + return param_name in state_dict + + def is_diffusers_vae(self, state_dict): + param_name = "quant_conv.weight" + return param_name in state_dict + + def is_ExVideo_StableVideoDiffusion(self, state_dict): + param_name = "blocks.185.positional_embedding.embeddings" + return param_name in state_dict + + def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None): + component_dict = { + "image_encoder": SVDImageEncoder, + "unet": SVDUNet, + "vae_decoder": SVDVAEDecoder, + "vae_encoder": SVDVAEEncoder, + } + if components is None: + components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"] + for component in components: + if component == "unet": + self.model[component] = component_dict[component](add_positional_conv=add_positional_conv) + self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False) + else: + self.model[component] = component_dict[component]() + self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict)) + self.model[component].to(self.torch_dtype).to(self.device) + self.model_path[component] = file_path + + def load_stable_diffusion(self, state_dict, components=None, file_path=""): + component_dict = { + "text_encoder": SDTextEncoder, + "unet": SDUNet, + "vae_decoder": SDVAEDecoder, + "vae_encoder": SDVAEEncoder, + "refiner": SDXLUNet, + } + if components is None: + components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"] + for component in components: + if component == "text_encoder": + # Add additional token embeddings to text encoder + token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]] + for keyword in self.textual_inversion_dict: + _, embeddings = self.textual_inversion_dict[keyword] + token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype)) + token_embeddings = torch.concat(token_embeddings, dim=0) + state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings + self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0]) + self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict)) + self.model[component].to(self.torch_dtype).to(self.device) + else: + self.model[component] = component_dict[component]() + self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict)) + self.model[component].to(self.torch_dtype).to(self.device) + self.model_path[component] = file_path + + def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""): + component_dict = { + "text_encoder": SDXLTextEncoder, + "text_encoder_2": SDXLTextEncoder2, + "unet": SDXLUNet, + "vae_decoder": SDXLVAEDecoder, + "vae_encoder": SDXLVAEEncoder, + } + if components is None: + components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"] + for component in components: + self.model[component] = component_dict[component]() + self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict)) + if component in ["vae_decoder", "vae_encoder"]: + # These two model will output nan when float16 is enabled. + # The precision problem happens in the last three resnet blocks. + # I do not know how to solve this problem. + self.model[component].to(torch.float32).to(self.device) + else: + self.model[component].to(self.torch_dtype).to(self.device) + self.model_path[component] = file_path + + def load_controlnet(self, state_dict, file_path=""): + component = "controlnet" + if component not in self.model: + self.model[component] = [] + self.model_path[component] = [] + model = SDControlNet() + model.load_state_dict(model.state_dict_converter().from_civitai(state_dict)) + model.to(self.torch_dtype).to(self.device) + self.model[component].append(model) + self.model_path[component].append(file_path) + + def load_animatediff(self, state_dict, file_path=""): + component = "motion_modules" + model = SDMotionModel() + model.load_state_dict(model.state_dict_converter().from_civitai(state_dict)) + model.to(self.torch_dtype).to(self.device) + self.model[component] = model + self.model_path[component] = file_path + + def load_animatediff_xl(self, state_dict, file_path=""): + component = "motion_modules_xl" + model = SDXLMotionModel() + model.load_state_dict(model.state_dict_converter().from_civitai(state_dict)) + model.to(self.torch_dtype).to(self.device) + self.model[component] = model + self.model_path[component] = file_path + + def load_beautiful_prompt(self, state_dict, file_path=""): + component = "beautiful_prompt" + from transformers import AutoModelForCausalLM + model_folder = os.path.dirname(file_path) + model = AutoModelForCausalLM.from_pretrained( + model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype + ).to(self.device).eval() + self.model[component] = model + self.model_path[component] = file_path + + def load_RIFE(self, state_dict, file_path=""): + component = "RIFE" + from ..extensions.RIFE import IFNet + model = IFNet().eval() + model.load_state_dict(model.state_dict_converter().from_civitai(state_dict)) + model.to(torch.float32).to(self.device) + self.model[component] = model + self.model_path[component] = file_path + + def load_sd_lora(self, state_dict, alpha): + SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device) + SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device) + + def load_translator(self, state_dict, file_path=""): + # This model is lightweight, we do not place it on GPU. + component = "translator" + from transformers import AutoModelForSeq2SeqLM + model_folder = os.path.dirname(file_path) + model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval() + self.model[component] = model + self.model_path[component] = file_path + + def load_ipadapter(self, state_dict, file_path=""): + component = "ipadapter" + model = SDIpAdapter() + model.load_state_dict(model.state_dict_converter().from_civitai(state_dict)) + model.to(self.torch_dtype).to(self.device) + self.model[component] = model + self.model_path[component] = file_path + + def load_ipadapter_image_encoder(self, state_dict, file_path=""): + component = "ipadapter_image_encoder" + model = IpAdapterCLIPImageEmbedder() + model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict)) + model.to(self.torch_dtype).to(self.device) + self.model[component] = model + self.model_path[component] = file_path + + def load_ipadapter_xl(self, state_dict, file_path=""): + component = "ipadapter_xl" + model = SDXLIpAdapter() + model.load_state_dict(model.state_dict_converter().from_civitai(state_dict)) + model.to(self.torch_dtype).to(self.device) + self.model[component] = model + self.model_path[component] = file_path + + def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""): + component = "ipadapter_xl_image_encoder" + model = IpAdapterXLCLIPImageEmbedder() + model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict)) + model.to(self.torch_dtype).to(self.device) + self.model[component] = model + self.model_path[component] = file_path + + def load_hunyuan_dit_clip_text_encoder(self, state_dict, file_path=""): + component = "hunyuan_dit_clip_text_encoder" + model = HunyuanDiTCLIPTextEncoder() + model.load_state_dict(model.state_dict_converter().from_civitai(state_dict)) + model.to(self.torch_dtype).to(self.device) + self.model[component] = model + self.model_path[component] = file_path + + def load_hunyuan_dit_t5_text_encoder(self, state_dict, file_path=""): + component = "hunyuan_dit_t5_text_encoder" + model = HunyuanDiTT5TextEncoder() + model.load_state_dict(model.state_dict_converter().from_civitai(state_dict)) + model.to(self.torch_dtype).to(self.device) + self.model[component] = model + self.model_path[component] = file_path + + def load_hunyuan_dit(self, state_dict, file_path=""): + component = "hunyuan_dit" + model = HunyuanDiT() + model.load_state_dict(model.state_dict_converter().from_civitai(state_dict)) + model.to(self.torch_dtype).to(self.device) + self.model[component] = model + self.model_path[component] = file_path + + def load_diffusers_vae(self, state_dict, file_path=""): + # TODO: detect SD and SDXL + component = "vae_encoder" + model = SDXLVAEEncoder() + model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict)) + model.to(self.torch_dtype).to(self.device) + self.model[component] = model + self.model_path[component] = file_path + component = "vae_decoder" + model = SDXLVAEDecoder() + model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict)) + model.to(self.torch_dtype).to(self.device) + self.model[component] = model + self.model_path[component] = file_path + + def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""): + unet_state_dict = self.model["unet"].state_dict() + self.model["unet"].to("cpu") + del self.model["unet"] + add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0] + self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv) + self.model["unet"].load_state_dict(unet_state_dict, strict=False) + self.model["unet"].load_state_dict(state_dict, strict=False) + self.model["unet"].to(self.torch_dtype).to(self.device) + + def search_for_embeddings(self, state_dict): + embeddings = [] + for k in state_dict: + if isinstance(state_dict[k], torch.Tensor): + embeddings.append(state_dict[k]) + elif isinstance(state_dict[k], dict): + embeddings += self.search_for_embeddings(state_dict[k]) + return embeddings + + def load_textual_inversions(self, folder): + # Store additional tokens here + self.textual_inversion_dict = {} + + # Load every textual inversion file + for file_name in os.listdir(folder): + if file_name.endswith(".txt"): + continue + keyword = os.path.splitext(file_name)[0] + state_dict = load_state_dict(os.path.join(folder, file_name)) + + # Search for embeddings + for embeddings in self.search_for_embeddings(state_dict): + if len(embeddings.shape) == 2 and embeddings.shape[1] == 768: + tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])] + self.textual_inversion_dict[keyword] = (tokens, embeddings) + break + + def load_model(self, file_path, components=None, lora_alphas=[]): + state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype) + if self.is_stable_video_diffusion(state_dict): + self.load_stable_video_diffusion(state_dict, file_path=file_path) + elif self.is_animatediff(state_dict): + self.load_animatediff(state_dict, file_path=file_path) + elif self.is_animatediff_xl(state_dict): + self.load_animatediff_xl(state_dict, file_path=file_path) + elif self.is_controlnet(state_dict): + self.load_controlnet(state_dict, file_path=file_path) + elif self.is_stabe_diffusion_xl(state_dict): + self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path) + elif self.is_stable_diffusion(state_dict): + self.load_stable_diffusion(state_dict, components=components, file_path=file_path) + elif self.is_sd_lora(state_dict): + self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0)) + elif self.is_beautiful_prompt(state_dict): + self.load_beautiful_prompt(state_dict, file_path=file_path) + elif self.is_RIFE(state_dict): + self.load_RIFE(state_dict, file_path=file_path) + elif self.is_translator(state_dict): + self.load_translator(state_dict, file_path=file_path) + elif self.is_ipadapter(state_dict): + self.load_ipadapter(state_dict, file_path=file_path) + elif self.is_ipadapter_image_encoder(state_dict): + self.load_ipadapter_image_encoder(state_dict, file_path=file_path) + elif self.is_ipadapter_xl(state_dict): + self.load_ipadapter_xl(state_dict, file_path=file_path) + elif self.is_ipadapter_xl_image_encoder(state_dict): + self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path) + elif self.is_hunyuan_dit_clip_text_encoder(state_dict): + self.load_hunyuan_dit_clip_text_encoder(state_dict, file_path=file_path) + elif self.is_hunyuan_dit_t5_text_encoder(state_dict): + self.load_hunyuan_dit_t5_text_encoder(state_dict, file_path=file_path) + elif self.is_hunyuan_dit(state_dict): + self.load_hunyuan_dit(state_dict, file_path=file_path) + elif self.is_diffusers_vae(state_dict): + self.load_diffusers_vae(state_dict, file_path=file_path) + elif self.is_ExVideo_StableVideoDiffusion(state_dict): + self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path) + + def load_models(self, file_path_list, lora_alphas=[]): + for file_path in file_path_list: + self.load_model(file_path, lora_alphas=lora_alphas) + + def to(self, device): + for component in self.model: + if isinstance(self.model[component], list): + for model in self.model[component]: + model.to(device) + else: + self.model[component].to(device) + torch.cuda.empty_cache() + + def get_model_with_model_path(self, model_path): + for component in self.model_path: + if isinstance(self.model_path[component], str): + if os.path.samefile(self.model_path[component], model_path): + return self.model[component] + elif isinstance(self.model_path[component], list): + for i, model_path_ in enumerate(self.model_path[component]): + if os.path.samefile(model_path_, model_path): + return self.model[component][i] + raise ValueError(f"Please load model {model_path} before you use it.") + + def __getattr__(self, __name): + if __name in self.model: + return self.model[__name] + else: + return super.__getattribute__(__name) + + +def load_state_dict(file_path, torch_dtype=None): + if file_path.endswith(".safetensors"): + return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) + else: + return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) + + +def load_state_dict_from_safetensors(file_path, torch_dtype=None): + state_dict = {} + with safe_open(file_path, framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + if torch_dtype is not None: + state_dict[k] = state_dict[k].to(torch_dtype) + return state_dict + + +def load_state_dict_from_bin(file_path, torch_dtype=None): + state_dict = torch.load(file_path, map_location="cpu") + if torch_dtype is not None: + for i in state_dict: + if isinstance(state_dict[i], torch.Tensor): + state_dict[i] = state_dict[i].to(torch_dtype) + return state_dict + + +def search_parameter(param, state_dict): + for name, param_ in state_dict.items(): + if param.numel() == param_.numel(): + if param.shape == param_.shape: + if torch.dist(param, param_) < 1e-6: + return name + else: + if torch.dist(param.flatten(), param_.flatten()) < 1e-6: + return name + return None + + +def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False): + matched_keys = set() + with torch.no_grad(): + for name in source_state_dict: + rename = search_parameter(source_state_dict[name], target_state_dict) + if rename is not None: + print(f'"{name}": "{rename}",') + matched_keys.add(rename) + elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0: + length = source_state_dict[name].shape[0] // 3 + rename = [] + for i in range(3): + rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict)) + if None not in rename: + print(f'"{name}": {rename},') + for rename_ in rename: + matched_keys.add(rename_) + for name in target_state_dict: + if name not in matched_keys: + print("Cannot find", name, target_state_dict[name].shape) diff --git a/diffsynth/models/attention.py b/diffsynth/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..eb90e1ed1a28a0541a8d9df8313997a7d3f14da7 --- /dev/null +++ b/diffsynth/models/attention.py @@ -0,0 +1,89 @@ +import torch +from einops import rearrange + + +def low_version_attention(query, key, value, attn_bias=None): + scale = 1 / query.shape[-1] ** 0.5 + query = query * scale + attn = torch.matmul(query, key.transpose(-2, -1)) + if attn_bias is not None: + attn = attn + attn_bias + attn = attn.softmax(-1) + return attn @ value + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0): + batch_size = q.shape[0] + ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) + hidden_states = hidden_states + scale * ip_hidden_states + return hidden_states + + def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if qkv_preprocessor is not None: + q, k, v = qkv_preprocessor(q, k, v) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + if ipadapter_kwargs is not None: + hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads) + k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads) + v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads) + + if attn_mask is not None: + hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask) + else: + import xformers.ops as xops + hidden_states = xops.memory_efficient_attention(q, k, v) + hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads) + + hidden_states = hidden_states.to(q.dtype) + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor) \ No newline at end of file diff --git a/diffsynth/models/hunyuan_dit.py b/diffsynth/models/hunyuan_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..ca797e823cbe69ec67fd7a3d5519e850e4039a8c --- /dev/null +++ b/diffsynth/models/hunyuan_dit.py @@ -0,0 +1,451 @@ +from .attention import Attention +from .tiler import TileWorker +from einops import repeat, rearrange +import math +import torch + + +class HunyuanDiTRotaryEmbedding(torch.nn.Module): + + def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True): + super().__init__() + self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06) + self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06) + self.rotary_emb_on_k = rotary_emb_on_k + self.k_cache, self.v_cache = [], [] + + def reshape_for_broadcast(self, freqs_cis, x): + ndim = x.ndim + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + + def rotate_half(self, x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + def apply_rotary_emb(self, xq, xk, freqs_cis): + xk_out = None + cos, sin = self.reshape_for_broadcast(freqs_cis, xq) + cos, sin = cos.to(xq.device), sin.to(xq.device) + xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq) + if xk is not None: + xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk) + return xq_out, xk_out + + def forward(self, q, k, v, freqs_cis_img, to_cache=False): + # norm + q = self.q_norm(q) + k = self.k_norm(k) + + # RoPE + if self.rotary_emb_on_k: + q, k = self.apply_rotary_emb(q, k, freqs_cis_img) + else: + q, _ = self.apply_rotary_emb(q, None, freqs_cis_img) + + if to_cache: + self.k_cache.append(k) + self.v_cache.append(v) + elif len(self.k_cache) > 0 and len(self.v_cache) > 0: + k = torch.concat([k] + self.k_cache, dim=2) + v = torch.concat([v] + self.v_cache, dim=2) + self.k_cache, self.v_cache = [], [] + return q, k, v + + +class FP32_Layernorm(torch.nn.LayerNorm): + def forward(self, inputs): + origin_dtype = inputs.dtype + return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype) + + +class FP32_SiLU(torch.nn.SiLU): + def forward(self, inputs): + origin_dtype = inputs.dtype + return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype) + + +class HunyuanDiTFinalLayer(torch.nn.Module): + def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8): + super().__init__() + self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = torch.nn.Sequential( + FP32_SiLU(), + torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True) + ) + + def modulate(self, x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + def forward(self, hidden_states, condition_emb): + shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1) + hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale) + hidden_states = self.linear(hidden_states) + return hidden_states + + +class HunyuanDiTBlock(torch.nn.Module): + + def __init__( + self, + hidden_dim=1408, + condition_dim=1408, + num_heads=16, + mlp_ratio=4.3637, + text_dim=1024, + skip_connection=False + ): + super().__init__() + self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True) + self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads) + self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True) + self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True) + self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False) + self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True) + self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True) + self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True)) + self.mlp = torch.nn.Sequential( + torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True), + torch.nn.GELU(approximate="tanh"), + torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True) + ) + if skip_connection: + self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True) + self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True) + else: + self.skip_norm, self.skip_linear = None, None + + def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False): + # Long Skip Connection + if self.skip_norm is not None and self.skip_linear is not None: + hidden_states = torch.cat([hidden_states, residual], dim=-1) + hidden_states = self.skip_norm(hidden_states) + hidden_states = self.skip_linear(hidden_states) + + # Self-Attention + shift_msa = self.modulation(condition_emb).unsqueeze(dim=1) + attn_input = self.norm1(hidden_states) + shift_msa + hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache)) + + # Cross-Attention + attn_input = self.norm3(hidden_states) + hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img)) + + # FFN Layer + mlp_input = self.norm2(hidden_states) + hidden_states = hidden_states + self.mlp(mlp_input) + return hidden_states + + +class AttentionPool(torch.nn.Module): + def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None): + super().__init__() + self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = torch.nn.Linear(embed_dim, embed_dim) + self.q_proj = torch.nn.Linear(embed_dim, embed_dim) + self.v_proj = torch.nn.Linear(embed_dim, embed_dim) + self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.permute(1, 0, 2) # NLC -> LNC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC + x, _ = torch.nn.functional.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +class PatchEmbed(torch.nn.Module): + def __init__( + self, + patch_size=(2, 2), + in_chans=4, + embed_dim=1408, + bias=True, + ): + super().__init__() + self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + + def forward(self, x): + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + return x + + +def timestep_embedding(t, dim, max_period=10000, repeat_only=False): + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线 + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(t, "b -> b d", d=dim) + return embedding + + +class TimestepEmbedder(torch.nn.Module): + def __init__(self, hidden_size=1408, frequency_embedding_size=256): + super().__init__() + self.mlp = torch.nn.Sequential( + torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True), + torch.nn.SiLU(), + torch.nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class HunyuanDiT(torch.nn.Module): + def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256): + super().__init__() + + # Embedders + self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32)) + self.t5_embedder = torch.nn.Sequential( + torch.nn.Linear(t5_dim, t5_dim * 4, bias=True), + FP32_SiLU(), + torch.nn.Linear(t5_dim * 4, text_dim, bias=True), + ) + self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024) + self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim)) + self.patch_embedder = PatchEmbed(in_chans=in_channels) + self.timestep_embedder = TimestepEmbedder() + self.extra_embedder = torch.nn.Sequential( + torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4), + FP32_SiLU(), + torch.nn.Linear(hidden_dim * 4, hidden_dim), + ) + + # Transformer blocks + self.num_layers_down = num_layers_down + self.num_layers_up = num_layers_up + self.blocks = torch.nn.ModuleList( + [HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \ + [HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)] + ) + + # Output layers + self.final_layer = HunyuanDiTFinalLayer() + self.out_channels = out_channels + + def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5): + text_emb_mask = text_emb_mask.bool() + text_emb_mask_t5 = text_emb_mask_t5.bool() + text_emb_t5 = self.t5_embedder(text_emb_t5) + text_emb = torch.cat([text_emb, text_emb_t5], dim=1) + text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1) + text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb)) + return text_emb + + def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size): + # Text embedding + pooled_text_emb_t5 = self.t5_pooler(text_emb_t5) + + # Timestep embedding + timestep_emb = self.timestep_embedder(timestep) + + # Size embedding + size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype) + size_emb = size_emb.view(-1, 6 * 256) + + # Style embedding + style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size) + + # Concatenate all extra vectors + extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1) + condition_emb = timestep_emb + self.extra_embedder(extra_emb) + + return condition_emb + + def unpatchify(self, x, h, w): + return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2) + + def build_mask(self, data, is_bound): + _, _, H, W = data.shape + h = repeat(torch.arange(H), "H -> H W", H=H, W=W) + w = repeat(torch.arange(W), "W -> H W", H=H, W=W) + border_width = (H + W) // 4 + pad = torch.ones_like(h) * border_width + mask = torch.stack([ + pad if is_bound[0] else h + 1, + pad if is_bound[1] else H - h, + pad if is_bound[2] else w + 1, + pad if is_bound[3] else W - w + ]).min(dim=0).values + mask = mask.clip(1, border_width) + mask = (mask / border_width).to(dtype=data.dtype, device=data.device) + mask = rearrange(mask, "H W -> 1 H W") + return mask + + def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride): + B, C, H, W = hidden_states.shape + + weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device) + values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device) + + # Split tasks + tasks = [] + for h in range(0, H, tile_stride): + for w in range(0, W, tile_stride): + if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W): + continue + h_, w_ = h + tile_size, w + tile_size + if h_ > H: h, h_ = H - tile_size, H + if w_ > W: w, w_ = W - tile_size, W + tasks.append((h, h_, w, w_)) + + # Run + for hl, hr, wl, wr in tasks: + hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device) + hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C") + if residual is not None: + residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device) + residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C") + else: + residual_batch = None + + # Forward + hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device) + hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl) + + mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W)) + values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask + weight[:, :, hl:hr, wl:wr] += mask + values /= weight + return values + + def forward( + self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img, + tiled=False, tile_size=64, tile_stride=32, + to_cache=False, + use_gradient_checkpointing=False, + ): + # Embeddings + text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5) + condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0]) + + # Input + height, width = hidden_states.shape[-2], hidden_states.shape[-1] + hidden_states = self.patch_embedder(hidden_states) + + # Blocks + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + if tiled: + hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2) + residuals = [] + for block_id, block in enumerate(self.blocks): + residual = residuals.pop() if block_id >= self.num_layers_down else None + hidden_states = self.tiled_block_forward( + block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, + torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device, + tile_size=tile_size, tile_stride=tile_stride + ) + if block_id < self.num_layers_down - 2: + residuals.append(hidden_states) + hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C") + else: + residuals = [] + for block_id, block in enumerate(self.blocks): + residual = residuals.pop() if block_id >= self.num_layers_down else None + if self.training and use_gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, condition_emb, text_emb, freq_cis_img, residual, + use_reentrant=False, + ) + else: + hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache) + if block_id < self.num_layers_down - 2: + residuals.append(hidden_states) + + # Output + hidden_states = self.final_layer(hidden_states, condition_emb) + hidden_states = self.unpatchify(hidden_states, height//2, width//2) + hidden_states, _ = hidden_states.chunk(2, dim=1) + return hidden_states + + def state_dict_converter(self): + return HunyuanDiTStateDictConverter() + + + +class HunyuanDiTStateDictConverter(): + def __init__(self): + pass + + def from_diffusers(self, state_dict): + state_dict_ = {} + for name, param in state_dict.items(): + name_ = name + name_ = name_.replace(".default_modulation.", ".modulation.") + name_ = name_.replace(".mlp.fc1.", ".mlp.0.") + name_ = name_.replace(".mlp.fc2.", ".mlp.2.") + name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.") + name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.") + name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.") + name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.") + name_ = name_.replace(".q_proj.", ".to_q.") + name_ = name_.replace(".out_proj.", ".to_out.") + name_ = name_.replace("text_embedding_padding", "text_emb_padding") + name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.") + name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.") + name_ = name_.replace("pooler.", "t5_pooler.") + name_ = name_.replace("x_embedder.", "patch_embedder.") + name_ = name_.replace("t_embedder.", "timestep_embedder.") + name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.") + name_ = name_.replace("style_embedder.weight", "style_embedder") + if ".kv_proj." in name_: + param_k = param[:param.shape[0]//2] + param_v = param[param.shape[0]//2:] + state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k + state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v + elif ".Wqkv." in name_: + param_q = param[:param.shape[0]//3] + param_k = param[param.shape[0]//3:param.shape[0]//3*2] + param_v = param[param.shape[0]//3*2:] + state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q + state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k + state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v + elif "style_embedder" in name_: + state_dict_[name_] = param.squeeze() + else: + state_dict_[name_] = param + return state_dict_ + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) diff --git a/diffsynth/models/hunyuan_dit_text_encoder.py b/diffsynth/models/hunyuan_dit_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..efa5d39148590352b2d4919f6bb31f070e40a8f5 --- /dev/null +++ b/diffsynth/models/hunyuan_dit_text_encoder.py @@ -0,0 +1,161 @@ +from transformers import BertModel, BertConfig, T5EncoderModel, T5Config +import torch + + + +class HunyuanDiTCLIPTextEncoder(BertModel): + def __init__(self): + config = BertConfig( + _name_or_path = "", + architectures = ["BertModel"], + attention_probs_dropout_prob = 0.1, + bos_token_id = 0, + classifier_dropout = None, + directionality = "bidi", + eos_token_id = 2, + hidden_act = "gelu", + hidden_dropout_prob = 0.1, + hidden_size = 1024, + initializer_range = 0.02, + intermediate_size = 4096, + layer_norm_eps = 1e-12, + max_position_embeddings = 512, + model_type = "bert", + num_attention_heads = 16, + num_hidden_layers = 24, + output_past = True, + pad_token_id = 0, + pooler_fc_size = 768, + pooler_num_attention_heads = 12, + pooler_num_fc_layers = 3, + pooler_size_per_head = 128, + pooler_type = "first_token_transform", + position_embedding_type = "absolute", + torch_dtype = "float32", + transformers_version = "4.37.2", + type_vocab_size = 2, + use_cache = True, + vocab_size = 47020 + ) + super().__init__(config, add_pooling_layer=False) + self.eval() + + def forward(self, input_ids, attention_mask, clip_skip=1): + input_shape = input_ids.size() + + batch_size, seq_length = input_shape + device = input_ids.device + + past_key_values_length = 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=None, + token_type_ids=None, + inputs_embeds=None, + past_key_values_length=0, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=False, + output_attentions=False, + output_hidden_states=True, + return_dict=True, + ) + all_hidden_states = encoder_outputs.hidden_states + prompt_emb = all_hidden_states[-clip_skip] + if clip_skip > 1: + mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std() + prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean + return prompt_emb + + def state_dict_converter(self): + return HunyuanDiTCLIPTextEncoderStateDictConverter() + + + +class HunyuanDiTT5TextEncoder(T5EncoderModel): + def __init__(self): + config = T5Config( + _name_or_path = "../HunyuanDiT/t2i/mt5", + architectures = ["MT5ForConditionalGeneration"], + classifier_dropout = 0.0, + d_ff = 5120, + d_kv = 64, + d_model = 2048, + decoder_start_token_id = 0, + dense_act_fn = "gelu_new", + dropout_rate = 0.1, + eos_token_id = 1, + feed_forward_proj = "gated-gelu", + initializer_factor = 1.0, + is_encoder_decoder = True, + is_gated_act = True, + layer_norm_epsilon = 1e-06, + model_type = "t5", + num_decoder_layers = 24, + num_heads = 32, + num_layers = 24, + output_past = True, + pad_token_id = 0, + relative_attention_max_distance = 128, + relative_attention_num_buckets = 32, + tie_word_embeddings = False, + tokenizer_class = "T5Tokenizer", + transformers_version = "4.37.2", + use_cache = True, + vocab_size = 250112 + ) + super().__init__(config) + self.eval() + + def forward(self, input_ids, attention_mask, clip_skip=1): + outputs = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + prompt_emb = outputs.hidden_states[-clip_skip] + if clip_skip > 1: + mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std() + prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean + return prompt_emb + + def state_dict_converter(self): + return HunyuanDiTT5TextEncoderStateDictConverter() + + + +class HunyuanDiTCLIPTextEncoderStateDictConverter(): + def __init__(self): + pass + + def from_diffusers(self, state_dict): + state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")} + return state_dict_ + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) + + +class HunyuanDiTT5TextEncoderStateDictConverter(): + def __init__(self): + pass + + def from_diffusers(self, state_dict): + state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")} + state_dict_["shared.weight"] = state_dict["shared.weight"] + return state_dict_ + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) diff --git a/diffsynth/models/sd_controlnet.py b/diffsynth/models/sd_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..15a483860999e089e0b7f6a4170252bbd21b7903 --- /dev/null +++ b/diffsynth/models/sd_controlnet.py @@ -0,0 +1,587 @@ +import torch +from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler +from .tiler import TileWorker + + +class ControlNetConditioningLayer(torch.nn.Module): + def __init__(self, channels = (3, 16, 32, 96, 256, 320)): + super().__init__() + self.blocks = torch.nn.ModuleList([]) + self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1)) + self.blocks.append(torch.nn.SiLU()) + for i in range(1, len(channels) - 2): + self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1)) + self.blocks.append(torch.nn.SiLU()) + self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2)) + self.blocks.append(torch.nn.SiLU()) + self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1)) + + def forward(self, conditioning): + for block in self.blocks: + conditioning = block(conditioning) + return conditioning + + +class SDControlNet(torch.nn.Module): + def __init__(self, global_pool=False): + super().__init__() + self.time_proj = Timesteps(320) + self.time_embedding = torch.nn.Sequential( + torch.nn.Linear(320, 1280), + torch.nn.SiLU(), + torch.nn.Linear(1280, 1280) + ) + self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1) + + self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320)) + + self.blocks = torch.nn.ModuleList([ + # CrossAttnDownBlock2D + ResnetBlock(320, 320, 1280), + AttentionBlock(8, 40, 320, 1, 768), + PushBlock(), + ResnetBlock(320, 320, 1280), + AttentionBlock(8, 40, 320, 1, 768), + PushBlock(), + DownSampler(320), + PushBlock(), + # CrossAttnDownBlock2D + ResnetBlock(320, 640, 1280), + AttentionBlock(8, 80, 640, 1, 768), + PushBlock(), + ResnetBlock(640, 640, 1280), + AttentionBlock(8, 80, 640, 1, 768), + PushBlock(), + DownSampler(640), + PushBlock(), + # CrossAttnDownBlock2D + ResnetBlock(640, 1280, 1280), + AttentionBlock(8, 160, 1280, 1, 768), + PushBlock(), + ResnetBlock(1280, 1280, 1280), + AttentionBlock(8, 160, 1280, 1, 768), + PushBlock(), + DownSampler(1280), + PushBlock(), + # DownBlock2D + ResnetBlock(1280, 1280, 1280), + PushBlock(), + ResnetBlock(1280, 1280, 1280), + PushBlock(), + # UNetMidBlock2DCrossAttn + ResnetBlock(1280, 1280, 1280), + AttentionBlock(8, 160, 1280, 1, 768), + ResnetBlock(1280, 1280, 1280), + PushBlock() + ]) + + self.controlnet_blocks = torch.nn.ModuleList([ + torch.nn.Conv2d(320, 320, kernel_size=(1, 1)), + torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False), + torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False), + torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False), + torch.nn.Conv2d(640, 640, kernel_size=(1, 1)), + torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False), + torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False), + ]) + + self.global_pool = global_pool + + def forward( + self, + sample, timestep, encoder_hidden_states, conditioning, + tiled=False, tile_size=64, tile_stride=32, + ): + # 1. time + time_emb = self.time_proj(timestep[None]).to(sample.dtype) + time_emb = self.time_embedding(time_emb) + time_emb = time_emb.repeat(sample.shape[0], 1) + + # 2. pre-process + height, width = sample.shape[2], sample.shape[3] + hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning) + text_emb = encoder_hidden_states + res_stack = [hidden_states] + + # 3. blocks + for i, block in enumerate(self.blocks): + if tiled and not isinstance(block, PushBlock): + _, _, inter_height, _ = hidden_states.shape + resize_scale = inter_height / height + hidden_states = TileWorker().tiled_forward( + lambda x: block(x, time_emb, text_emb, res_stack)[0], + hidden_states, + int(tile_size * resize_scale), + int(tile_stride * resize_scale), + tile_device=hidden_states.device, + tile_dtype=hidden_states.dtype + ) + else: + hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack) + + # 4. ControlNet blocks + controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)] + + # pool + if self.global_pool: + controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack] + + return controlnet_res_stack + + def state_dict_converter(self): + return SDControlNetStateDictConverter() + + +class SDControlNetStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + # architecture + block_types = [ + 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock', + 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock', + 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock', + 'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock', + 'ResnetBlock', 'AttentionBlock', 'ResnetBlock', + 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler', + 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler', + 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler', + 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock' + ] + + # controlnet_rename_dict + controlnet_rename_dict = { + "controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight", + "controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias", + "controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight", + "controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias", + "controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight", + "controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias", + "controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight", + "controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias", + "controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight", + "controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias", + "controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight", + "controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias", + "controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight", + "controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias", + "controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight", + "controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias", + } + + # Rename each parameter + name_list = sorted([name for name in state_dict]) + rename_dict = {} + block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1} + last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""} + for name in name_list: + names = name.split(".") + if names[0] in ["conv_in", "conv_norm_out", "conv_out"]: + pass + elif name in controlnet_rename_dict: + names = controlnet_rename_dict[name].split(".") + elif names[0] == "controlnet_down_blocks": + names[0] = "controlnet_blocks" + elif names[0] == "controlnet_mid_block": + names = ["controlnet_blocks", "12", names[-1]] + elif names[0] in ["time_embedding", "add_embedding"]: + if names[0] == "add_embedding": + names[0] = "add_time_embedding" + names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]] + elif names[0] in ["down_blocks", "mid_block", "up_blocks"]: + if names[0] == "mid_block": + names.insert(1, "0") + block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]] + block_type_with_id = ".".join(names[:4]) + if block_type_with_id != last_block_type_with_id[block_type]: + block_id[block_type] += 1 + last_block_type_with_id[block_type] = block_type_with_id + while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type: + block_id[block_type] += 1 + block_type_with_id = ".".join(names[:4]) + names = ["blocks", str(block_id[block_type])] + names[4:] + if "ff" in names: + ff_index = names.index("ff") + component = ".".join(names[ff_index:ff_index+3]) + component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component] + names = names[:ff_index] + [component] + names[ff_index+3:] + if "to_out" in names: + names.pop(names.index("to_out") + 1) + else: + raise ValueError(f"Unknown parameters: {name}") + rename_dict[name] = ".".join(names) + + # Convert state_dict + state_dict_ = {} + for name, param in state_dict.items(): + if ".proj_in." in name or ".proj_out." in name: + param = param.squeeze() + if rename_dict[name] in [ + "controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias", + "controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias" + ]: + continue + state_dict_[rename_dict[name]] = param + return state_dict_ + + def from_civitai(self, state_dict): + if "mid_block.resnets.1.time_emb_proj.weight" in state_dict: + # For controlnets in diffusers format + return self.from_diffusers(state_dict) + rename_dict = { + "control_model.time_embed.0.weight": "time_embedding.0.weight", + "control_model.time_embed.0.bias": "time_embedding.0.bias", + "control_model.time_embed.2.weight": "time_embedding.2.weight", + "control_model.time_embed.2.bias": "time_embedding.2.bias", + "control_model.input_blocks.0.0.weight": "conv_in.weight", + "control_model.input_blocks.0.0.bias": "conv_in.bias", + "control_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight", + "control_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias", + "control_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight", + "control_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias", + "control_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight", + "control_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias", + "control_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight", + "control_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias", + "control_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight", + "control_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias", + "control_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight", + "control_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias", + "control_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight", + "control_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias", + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias", + "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias", + "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias", + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias", + "control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias", + "control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias", + "control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight", + "control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias", + "control_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight", + "control_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias", + "control_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight", + "control_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias", + "control_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight", + "control_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias", + "control_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight", + "control_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias", + "control_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight", + "control_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias", + "control_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight", + "control_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias", + "control_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight", + "control_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias", + "control_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight", + "control_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias", + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias", + "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias", + "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias", + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias", + "control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias", + "control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias", + "control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight", + "control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias", + "control_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight", + "control_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias", + "control_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight", + "control_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias", + "control_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight", + "control_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias", + "control_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight", + "control_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias", + "control_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight", + "control_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias", + "control_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight", + "control_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias", + "control_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight", + "control_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias", + "control_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight", + "control_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias", + "control_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight", + "control_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias", + "control_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight", + "control_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias", + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias", + "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias", + "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias", + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias", + "control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias", + "control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias", + "control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight", + "control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias", + "control_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight", + "control_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias", + "control_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight", + "control_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias", + "control_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight", + "control_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias", + "control_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight", + "control_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias", + "control_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight", + "control_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias", + "control_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight", + "control_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias", + "control_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight", + "control_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias", + "control_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight", + "control_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias", + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias", + "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias", + "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias", + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias", + "control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias", + "control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias", + "control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight", + "control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias", + "control_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight", + "control_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias", + "control_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight", + "control_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias", + "control_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight", + "control_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias", + "control_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight", + "control_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias", + "control_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight", + "control_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias", + "control_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight", + "control_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias", + "control_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight", + "control_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias", + "control_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight", + "control_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias", + "control_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight", + "control_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias", + "control_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight", + "control_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias", + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias", + "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias", + "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias", + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias", + "control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias", + "control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias", + "control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight", + "control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias", + "control_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight", + "control_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias", + "control_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight", + "control_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias", + "control_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight", + "control_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias", + "control_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight", + "control_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias", + "control_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight", + "control_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias", + "control_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight", + "control_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias", + "control_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight", + "control_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias", + "control_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight", + "control_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias", + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias", + "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias", + "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias", + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias", + "control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias", + "control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias", + "control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight", + "control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias", + "control_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight", + "control_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias", + "control_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight", + "control_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias", + "control_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight", + "control_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias", + "control_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight", + "control_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias", + "control_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight", + "control_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias", + "control_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight", + "control_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias", + "control_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight", + "control_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias", + "control_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight", + "control_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias", + "control_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight", + "control_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias", + "control_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight", + "control_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias", + "control_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight", + "control_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias", + "control_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight", + "control_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias", + "control_model.zero_convs.0.0.weight": "controlnet_blocks.0.weight", + "control_model.zero_convs.0.0.bias": "controlnet_blocks.0.bias", + "control_model.zero_convs.1.0.weight": "controlnet_blocks.1.weight", + "control_model.zero_convs.1.0.bias": "controlnet_blocks.0.bias", + "control_model.zero_convs.2.0.weight": "controlnet_blocks.2.weight", + "control_model.zero_convs.2.0.bias": "controlnet_blocks.0.bias", + "control_model.zero_convs.3.0.weight": "controlnet_blocks.3.weight", + "control_model.zero_convs.3.0.bias": "controlnet_blocks.0.bias", + "control_model.zero_convs.4.0.weight": "controlnet_blocks.4.weight", + "control_model.zero_convs.4.0.bias": "controlnet_blocks.4.bias", + "control_model.zero_convs.5.0.weight": "controlnet_blocks.5.weight", + "control_model.zero_convs.5.0.bias": "controlnet_blocks.4.bias", + "control_model.zero_convs.6.0.weight": "controlnet_blocks.6.weight", + "control_model.zero_convs.6.0.bias": "controlnet_blocks.4.bias", + "control_model.zero_convs.7.0.weight": "controlnet_blocks.7.weight", + "control_model.zero_convs.7.0.bias": "controlnet_blocks.7.bias", + "control_model.zero_convs.8.0.weight": "controlnet_blocks.8.weight", + "control_model.zero_convs.8.0.bias": "controlnet_blocks.7.bias", + "control_model.zero_convs.9.0.weight": "controlnet_blocks.9.weight", + "control_model.zero_convs.9.0.bias": "controlnet_blocks.7.bias", + "control_model.zero_convs.10.0.weight": "controlnet_blocks.10.weight", + "control_model.zero_convs.10.0.bias": "controlnet_blocks.7.bias", + "control_model.zero_convs.11.0.weight": "controlnet_blocks.11.weight", + "control_model.zero_convs.11.0.bias": "controlnet_blocks.7.bias", + "control_model.input_hint_block.0.weight": "controlnet_conv_in.blocks.0.weight", + "control_model.input_hint_block.0.bias": "controlnet_conv_in.blocks.0.bias", + "control_model.input_hint_block.2.weight": "controlnet_conv_in.blocks.2.weight", + "control_model.input_hint_block.2.bias": "controlnet_conv_in.blocks.2.bias", + "control_model.input_hint_block.4.weight": "controlnet_conv_in.blocks.4.weight", + "control_model.input_hint_block.4.bias": "controlnet_conv_in.blocks.4.bias", + "control_model.input_hint_block.6.weight": "controlnet_conv_in.blocks.6.weight", + "control_model.input_hint_block.6.bias": "controlnet_conv_in.blocks.6.bias", + "control_model.input_hint_block.8.weight": "controlnet_conv_in.blocks.8.weight", + "control_model.input_hint_block.8.bias": "controlnet_conv_in.blocks.8.bias", + "control_model.input_hint_block.10.weight": "controlnet_conv_in.blocks.10.weight", + "control_model.input_hint_block.10.bias": "controlnet_conv_in.blocks.10.bias", + "control_model.input_hint_block.12.weight": "controlnet_conv_in.blocks.12.weight", + "control_model.input_hint_block.12.bias": "controlnet_conv_in.blocks.12.bias", + "control_model.input_hint_block.14.weight": "controlnet_conv_in.blocks.14.weight", + "control_model.input_hint_block.14.bias": "controlnet_conv_in.blocks.14.bias", + "control_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight", + "control_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias", + "control_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight", + "control_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias", + "control_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight", + "control_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias", + "control_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight", + "control_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias", + "control_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight", + "control_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias", + "control_model.middle_block.1.norm.weight": "blocks.29.norm.weight", + "control_model.middle_block.1.norm.bias": "blocks.29.norm.bias", + "control_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight", + "control_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias", + "control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight", + "control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight", + "control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight", + "control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight", + "control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias", + "control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight", + "control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias", + "control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight", + "control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias", + "control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight", + "control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight", + "control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight", + "control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight", + "control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias", + "control_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight", + "control_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias", + "control_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight", + "control_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias", + "control_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight", + "control_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias", + "control_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight", + "control_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias", + "control_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight", + "control_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias", + "control_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight", + "control_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias", + "control_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight", + "control_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias", + "control_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight", + "control_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias", + "control_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight", + "control_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias", + "control_model.middle_block_out.0.weight": "controlnet_blocks.12.weight", + "control_model.middle_block_out.0.bias": "controlnet_blocks.7.bias", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if ".proj_in." in name or ".proj_out." in name: + param = param.squeeze() + state_dict_[rename_dict[name]] = param + return state_dict_ diff --git a/diffsynth/models/sd_ipadapter.py b/diffsynth/models/sd_ipadapter.py new file mode 100644 index 0000000000000000000000000000000000000000..723d53df6fd190f521f75c9b14a4060950f18f37 --- /dev/null +++ b/diffsynth/models/sd_ipadapter.py @@ -0,0 +1,56 @@ +from .svd_image_encoder import SVDImageEncoder +from .sdxl_ipadapter import IpAdapterImageProjModel, IpAdapterModule, SDXLIpAdapterStateDictConverter +from transformers import CLIPImageProcessor +import torch + + +class IpAdapterCLIPImageEmbedder(SVDImageEncoder): + def __init__(self): + super().__init__() + self.image_processor = CLIPImageProcessor() + + def forward(self, image): + pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values + pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype) + return super().forward(pixel_values) + + +class SDIpAdapter(torch.nn.Module): + def __init__(self): + super().__init__() + shape_list = [(768, 320)] * 2 + [(768, 640)] * 2 + [(768, 1280)] * 5 + [(768, 640)] * 3 + [(768, 320)] * 3 + [(768, 1280)] * 1 + self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list]) + self.image_proj = IpAdapterImageProjModel(cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4) + self.set_full_adapter() + + def set_full_adapter(self): + block_ids = [1, 4, 9, 12, 17, 20, 40, 43, 46, 50, 53, 56, 60, 63, 66, 29] + self.call_block_id = {(i, 0): j for j, i in enumerate(block_ids)} + + def set_less_adapter(self): + # IP-Adapter for SD v1.5 doesn't support this feature. + self.set_full_adapter(self) + + def forward(self, hidden_states, scale=1.0): + hidden_states = self.image_proj(hidden_states) + hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1]) + ip_kv_dict = {} + for (block_id, transformer_id) in self.call_block_id: + ipadapter_id = self.call_block_id[(block_id, transformer_id)] + ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states) + if block_id not in ip_kv_dict: + ip_kv_dict[block_id] = {} + ip_kv_dict[block_id][transformer_id] = { + "ip_k": ip_k, + "ip_v": ip_v, + "scale": scale + } + return ip_kv_dict + + def state_dict_converter(self): + return SDIpAdapterStateDictConverter() + + +class SDIpAdapterStateDictConverter(SDXLIpAdapterStateDictConverter): + def __init__(self): + pass diff --git a/diffsynth/models/sd_lora.py b/diffsynth/models/sd_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..3b7ecacc3f19040152ef4036988082eff35d8f4c --- /dev/null +++ b/diffsynth/models/sd_lora.py @@ -0,0 +1,60 @@ +import torch +from .sd_unet import SDUNetStateDictConverter, SDUNet +from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder + + +class SDLoRA: + def __init__(self): + pass + + def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"): + special_keys = { + "down.blocks": "down_blocks", + "up.blocks": "up_blocks", + "mid.block": "mid_block", + "proj.in": "proj_in", + "proj.out": "proj_out", + "transformer.blocks": "transformer_blocks", + "to.q": "to_q", + "to.k": "to_k", + "to.v": "to_v", + "to.out": "to_out", + } + state_dict_ = {} + for key in state_dict: + if ".lora_up" not in key: + continue + if not key.startswith(lora_prefix): + continue + weight_up = state_dict[key].to(device="cuda", dtype=torch.float16) + weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16) + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32) + weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32) + lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + lora_weight = alpha * torch.mm(weight_up, weight_down) + target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight" + for special_key in special_keys: + target_name = target_name.replace(special_key, special_keys[special_key]) + state_dict_[target_name] = lora_weight.cpu() + return state_dict_ + + def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"): + state_dict_unet = unet.state_dict() + state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device) + state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora) + if len(state_dict_lora) > 0: + for name in state_dict_lora: + state_dict_unet[name] += state_dict_lora[name].to(device=device) + unet.load_state_dict(state_dict_unet) + + def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"): + state_dict_text_encoder = text_encoder.state_dict() + state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device) + state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora) + if len(state_dict_lora) > 0: + for name in state_dict_lora: + state_dict_text_encoder[name] += state_dict_lora[name].to(device=device) + text_encoder.load_state_dict(state_dict_text_encoder) + diff --git a/diffsynth/models/sd_motion.py b/diffsynth/models/sd_motion.py new file mode 100644 index 0000000000000000000000000000000000000000..b313e62c8c6382ffbe5d30a2601583191209f847 --- /dev/null +++ b/diffsynth/models/sd_motion.py @@ -0,0 +1,198 @@ +from .sd_unet import SDUNet, Attention, GEGLU +import torch +from einops import rearrange, repeat + + +class TemporalTransformerBlock(torch.nn.Module): + + def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32): + super().__init__() + + # 1. Self-Attn + self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim)) + self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True) + self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) + + # 2. Cross-Attn + self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim)) + self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True) + self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) + + # 3. Feed-forward + self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True) + self.act_fn = GEGLU(dim, dim * 4) + self.ff = torch.nn.Linear(dim * 4, dim) + + + def forward(self, hidden_states, batch_size=1): + + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size) + attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]]) + attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size) + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size) + attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]]) + attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + ff_output = self.act_fn(norm_hidden_states) + ff_output = self.ff(ff_output) + hidden_states = ff_output + hidden_states + + return hidden_states + + +class TemporalBlock(torch.nn.Module): + + def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True) + self.proj_in = torch.nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = torch.nn.ModuleList([ + TemporalTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim + ) + for d in range(num_layers) + ]) + + self.proj_out = torch.nn.Linear(inner_dim, in_channels) + + def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1): + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + batch_size=batch_size + ) + + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = hidden_states + residual + + return hidden_states, time_emb, text_emb, res_stack + + +class SDMotionModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.motion_modules = torch.nn.ModuleList([ + TemporalBlock(8, 40, 320, eps=1e-6), + TemporalBlock(8, 40, 320, eps=1e-6), + TemporalBlock(8, 80, 640, eps=1e-6), + TemporalBlock(8, 80, 640, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 160, 1280, eps=1e-6), + TemporalBlock(8, 80, 640, eps=1e-6), + TemporalBlock(8, 80, 640, eps=1e-6), + TemporalBlock(8, 80, 640, eps=1e-6), + TemporalBlock(8, 40, 320, eps=1e-6), + TemporalBlock(8, 40, 320, eps=1e-6), + TemporalBlock(8, 40, 320, eps=1e-6), + ]) + self.call_block_id = { + 1: 0, + 4: 1, + 9: 2, + 12: 3, + 17: 4, + 20: 5, + 24: 6, + 26: 7, + 29: 8, + 32: 9, + 34: 10, + 36: 11, + 40: 12, + 43: 13, + 46: 14, + 50: 15, + 53: 16, + 56: 17, + 60: 18, + 63: 19, + 66: 20 + } + + def forward(self): + pass + + def state_dict_converter(self): + return SDMotionModelStateDictConverter() + + +class SDMotionModelStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "norm": "norm", + "proj_in": "proj_in", + "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q", + "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k", + "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v", + "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out", + "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1", + "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q", + "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k", + "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v", + "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out", + "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2", + "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1", + "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2", + "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj", + "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff", + "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3", + "proj_out": "proj_out", + } + name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")]) + name_list += sorted([i for i in state_dict if i.startswith("mid_block.")]) + name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")]) + state_dict_ = {} + last_prefix, module_id = "", -1 + for name in name_list: + names = name.split(".") + prefix_index = names.index("temporal_transformer") + 1 + prefix = ".".join(names[:prefix_index]) + if prefix != last_prefix: + last_prefix = prefix + module_id += 1 + middle_name = ".".join(names[prefix_index:-1]) + suffix = names[-1] + if "pos_encoder" in names: + rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]]) + else: + rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix]) + state_dict_[rename] = state_dict[name] + return state_dict_ + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) diff --git a/diffsynth/models/sd_text_encoder.py b/diffsynth/models/sd_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c58e47710a5661d08920169f1d9849d39d45e73e --- /dev/null +++ b/diffsynth/models/sd_text_encoder.py @@ -0,0 +1,320 @@ +import torch +from .attention import Attention + + +class CLIPEncoderLayer(torch.nn.Module): + def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True): + super().__init__() + self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True) + self.layer_norm1 = torch.nn.LayerNorm(embed_dim) + self.layer_norm2 = torch.nn.LayerNorm(embed_dim) + self.fc1 = torch.nn.Linear(embed_dim, intermediate_size) + self.fc2 = torch.nn.Linear(intermediate_size, embed_dim) + + self.use_quick_gelu = use_quick_gelu + + def quickGELU(self, x): + return x * torch.sigmoid(1.702 * x) + + def forward(self, hidden_states, attn_mask=None): + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.attn(hidden_states, attn_mask=attn_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.fc1(hidden_states) + if self.use_quick_gelu: + hidden_states = self.quickGELU(hidden_states) + else: + hidden_states = torch.nn.functional.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class SDTextEncoder(torch.nn.Module): + def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072): + super().__init__() + + # token_embedding + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # final_layer_norm + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, input_ids, clip_skip=1): + embeds = self.token_embedding(input_ids) + self.position_embeds + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + if encoder_id + clip_skip == len(self.encoders): + break + embeds = self.final_layer_norm(embeds) + return embeds + + def state_dict_converter(self): + return SDTextEncoderStateDictConverter() + + +class SDTextEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embeds", + "text_model.final_layer_norm.weight": "final_layer_norm.weight", + "text_model.final_layer_norm.bias": "final_layer_norm.bias" + } + attn_rename_dict = { + "self_attn.q_proj": "attn.to_q", + "self_attn.k_proj": "attn.to_k", + "self_attn.v_proj": "attn.to_v", + "self_attn.out_proj": "attn.to_out", + "layer_norm1": "layer_norm1", + "layer_norm2": "layer_norm2", + "mlp.fc1": "fc1", + "mlp.fc2": "fc2", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + elif name.startswith("text_model.encoder.layers."): + param = state_dict[name] + names = name.split(".") + layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] + name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) + state_dict_[name_] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight", + "cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias", + "cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight", + "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds" + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + return state_dict_ diff --git a/diffsynth/models/sd_unet.py b/diffsynth/models/sd_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..8a8b17e76451de27cf2933f5cdd2cfd82eb87977 --- /dev/null +++ b/diffsynth/models/sd_unet.py @@ -0,0 +1,1107 @@ +import torch, math +from .attention import Attention +from .tiler import TileWorker + + +class Timesteps(torch.nn.Module): + def __init__(self, num_channels): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps): + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) / half_dim + timesteps = timesteps.unsqueeze(-1) + emb = timesteps.float() * torch.exp(exponent) + emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1) + return emb + + +class GEGLU(torch.nn.Module): + + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = torch.nn.Linear(dim_in, dim_out * 2) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * torch.nn.functional.gelu(gate) + + +class BasicTransformerBlock(torch.nn.Module): + + def __init__(self, dim, num_attention_heads, attention_head_dim, cross_attention_dim): + super().__init__() + + # 1. Self-Attn + self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True) + self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) + + # 2. Cross-Attn + self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True) + self.attn2 = Attention(q_dim=dim, kv_dim=cross_attention_dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) + + # 3. Feed-forward + self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True) + self.act_fn = GEGLU(dim, dim * 4) + self.ff = torch.nn.Linear(dim * 4, dim) + + + def forward(self, hidden_states, encoder_hidden_states, ipadapter_kwargs=None): + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + norm_hidden_states = self.norm2(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states, ipadapter_kwargs=ipadapter_kwargs) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + ff_output = self.act_fn(norm_hidden_states) + ff_output = self.ff(ff_output) + hidden_states = ff_output + hidden_states + + return hidden_states + + +class DownSampler(torch.nn.Module): + def __init__(self, channels, padding=1, extra_padding=False): + super().__init__() + self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding) + self.extra_padding = extra_padding + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + if self.extra_padding: + hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0) + hidden_states = self.conv(hidden_states) + return hidden_states, time_emb, text_emb, res_stack + + +class UpSampler(torch.nn.Module): + def __init__(self, channels): + super().__init__() + self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1) + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = self.conv(hidden_states) + return hidden_states, time_emb, text_emb, res_stack + + +class ResnetBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5): + super().__init__() + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.nonlinearity = torch.nn.SiLU() + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True) + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + x = hidden_states + x = self.norm1(x) + x = self.nonlinearity(x) + x = self.conv1(x) + if time_emb is not None: + emb = self.nonlinearity(time_emb) + emb = self.time_emb_proj(emb)[:, :, None, None] + x = x + emb + x = self.norm2(x) + x = self.nonlinearity(x) + x = self.conv2(x) + if self.conv_shortcut is not None: + hidden_states = self.conv_shortcut(hidden_states) + hidden_states = hidden_states + x + return hidden_states, time_emb, text_emb, res_stack + + +class AttentionBlock(torch.nn.Module): + + def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, cross_attention_dim=None, norm_num_groups=32, eps=1e-5, need_proj_out=True): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True) + self.proj_in = torch.nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = torch.nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim + ) + for d in range(num_layers) + ]) + self.need_proj_out = need_proj_out + if need_proj_out: + self.proj_out = torch.nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states, time_emb, text_emb, res_stack, + cross_frame_attention=False, + tiled=False, tile_size=64, tile_stride=32, + ipadapter_kwargs_list={}, + **kwargs + ): + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + + if cross_frame_attention: + hidden_states = hidden_states.reshape(1, batch * height * width, inner_dim) + encoder_hidden_states = text_emb.mean(dim=0, keepdim=True) + else: + encoder_hidden_states = text_emb + if encoder_hidden_states.shape[0] != hidden_states.shape[0]: + encoder_hidden_states = encoder_hidden_states.repeat(hidden_states.shape[0], 1, 1) + + if tiled: + tile_size = min(tile_size, min(height, width)) + hidden_states = hidden_states.permute(0, 2, 1).reshape(batch, inner_dim, height, width) + def block_tile_forward(x): + b, c, h, w = x.shape + x = x.permute(0, 2, 3, 1).reshape(b, h*w, c) + x = block(x, encoder_hidden_states) + x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) + return x + for block in self.transformer_blocks: + hidden_states = TileWorker().tiled_forward( + block_tile_forward, + hidden_states, + tile_size, + tile_stride, + tile_device=hidden_states.device, + tile_dtype=hidden_states.dtype + ) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + for block_id, block in enumerate(self.transformer_blocks): + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ipadapter_kwargs=ipadapter_kwargs_list.get(block_id, None) + ) + if cross_frame_attention: + hidden_states = hidden_states.reshape(batch, height * width, inner_dim) + + if self.need_proj_out: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = hidden_states + residual + else: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + return hidden_states, time_emb, text_emb, res_stack + + +class PushBlock(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + res_stack.append(hidden_states) + return hidden_states, time_emb, text_emb, res_stack + + +class PopBlock(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + res_hidden_states = res_stack.pop() + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + return hidden_states, time_emb, text_emb, res_stack + + +class SDUNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.time_proj = Timesteps(320) + self.time_embedding = torch.nn.Sequential( + torch.nn.Linear(320, 1280), + torch.nn.SiLU(), + torch.nn.Linear(1280, 1280) + ) + self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1) + + self.blocks = torch.nn.ModuleList([ + # CrossAttnDownBlock2D + ResnetBlock(320, 320, 1280), + AttentionBlock(8, 40, 320, 1, 768, eps=1e-6), + PushBlock(), + ResnetBlock(320, 320, 1280), + AttentionBlock(8, 40, 320, 1, 768, eps=1e-6), + PushBlock(), + DownSampler(320), + PushBlock(), + # CrossAttnDownBlock2D + ResnetBlock(320, 640, 1280), + AttentionBlock(8, 80, 640, 1, 768, eps=1e-6), + PushBlock(), + ResnetBlock(640, 640, 1280), + AttentionBlock(8, 80, 640, 1, 768, eps=1e-6), + PushBlock(), + DownSampler(640), + PushBlock(), + # CrossAttnDownBlock2D + ResnetBlock(640, 1280, 1280), + AttentionBlock(8, 160, 1280, 1, 768, eps=1e-6), + PushBlock(), + ResnetBlock(1280, 1280, 1280), + AttentionBlock(8, 160, 1280, 1, 768, eps=1e-6), + PushBlock(), + DownSampler(1280), + PushBlock(), + # DownBlock2D + ResnetBlock(1280, 1280, 1280), + PushBlock(), + ResnetBlock(1280, 1280, 1280), + PushBlock(), + # UNetMidBlock2DCrossAttn + ResnetBlock(1280, 1280, 1280), + AttentionBlock(8, 160, 1280, 1, 768, eps=1e-6), + ResnetBlock(1280, 1280, 1280), + # UpBlock2D + PopBlock(), + ResnetBlock(2560, 1280, 1280), + PopBlock(), + ResnetBlock(2560, 1280, 1280), + PopBlock(), + ResnetBlock(2560, 1280, 1280), + UpSampler(1280), + # CrossAttnUpBlock2D + PopBlock(), + ResnetBlock(2560, 1280, 1280), + AttentionBlock(8, 160, 1280, 1, 768, eps=1e-6), + PopBlock(), + ResnetBlock(2560, 1280, 1280), + AttentionBlock(8, 160, 1280, 1, 768, eps=1e-6), + PopBlock(), + ResnetBlock(1920, 1280, 1280), + AttentionBlock(8, 160, 1280, 1, 768, eps=1e-6), + UpSampler(1280), + # CrossAttnUpBlock2D + PopBlock(), + ResnetBlock(1920, 640, 1280), + AttentionBlock(8, 80, 640, 1, 768, eps=1e-6), + PopBlock(), + ResnetBlock(1280, 640, 1280), + AttentionBlock(8, 80, 640, 1, 768, eps=1e-6), + PopBlock(), + ResnetBlock(960, 640, 1280), + AttentionBlock(8, 80, 640, 1, 768, eps=1e-6), + UpSampler(640), + # CrossAttnUpBlock2D + PopBlock(), + ResnetBlock(960, 320, 1280), + AttentionBlock(8, 40, 320, 1, 768, eps=1e-6), + PopBlock(), + ResnetBlock(640, 320, 1280), + AttentionBlock(8, 40, 320, 1, 768, eps=1e-6), + PopBlock(), + ResnetBlock(640, 320, 1280), + AttentionBlock(8, 40, 320, 1, 768, eps=1e-6), + ]) + + self.conv_norm_out = torch.nn.GroupNorm(num_channels=320, num_groups=32, eps=1e-5) + self.conv_act = torch.nn.SiLU() + self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1) + + def forward(self, sample, timestep, encoder_hidden_states, **kwargs): + # 1. time + time_emb = self.time_proj(timestep[None]).to(sample.dtype) + time_emb = self.time_embedding(time_emb) + + # 2. pre-process + hidden_states = self.conv_in(sample) + text_emb = encoder_hidden_states + res_stack = [hidden_states] + + # 3. blocks + for i, block in enumerate(self.blocks): + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + + # 4. output + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + def state_dict_converter(self): + return SDUNetStateDictConverter() + + +class SDUNetStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + # architecture + block_types = [ + 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock', + 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock', + 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock', + 'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock', + 'ResnetBlock', 'AttentionBlock', 'ResnetBlock', + 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler', + 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler', + 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler', + 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock' + ] + + # Rename each parameter + name_list = sorted([name for name in state_dict]) + rename_dict = {} + block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1} + last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""} + for name in name_list: + names = name.split(".") + if names[0] in ["conv_in", "conv_norm_out", "conv_out"]: + pass + elif names[0] in ["time_embedding", "add_embedding"]: + if names[0] == "add_embedding": + names[0] = "add_time_embedding" + names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]] + elif names[0] in ["down_blocks", "mid_block", "up_blocks"]: + if names[0] == "mid_block": + names.insert(1, "0") + block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]] + block_type_with_id = ".".join(names[:4]) + if block_type_with_id != last_block_type_with_id[block_type]: + block_id[block_type] += 1 + last_block_type_with_id[block_type] = block_type_with_id + while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type: + block_id[block_type] += 1 + block_type_with_id = ".".join(names[:4]) + names = ["blocks", str(block_id[block_type])] + names[4:] + if "ff" in names: + ff_index = names.index("ff") + component = ".".join(names[ff_index:ff_index+3]) + component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component] + names = names[:ff_index] + [component] + names[ff_index+3:] + if "to_out" in names: + names.pop(names.index("to_out") + 1) + else: + raise ValueError(f"Unknown parameters: {name}") + rename_dict[name] = ".".join(names) + + # Convert state_dict + state_dict_ = {} + for name, param in state_dict.items(): + if ".proj_in." in name or ".proj_out." in name: + param = param.squeeze() + state_dict_[rename_dict[name]] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "model.diffusion_model.input_blocks.0.0.bias": "conv_in.bias", + "model.diffusion_model.input_blocks.0.0.weight": "conv_in.weight", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight", + "model.diffusion_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias", + "model.diffusion_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight", + "model.diffusion_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias", + "model.diffusion_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight", + "model.diffusion_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias", + "model.diffusion_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias", + "model.diffusion_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight", + "model.diffusion_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias", + "model.diffusion_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight", + "model.diffusion_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias", + "model.diffusion_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight", + "model.diffusion_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias", + "model.diffusion_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight", + "model.diffusion_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias", + "model.diffusion_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight", + "model.diffusion_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias", + "model.diffusion_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight", + "model.diffusion_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias", + "model.diffusion_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight", + "model.diffusion_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias", + "model.diffusion_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight", + "model.diffusion_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias", + "model.diffusion_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight", + "model.diffusion_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias", + "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight", + "model.diffusion_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias", + "model.diffusion_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight", + "model.diffusion_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias", + "model.diffusion_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight", + "model.diffusion_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias", + "model.diffusion_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias", + "model.diffusion_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight", + "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias", + "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight", + "model.diffusion_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias", + "model.diffusion_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight", + "model.diffusion_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias", + "model.diffusion_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight", + "model.diffusion_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias", + "model.diffusion_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight", + "model.diffusion_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias", + "model.diffusion_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight", + "model.diffusion_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias", + "model.diffusion_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight", + "model.diffusion_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias", + "model.diffusion_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias", + "model.diffusion_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight", + "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias", + "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight", + "model.diffusion_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias", + "model.diffusion_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight", + "model.diffusion_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias", + "model.diffusion_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight", + "model.diffusion_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias", + "model.diffusion_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight", + "model.diffusion_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias", + "model.diffusion_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight", + "model.diffusion_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias", + "model.diffusion_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight", + "model.diffusion_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias", + "model.diffusion_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias", + "model.diffusion_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight", + "model.diffusion_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias", + "model.diffusion_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight", + "model.diffusion_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias", + "model.diffusion_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight", + "model.diffusion_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias", + "model.diffusion_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight", + "model.diffusion_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias", + "model.diffusion_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight", + "model.diffusion_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias", + "model.diffusion_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight", + "model.diffusion_model.middle_block.1.norm.bias": "blocks.29.norm.bias", + "model.diffusion_model.middle_block.1.norm.weight": "blocks.29.norm.weight", + "model.diffusion_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias", + "model.diffusion_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight", + "model.diffusion_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias", + "model.diffusion_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight", + "model.diffusion_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias", + "model.diffusion_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight", + "model.diffusion_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias", + "model.diffusion_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight", + "model.diffusion_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias", + "model.diffusion_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight", + "model.diffusion_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias", + "model.diffusion_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight", + "model.diffusion_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias", + "model.diffusion_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight", + "model.diffusion_model.out.0.bias": "conv_norm_out.bias", + "model.diffusion_model.out.0.weight": "conv_norm_out.weight", + "model.diffusion_model.out.2.bias": "conv_out.bias", + "model.diffusion_model.out.2.weight": "conv_out.weight", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "blocks.32.time_emb_proj.bias", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "blocks.32.time_emb_proj.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "blocks.32.norm1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "blocks.32.norm1.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "blocks.32.conv1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "blocks.32.conv1.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "blocks.32.norm2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "blocks.32.norm2.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "blocks.32.conv2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "blocks.32.conv2.weight", + "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "blocks.32.conv_shortcut.bias", + "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "blocks.32.conv_shortcut.weight", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "blocks.34.time_emb_proj.bias", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "blocks.34.time_emb_proj.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "blocks.34.norm1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "blocks.34.norm1.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "blocks.34.conv1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "blocks.34.conv1.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "blocks.34.norm2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "blocks.34.norm2.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "blocks.34.conv2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "blocks.34.conv2.weight", + "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "blocks.34.conv_shortcut.bias", + "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "blocks.34.conv_shortcut.weight", + "model.diffusion_model.output_blocks.10.0.emb_layers.1.bias": "blocks.62.time_emb_proj.bias", + "model.diffusion_model.output_blocks.10.0.emb_layers.1.weight": "blocks.62.time_emb_proj.weight", + "model.diffusion_model.output_blocks.10.0.in_layers.0.bias": "blocks.62.norm1.bias", + "model.diffusion_model.output_blocks.10.0.in_layers.0.weight": "blocks.62.norm1.weight", + "model.diffusion_model.output_blocks.10.0.in_layers.2.bias": "blocks.62.conv1.bias", + "model.diffusion_model.output_blocks.10.0.in_layers.2.weight": "blocks.62.conv1.weight", + "model.diffusion_model.output_blocks.10.0.out_layers.0.bias": "blocks.62.norm2.bias", + "model.diffusion_model.output_blocks.10.0.out_layers.0.weight": "blocks.62.norm2.weight", + "model.diffusion_model.output_blocks.10.0.out_layers.3.bias": "blocks.62.conv2.bias", + "model.diffusion_model.output_blocks.10.0.out_layers.3.weight": "blocks.62.conv2.weight", + "model.diffusion_model.output_blocks.10.0.skip_connection.bias": "blocks.62.conv_shortcut.bias", + "model.diffusion_model.output_blocks.10.0.skip_connection.weight": "blocks.62.conv_shortcut.weight", + "model.diffusion_model.output_blocks.10.1.norm.bias": "blocks.63.norm.bias", + "model.diffusion_model.output_blocks.10.1.norm.weight": "blocks.63.norm.weight", + "model.diffusion_model.output_blocks.10.1.proj_in.bias": "blocks.63.proj_in.bias", + "model.diffusion_model.output_blocks.10.1.proj_in.weight": "blocks.63.proj_in.weight", + "model.diffusion_model.output_blocks.10.1.proj_out.bias": "blocks.63.proj_out.bias", + "model.diffusion_model.output_blocks.10.1.proj_out.weight": "blocks.63.proj_out.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_k.weight": "blocks.63.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.63.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.63.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_q.weight": "blocks.63.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_v.weight": "blocks.63.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k.weight": "blocks.63.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.63.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.63.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q.weight": "blocks.63.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v.weight": "blocks.63.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.63.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.63.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.bias": "blocks.63.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.weight": "blocks.63.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.bias": "blocks.63.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.weight": "blocks.63.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.bias": "blocks.63.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.weight": "blocks.63.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.bias": "blocks.63.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.weight": "blocks.63.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.11.0.emb_layers.1.bias": "blocks.65.time_emb_proj.bias", + "model.diffusion_model.output_blocks.11.0.emb_layers.1.weight": "blocks.65.time_emb_proj.weight", + "model.diffusion_model.output_blocks.11.0.in_layers.0.bias": "blocks.65.norm1.bias", + "model.diffusion_model.output_blocks.11.0.in_layers.0.weight": "blocks.65.norm1.weight", + "model.diffusion_model.output_blocks.11.0.in_layers.2.bias": "blocks.65.conv1.bias", + "model.diffusion_model.output_blocks.11.0.in_layers.2.weight": "blocks.65.conv1.weight", + "model.diffusion_model.output_blocks.11.0.out_layers.0.bias": "blocks.65.norm2.bias", + "model.diffusion_model.output_blocks.11.0.out_layers.0.weight": "blocks.65.norm2.weight", + "model.diffusion_model.output_blocks.11.0.out_layers.3.bias": "blocks.65.conv2.bias", + "model.diffusion_model.output_blocks.11.0.out_layers.3.weight": "blocks.65.conv2.weight", + "model.diffusion_model.output_blocks.11.0.skip_connection.bias": "blocks.65.conv_shortcut.bias", + "model.diffusion_model.output_blocks.11.0.skip_connection.weight": "blocks.65.conv_shortcut.weight", + "model.diffusion_model.output_blocks.11.1.norm.bias": "blocks.66.norm.bias", + "model.diffusion_model.output_blocks.11.1.norm.weight": "blocks.66.norm.weight", + "model.diffusion_model.output_blocks.11.1.proj_in.bias": "blocks.66.proj_in.bias", + "model.diffusion_model.output_blocks.11.1.proj_in.weight": "blocks.66.proj_in.weight", + "model.diffusion_model.output_blocks.11.1.proj_out.bias": "blocks.66.proj_out.bias", + "model.diffusion_model.output_blocks.11.1.proj_out.weight": "blocks.66.proj_out.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_k.weight": "blocks.66.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.66.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.66.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_q.weight": "blocks.66.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_v.weight": "blocks.66.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k.weight": "blocks.66.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.66.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.66.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q.weight": "blocks.66.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v.weight": "blocks.66.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.66.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.66.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.bias": "blocks.66.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.weight": "blocks.66.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias": "blocks.66.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.weight": "blocks.66.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.bias": "blocks.66.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.weight": "blocks.66.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.bias": "blocks.66.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.weight": "blocks.66.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "blocks.36.time_emb_proj.bias", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "blocks.36.time_emb_proj.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "blocks.36.norm1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "blocks.36.norm1.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "blocks.36.conv1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "blocks.36.conv1.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "blocks.36.norm2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "blocks.36.norm2.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "blocks.36.conv2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "blocks.36.conv2.weight", + "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "blocks.36.conv_shortcut.bias", + "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "blocks.36.conv_shortcut.weight", + "model.diffusion_model.output_blocks.2.1.conv.bias": "blocks.37.conv.bias", + "model.diffusion_model.output_blocks.2.1.conv.weight": "blocks.37.conv.weight", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "blocks.39.time_emb_proj.bias", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "blocks.39.time_emb_proj.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "blocks.39.norm1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "blocks.39.norm1.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "blocks.39.conv1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "blocks.39.conv1.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "blocks.39.norm2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "blocks.39.norm2.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "blocks.39.conv2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "blocks.39.conv2.weight", + "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "blocks.39.conv_shortcut.bias", + "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "blocks.39.conv_shortcut.weight", + "model.diffusion_model.output_blocks.3.1.norm.bias": "blocks.40.norm.bias", + "model.diffusion_model.output_blocks.3.1.norm.weight": "blocks.40.norm.weight", + "model.diffusion_model.output_blocks.3.1.proj_in.bias": "blocks.40.proj_in.bias", + "model.diffusion_model.output_blocks.3.1.proj_in.weight": "blocks.40.proj_in.weight", + "model.diffusion_model.output_blocks.3.1.proj_out.bias": "blocks.40.proj_out.bias", + "model.diffusion_model.output_blocks.3.1.proj_out.weight": "blocks.40.proj_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "blocks.40.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.40.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.40.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "blocks.40.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "blocks.40.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "blocks.40.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.40.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.40.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "blocks.40.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "blocks.40.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.40.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.40.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "blocks.40.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "blocks.40.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "blocks.40.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "blocks.40.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "blocks.40.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "blocks.40.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "blocks.40.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "blocks.40.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "blocks.42.time_emb_proj.bias", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "blocks.42.time_emb_proj.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "blocks.42.norm1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "blocks.42.norm1.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "blocks.42.conv1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "blocks.42.conv1.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "blocks.42.norm2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "blocks.42.norm2.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "blocks.42.conv2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "blocks.42.conv2.weight", + "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "blocks.42.conv_shortcut.bias", + "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "blocks.42.conv_shortcut.weight", + "model.diffusion_model.output_blocks.4.1.norm.bias": "blocks.43.norm.bias", + "model.diffusion_model.output_blocks.4.1.norm.weight": "blocks.43.norm.weight", + "model.diffusion_model.output_blocks.4.1.proj_in.bias": "blocks.43.proj_in.bias", + "model.diffusion_model.output_blocks.4.1.proj_in.weight": "blocks.43.proj_in.weight", + "model.diffusion_model.output_blocks.4.1.proj_out.bias": "blocks.43.proj_out.bias", + "model.diffusion_model.output_blocks.4.1.proj_out.weight": "blocks.43.proj_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.43.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.43.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.43.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.43.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.43.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.43.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.43.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.43.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.43.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.43.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.43.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.43.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.43.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.43.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.43.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.43.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.43.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.43.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.43.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.43.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "blocks.45.time_emb_proj.bias", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "blocks.45.time_emb_proj.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "blocks.45.norm1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "blocks.45.norm1.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "blocks.45.conv1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "blocks.45.conv1.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "blocks.45.norm2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "blocks.45.norm2.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "blocks.45.conv2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "blocks.45.conv2.weight", + "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "blocks.45.conv_shortcut.bias", + "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "blocks.45.conv_shortcut.weight", + "model.diffusion_model.output_blocks.5.1.norm.bias": "blocks.46.norm.bias", + "model.diffusion_model.output_blocks.5.1.norm.weight": "blocks.46.norm.weight", + "model.diffusion_model.output_blocks.5.1.proj_in.bias": "blocks.46.proj_in.bias", + "model.diffusion_model.output_blocks.5.1.proj_in.weight": "blocks.46.proj_in.weight", + "model.diffusion_model.output_blocks.5.1.proj_out.bias": "blocks.46.proj_out.bias", + "model.diffusion_model.output_blocks.5.1.proj_out.weight": "blocks.46.proj_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.46.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.46.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.46.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.46.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.46.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.46.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.46.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.46.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.46.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.46.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.46.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.46.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.46.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.46.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.46.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.46.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.46.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.46.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.46.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.46.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.2.conv.bias": "blocks.47.conv.bias", + "model.diffusion_model.output_blocks.5.2.conv.weight": "blocks.47.conv.weight", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "blocks.49.time_emb_proj.bias", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "blocks.49.time_emb_proj.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "blocks.49.norm1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "blocks.49.norm1.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "blocks.49.conv1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "blocks.49.conv1.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "blocks.49.norm2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "blocks.49.norm2.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "blocks.49.conv2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "blocks.49.conv2.weight", + "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "blocks.49.conv_shortcut.bias", + "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "blocks.49.conv_shortcut.weight", + "model.diffusion_model.output_blocks.6.1.norm.bias": "blocks.50.norm.bias", + "model.diffusion_model.output_blocks.6.1.norm.weight": "blocks.50.norm.weight", + "model.diffusion_model.output_blocks.6.1.proj_in.bias": "blocks.50.proj_in.bias", + "model.diffusion_model.output_blocks.6.1.proj_in.weight": "blocks.50.proj_in.weight", + "model.diffusion_model.output_blocks.6.1.proj_out.bias": "blocks.50.proj_out.bias", + "model.diffusion_model.output_blocks.6.1.proj_out.weight": "blocks.50.proj_out.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k.weight": "blocks.50.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.50.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.50.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q.weight": "blocks.50.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v.weight": "blocks.50.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight": "blocks.50.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.50.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.50.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight": "blocks.50.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight": "blocks.50.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.50.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.50.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias": "blocks.50.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight": "blocks.50.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias": "blocks.50.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight": "blocks.50.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.bias": "blocks.50.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight": "blocks.50.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias": "blocks.50.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight": "blocks.50.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "blocks.52.time_emb_proj.bias", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "blocks.52.time_emb_proj.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "blocks.52.norm1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "blocks.52.norm1.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "blocks.52.conv1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "blocks.52.conv1.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "blocks.52.norm2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "blocks.52.norm2.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "blocks.52.conv2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "blocks.52.conv2.weight", + "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "blocks.52.conv_shortcut.bias", + "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "blocks.52.conv_shortcut.weight", + "model.diffusion_model.output_blocks.7.1.norm.bias": "blocks.53.norm.bias", + "model.diffusion_model.output_blocks.7.1.norm.weight": "blocks.53.norm.weight", + "model.diffusion_model.output_blocks.7.1.proj_in.bias": "blocks.53.proj_in.bias", + "model.diffusion_model.output_blocks.7.1.proj_in.weight": "blocks.53.proj_in.weight", + "model.diffusion_model.output_blocks.7.1.proj_out.bias": "blocks.53.proj_out.bias", + "model.diffusion_model.output_blocks.7.1.proj_out.weight": "blocks.53.proj_out.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.53.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.53.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.53.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.53.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.53.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.53.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.53.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.53.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.53.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.53.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.53.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.53.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.53.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.53.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.53.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.53.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.53.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.53.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.53.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.53.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "blocks.55.time_emb_proj.bias", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "blocks.55.time_emb_proj.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "blocks.55.norm1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "blocks.55.norm1.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "blocks.55.conv1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "blocks.55.conv1.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "blocks.55.norm2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "blocks.55.norm2.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "blocks.55.conv2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "blocks.55.conv2.weight", + "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "blocks.55.conv_shortcut.bias", + "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "blocks.55.conv_shortcut.weight", + "model.diffusion_model.output_blocks.8.1.norm.bias": "blocks.56.norm.bias", + "model.diffusion_model.output_blocks.8.1.norm.weight": "blocks.56.norm.weight", + "model.diffusion_model.output_blocks.8.1.proj_in.bias": "blocks.56.proj_in.bias", + "model.diffusion_model.output_blocks.8.1.proj_in.weight": "blocks.56.proj_in.weight", + "model.diffusion_model.output_blocks.8.1.proj_out.bias": "blocks.56.proj_out.bias", + "model.diffusion_model.output_blocks.8.1.proj_out.weight": "blocks.56.proj_out.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.56.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.56.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.56.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.56.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.56.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.56.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.56.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.56.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.56.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.56.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.56.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.56.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.56.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.56.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.56.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.56.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.56.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.56.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.56.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.56.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.8.2.conv.bias": "blocks.57.conv.bias", + "model.diffusion_model.output_blocks.8.2.conv.weight": "blocks.57.conv.weight", + "model.diffusion_model.output_blocks.9.0.emb_layers.1.bias": "blocks.59.time_emb_proj.bias", + "model.diffusion_model.output_blocks.9.0.emb_layers.1.weight": "blocks.59.time_emb_proj.weight", + "model.diffusion_model.output_blocks.9.0.in_layers.0.bias": "blocks.59.norm1.bias", + "model.diffusion_model.output_blocks.9.0.in_layers.0.weight": "blocks.59.norm1.weight", + "model.diffusion_model.output_blocks.9.0.in_layers.2.bias": "blocks.59.conv1.bias", + "model.diffusion_model.output_blocks.9.0.in_layers.2.weight": "blocks.59.conv1.weight", + "model.diffusion_model.output_blocks.9.0.out_layers.0.bias": "blocks.59.norm2.bias", + "model.diffusion_model.output_blocks.9.0.out_layers.0.weight": "blocks.59.norm2.weight", + "model.diffusion_model.output_blocks.9.0.out_layers.3.bias": "blocks.59.conv2.bias", + "model.diffusion_model.output_blocks.9.0.out_layers.3.weight": "blocks.59.conv2.weight", + "model.diffusion_model.output_blocks.9.0.skip_connection.bias": "blocks.59.conv_shortcut.bias", + "model.diffusion_model.output_blocks.9.0.skip_connection.weight": "blocks.59.conv_shortcut.weight", + "model.diffusion_model.output_blocks.9.1.norm.bias": "blocks.60.norm.bias", + "model.diffusion_model.output_blocks.9.1.norm.weight": "blocks.60.norm.weight", + "model.diffusion_model.output_blocks.9.1.proj_in.bias": "blocks.60.proj_in.bias", + "model.diffusion_model.output_blocks.9.1.proj_in.weight": "blocks.60.proj_in.weight", + "model.diffusion_model.output_blocks.9.1.proj_out.bias": "blocks.60.proj_out.bias", + "model.diffusion_model.output_blocks.9.1.proj_out.weight": "blocks.60.proj_out.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_k.weight": "blocks.60.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.60.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.60.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_q.weight": "blocks.60.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_v.weight": "blocks.60.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight": "blocks.60.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.60.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.60.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q.weight": "blocks.60.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight": "blocks.60.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.60.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.60.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.bias": "blocks.60.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.weight": "blocks.60.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.bias": "blocks.60.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.weight": "blocks.60.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.bias": "blocks.60.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.weight": "blocks.60.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.bias": "blocks.60.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight": "blocks.60.transformer_blocks.0.norm3.weight", + "model.diffusion_model.time_embed.0.bias": "time_embedding.0.bias", + "model.diffusion_model.time_embed.0.weight": "time_embedding.0.weight", + "model.diffusion_model.time_embed.2.bias": "time_embedding.2.bias", + "model.diffusion_model.time_embed.2.weight": "time_embedding.2.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if ".proj_in." in name or ".proj_out." in name: + param = param.squeeze() + state_dict_[rename_dict[name]] = param + return state_dict_ \ No newline at end of file diff --git a/diffsynth/models/sd_vae_decoder.py b/diffsynth/models/sd_vae_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0fba92d763a3a6d0b0331b95a5bd7c5ac65135dc --- /dev/null +++ b/diffsynth/models/sd_vae_decoder.py @@ -0,0 +1,332 @@ +import torch +from .attention import Attention +from .sd_unet import ResnetBlock, UpSampler +from .tiler import TileWorker + + +class VAEAttentionBlock(torch.nn.Module): + + def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True) + + self.transformer_blocks = torch.nn.ModuleList([ + Attention( + inner_dim, + num_attention_heads, + attention_head_dim, + bias_q=True, + bias_kv=True, + bias_out=True + ) + for d in range(num_layers) + ]) + + def forward(self, hidden_states, time_emb, text_emb, res_stack): + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + + for block in self.transformer_blocks: + hidden_states = block(hidden_states) + + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = hidden_states + residual + + return hidden_states, time_emb, text_emb, res_stack + + +class SDVAEDecoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.scaling_factor = 0.18215 + self.post_quant_conv = torch.nn.Conv2d(4, 4, kernel_size=1) + self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1) + + self.blocks = torch.nn.ModuleList([ + # UNetMidBlock2D + ResnetBlock(512, 512, eps=1e-6), + VAEAttentionBlock(1, 512, 512, 1, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + # UpDecoderBlock2D + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + UpSampler(512), + # UpDecoderBlock2D + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + UpSampler(512), + # UpDecoderBlock2D + ResnetBlock(512, 256, eps=1e-6), + ResnetBlock(256, 256, eps=1e-6), + ResnetBlock(256, 256, eps=1e-6), + UpSampler(256), + # UpDecoderBlock2D + ResnetBlock(256, 128, eps=1e-6), + ResnetBlock(128, 128, eps=1e-6), + ResnetBlock(128, 128, eps=1e-6), + ]) + + self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5) + self.conv_act = torch.nn.SiLU() + self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1) + + def tiled_forward(self, sample, tile_size=64, tile_stride=32): + hidden_states = TileWorker().tiled_forward( + lambda x: self.forward(x), + sample, + tile_size, + tile_stride, + tile_device=sample.device, + tile_dtype=sample.dtype + ) + return hidden_states + + def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs): + # For VAE Decoder, we do not need to apply the tiler on each layer. + if tiled: + return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride) + + # 1. pre-process + sample = sample / self.scaling_factor + hidden_states = self.post_quant_conv(sample) + hidden_states = self.conv_in(hidden_states) + time_emb = None + text_emb = None + res_stack = None + + # 2. blocks + for i, block in enumerate(self.blocks): + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + + # 3. output + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + def state_dict_converter(self): + return SDVAEDecoderStateDictConverter() + + +class SDVAEDecoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + # architecture + block_types = [ + 'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock', + 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler', + 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler', + 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler', + 'ResnetBlock', 'ResnetBlock', 'ResnetBlock' + ] + + # Rename each parameter + local_rename_dict = { + "post_quant_conv": "post_quant_conv", + "decoder.conv_in": "conv_in", + "decoder.mid_block.attentions.0.group_norm": "blocks.1.norm", + "decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q", + "decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k", + "decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v", + "decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out", + "decoder.mid_block.resnets.0.norm1": "blocks.0.norm1", + "decoder.mid_block.resnets.0.conv1": "blocks.0.conv1", + "decoder.mid_block.resnets.0.norm2": "blocks.0.norm2", + "decoder.mid_block.resnets.0.conv2": "blocks.0.conv2", + "decoder.mid_block.resnets.1.norm1": "blocks.2.norm1", + "decoder.mid_block.resnets.1.conv1": "blocks.2.conv1", + "decoder.mid_block.resnets.1.norm2": "blocks.2.norm2", + "decoder.mid_block.resnets.1.conv2": "blocks.2.conv2", + "decoder.conv_norm_out": "conv_norm_out", + "decoder.conv_out": "conv_out", + } + name_list = sorted([name for name in state_dict]) + rename_dict = {} + block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2} + last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""} + for name in name_list: + names = name.split(".") + name_prefix = ".".join(names[:-1]) + if name_prefix in local_rename_dict: + rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1] + elif name.startswith("decoder.up_blocks"): + block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]] + block_type_with_id = ".".join(names[:5]) + if block_type_with_id != last_block_type_with_id[block_type]: + block_id[block_type] += 1 + last_block_type_with_id[block_type] = block_type_with_id + while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type: + block_id[block_type] += 1 + block_type_with_id = ".".join(names[:5]) + names = ["blocks", str(block_id[block_type])] + names[5:] + rename_dict[name] = ".".join(names) + + # Convert state_dict + state_dict_ = {} + for name, param in state_dict.items(): + if name in rename_dict: + state_dict_[rename_dict[name]] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "first_stage_model.decoder.conv_in.bias": "conv_in.bias", + "first_stage_model.decoder.conv_in.weight": "conv_in.weight", + "first_stage_model.decoder.conv_out.bias": "conv_out.bias", + "first_stage_model.decoder.conv_out.weight": "conv_out.weight", + "first_stage_model.decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias", + "first_stage_model.decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight", + "first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias", + "first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight", + "first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias", + "first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight", + "first_stage_model.decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias", + "first_stage_model.decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight", + "first_stage_model.decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias", + "first_stage_model.decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight", + "first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias", + "first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight", + "first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias", + "first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight", + "first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias", + "first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight", + "first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias", + "first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight", + "first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias", + "first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight", + "first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias", + "first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight", + "first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias", + "first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight", + "first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias", + "first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight", + "first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias", + "first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight", + "first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias", + "first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight", + "first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias", + "first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight", + "first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias", + "first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight", + "first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias", + "first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight", + "first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias", + "first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight", + "first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias", + "first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight", + "first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias", + "first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight", + "first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias", + "first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight", + "first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias", + "first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight", + "first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias", + "first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight", + "first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias", + "first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight", + "first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias", + "first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight", + "first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias", + "first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight", + "first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias", + "first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight", + "first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias", + "first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight", + "first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias", + "first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight", + "first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias", + "first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight", + "first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias", + "first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight", + "first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias", + "first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight", + "first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias", + "first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight", + "first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias", + "first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight", + "first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias", + "first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight", + "first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias", + "first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight", + "first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias", + "first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight", + "first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias", + "first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight", + "first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias", + "first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight", + "first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias", + "first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight", + "first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias", + "first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight", + "first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias", + "first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight", + "first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias", + "first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight", + "first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias", + "first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight", + "first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias", + "first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight", + "first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias", + "first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight", + "first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias", + "first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight", + "first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias", + "first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight", + "first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias", + "first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight", + "first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias", + "first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight", + "first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias", + "first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight", + "first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias", + "first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight", + "first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias", + "first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight", + "first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias", + "first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight", + "first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias", + "first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight", + "first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias", + "first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight", + "first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias", + "first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight", + "first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias", + "first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight", + "first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias", + "first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight", + "first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias", + "first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight", + "first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias", + "first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight", + "first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias", + "first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight", + "first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias", + "first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight", + "first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias", + "first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight", + "first_stage_model.post_quant_conv.bias": "post_quant_conv.bias", + "first_stage_model.post_quant_conv.weight": "post_quant_conv.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if "transformer_blocks" in rename_dict[name]: + param = param.squeeze() + state_dict_[rename_dict[name]] = param + return state_dict_ diff --git a/diffsynth/models/sd_vae_encoder.py b/diffsynth/models/sd_vae_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ea76cc99889ca1d97007ef9fafe654b55ccac48a --- /dev/null +++ b/diffsynth/models/sd_vae_encoder.py @@ -0,0 +1,278 @@ +import torch +from .sd_unet import ResnetBlock, DownSampler +from .sd_vae_decoder import VAEAttentionBlock +from .tiler import TileWorker +from einops import rearrange + + +class SDVAEEncoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.scaling_factor = 0.18215 + self.quant_conv = torch.nn.Conv2d(8, 8, kernel_size=1) + self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1) + + self.blocks = torch.nn.ModuleList([ + # DownEncoderBlock2D + ResnetBlock(128, 128, eps=1e-6), + ResnetBlock(128, 128, eps=1e-6), + DownSampler(128, padding=0, extra_padding=True), + # DownEncoderBlock2D + ResnetBlock(128, 256, eps=1e-6), + ResnetBlock(256, 256, eps=1e-6), + DownSampler(256, padding=0, extra_padding=True), + # DownEncoderBlock2D + ResnetBlock(256, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + DownSampler(512, padding=0, extra_padding=True), + # DownEncoderBlock2D + ResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + # UNetMidBlock2D + ResnetBlock(512, 512, eps=1e-6), + VAEAttentionBlock(1, 512, 512, 1, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + ]) + + self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6) + self.conv_act = torch.nn.SiLU() + self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1) + + def tiled_forward(self, sample, tile_size=64, tile_stride=32): + hidden_states = TileWorker().tiled_forward( + lambda x: self.forward(x), + sample, + tile_size, + tile_stride, + tile_device=sample.device, + tile_dtype=sample.dtype + ) + return hidden_states + + def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs): + # For VAE Decoder, we do not need to apply the tiler on each layer. + if tiled: + return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride) + + # 1. pre-process + hidden_states = self.conv_in(sample) + time_emb = None + text_emb = None + res_stack = None + + # 2. blocks + for i, block in enumerate(self.blocks): + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + + # 3. output + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + hidden_states = self.quant_conv(hidden_states) + hidden_states = hidden_states[:, :4] + hidden_states *= self.scaling_factor + + return hidden_states + + def encode_video(self, sample, batch_size=8): + B = sample.shape[0] + hidden_states = [] + + for i in range(0, sample.shape[2], batch_size): + + j = min(i + batch_size, sample.shape[2]) + sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W") + + hidden_states_batch = self(sample_batch) + hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B) + + hidden_states.append(hidden_states_batch) + + hidden_states = torch.concat(hidden_states, dim=2) + return hidden_states + + def state_dict_converter(self): + return SDVAEEncoderStateDictConverter() + + +class SDVAEEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + # architecture + block_types = [ + 'ResnetBlock', 'ResnetBlock', 'DownSampler', + 'ResnetBlock', 'ResnetBlock', 'DownSampler', + 'ResnetBlock', 'ResnetBlock', 'DownSampler', + 'ResnetBlock', 'ResnetBlock', + 'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock' + ] + + # Rename each parameter + local_rename_dict = { + "quant_conv": "quant_conv", + "encoder.conv_in": "conv_in", + "encoder.mid_block.attentions.0.group_norm": "blocks.12.norm", + "encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q", + "encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k", + "encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v", + "encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out", + "encoder.mid_block.resnets.0.norm1": "blocks.11.norm1", + "encoder.mid_block.resnets.0.conv1": "blocks.11.conv1", + "encoder.mid_block.resnets.0.norm2": "blocks.11.norm2", + "encoder.mid_block.resnets.0.conv2": "blocks.11.conv2", + "encoder.mid_block.resnets.1.norm1": "blocks.13.norm1", + "encoder.mid_block.resnets.1.conv1": "blocks.13.conv1", + "encoder.mid_block.resnets.1.norm2": "blocks.13.norm2", + "encoder.mid_block.resnets.1.conv2": "blocks.13.conv2", + "encoder.conv_norm_out": "conv_norm_out", + "encoder.conv_out": "conv_out", + } + name_list = sorted([name for name in state_dict]) + rename_dict = {} + block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1} + last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""} + for name in name_list: + names = name.split(".") + name_prefix = ".".join(names[:-1]) + if name_prefix in local_rename_dict: + rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1] + elif name.startswith("encoder.down_blocks"): + block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]] + block_type_with_id = ".".join(names[:5]) + if block_type_with_id != last_block_type_with_id[block_type]: + block_id[block_type] += 1 + last_block_type_with_id[block_type] = block_type_with_id + while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type: + block_id[block_type] += 1 + block_type_with_id = ".".join(names[:5]) + names = ["blocks", str(block_id[block_type])] + names[5:] + rename_dict[name] = ".".join(names) + + # Convert state_dict + state_dict_ = {} + for name, param in state_dict.items(): + if name in rename_dict: + state_dict_[rename_dict[name]] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "first_stage_model.encoder.conv_in.bias": "conv_in.bias", + "first_stage_model.encoder.conv_in.weight": "conv_in.weight", + "first_stage_model.encoder.conv_out.bias": "conv_out.bias", + "first_stage_model.encoder.conv_out.weight": "conv_out.weight", + "first_stage_model.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias", + "first_stage_model.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight", + "first_stage_model.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias", + "first_stage_model.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight", + "first_stage_model.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias", + "first_stage_model.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight", + "first_stage_model.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias", + "first_stage_model.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight", + "first_stage_model.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias", + "first_stage_model.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight", + "first_stage_model.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias", + "first_stage_model.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight", + "first_stage_model.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias", + "first_stage_model.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight", + "first_stage_model.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias", + "first_stage_model.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight", + "first_stage_model.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias", + "first_stage_model.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight", + "first_stage_model.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias", + "first_stage_model.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight", + "first_stage_model.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias", + "first_stage_model.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias", + "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight", + "first_stage_model.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias", + "first_stage_model.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight", + "first_stage_model.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias", + "first_stage_model.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight", + "first_stage_model.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias", + "first_stage_model.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight", + "first_stage_model.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias", + "first_stage_model.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight", + "first_stage_model.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias", + "first_stage_model.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight", + "first_stage_model.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias", + "first_stage_model.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight", + "first_stage_model.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias", + "first_stage_model.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight", + "first_stage_model.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias", + "first_stage_model.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight", + "first_stage_model.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias", + "first_stage_model.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias", + "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight", + "first_stage_model.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias", + "first_stage_model.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight", + "first_stage_model.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias", + "first_stage_model.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight", + "first_stage_model.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias", + "first_stage_model.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight", + "first_stage_model.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias", + "first_stage_model.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight", + "first_stage_model.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias", + "first_stage_model.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight", + "first_stage_model.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias", + "first_stage_model.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight", + "first_stage_model.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias", + "first_stage_model.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight", + "first_stage_model.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias", + "first_stage_model.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight", + "first_stage_model.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias", + "first_stage_model.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight", + "first_stage_model.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias", + "first_stage_model.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight", + "first_stage_model.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias", + "first_stage_model.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight", + "first_stage_model.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias", + "first_stage_model.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight", + "first_stage_model.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias", + "first_stage_model.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight", + "first_stage_model.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias", + "first_stage_model.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight", + "first_stage_model.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias", + "first_stage_model.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight", + "first_stage_model.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias", + "first_stage_model.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight", + "first_stage_model.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias", + "first_stage_model.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight", + "first_stage_model.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias", + "first_stage_model.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight", + "first_stage_model.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias", + "first_stage_model.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight", + "first_stage_model.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias", + "first_stage_model.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight", + "first_stage_model.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias", + "first_stage_model.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight", + "first_stage_model.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias", + "first_stage_model.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight", + "first_stage_model.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias", + "first_stage_model.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight", + "first_stage_model.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias", + "first_stage_model.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight", + "first_stage_model.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias", + "first_stage_model.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight", + "first_stage_model.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias", + "first_stage_model.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight", + "first_stage_model.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias", + "first_stage_model.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight", + "first_stage_model.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias", + "first_stage_model.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight", + "first_stage_model.encoder.norm_out.bias": "conv_norm_out.bias", + "first_stage_model.encoder.norm_out.weight": "conv_norm_out.weight", + "first_stage_model.quant_conv.bias": "quant_conv.bias", + "first_stage_model.quant_conv.weight": "quant_conv.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if "transformer_blocks" in rename_dict[name]: + param = param.squeeze() + state_dict_[rename_dict[name]] = param + return state_dict_ diff --git a/diffsynth/models/sdxl_ipadapter.py b/diffsynth/models/sdxl_ipadapter.py new file mode 100644 index 0000000000000000000000000000000000000000..0b3faca2a5ad988a207242b4112eb1ffe855d3d1 --- /dev/null +++ b/diffsynth/models/sdxl_ipadapter.py @@ -0,0 +1,121 @@ +from .svd_image_encoder import SVDImageEncoder +from transformers import CLIPImageProcessor +import torch + + +class IpAdapterXLCLIPImageEmbedder(SVDImageEncoder): + def __init__(self): + super().__init__(embed_dim=1664, encoder_intermediate_size=8192, projection_dim=1280, num_encoder_layers=48, num_heads=16, head_dim=104) + self.image_processor = CLIPImageProcessor() + + def forward(self, image): + pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values + pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype) + return super().forward(pixel_values) + + +class IpAdapterImageProjModel(torch.nn.Module): + def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4): + super().__init__() + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + + +class IpAdapterModule(torch.nn.Module): + def __init__(self, input_dim, output_dim): + super().__init__() + self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False) + self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False) + + def forward(self, hidden_states): + ip_k = self.to_k_ip(hidden_states) + ip_v = self.to_v_ip(hidden_states) + return ip_k, ip_v + + +class SDXLIpAdapter(torch.nn.Module): + def __init__(self): + super().__init__() + shape_list = [(2048, 640)] * 4 + [(2048, 1280)] * 50 + [(2048, 640)] * 6 + [(2048, 1280)] * 10 + self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list]) + self.image_proj = IpAdapterImageProjModel() + self.set_full_adapter() + + def set_full_adapter(self): + map_list = sum([ + [(7, i) for i in range(2)], + [(10, i) for i in range(2)], + [(15, i) for i in range(10)], + [(18, i) for i in range(10)], + [(25, i) for i in range(10)], + [(28, i) for i in range(10)], + [(31, i) for i in range(10)], + [(35, i) for i in range(2)], + [(38, i) for i in range(2)], + [(41, i) for i in range(2)], + [(21, i) for i in range(10)], + ], []) + self.call_block_id = {i: j for j, i in enumerate(map_list)} + + def set_less_adapter(self): + map_list = sum([ + [(7, i) for i in range(2)], + [(10, i) for i in range(2)], + [(15, i) for i in range(10)], + [(18, i) for i in range(10)], + [(25, i) for i in range(10)], + [(28, i) for i in range(10)], + [(31, i) for i in range(10)], + [(35, i) for i in range(2)], + [(38, i) for i in range(2)], + [(41, i) for i in range(2)], + [(21, i) for i in range(10)], + ], []) + self.call_block_id = {i: j for j, i in enumerate(map_list) if j>=34 and j<44} + + def forward(self, hidden_states, scale=1.0): + hidden_states = self.image_proj(hidden_states) + hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1]) + ip_kv_dict = {} + for (block_id, transformer_id) in self.call_block_id: + ipadapter_id = self.call_block_id[(block_id, transformer_id)] + ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states) + if block_id not in ip_kv_dict: + ip_kv_dict[block_id] = {} + ip_kv_dict[block_id][transformer_id] = { + "ip_k": ip_k, + "ip_v": ip_v, + "scale": scale + } + return ip_kv_dict + + def state_dict_converter(self): + return SDXLIpAdapterStateDictConverter() + + +class SDXLIpAdapterStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + state_dict_ = {} + for name in state_dict["ip_adapter"]: + names = name.split(".") + layer_id = str(int(names[0]) // 2) + name_ = ".".join(["ipadapter_modules"] + [layer_id] + names[1:]) + state_dict_[name_] = state_dict["ip_adapter"][name] + for name in state_dict["image_proj"]: + name_ = "image_proj." + name + state_dict_[name_] = state_dict["image_proj"][name] + return state_dict_ + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) + diff --git a/diffsynth/models/sdxl_motion.py b/diffsynth/models/sdxl_motion.py new file mode 100644 index 0000000000000000000000000000000000000000..329b1c65cf11d2aeec4564a273b1282b835dc1c0 --- /dev/null +++ b/diffsynth/models/sdxl_motion.py @@ -0,0 +1,103 @@ +from .sd_motion import TemporalBlock +import torch + + + +class SDXLMotionModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.motion_modules = torch.nn.ModuleList([ + TemporalBlock(8, 320//8, 320, eps=1e-6), + TemporalBlock(8, 320//8, 320, eps=1e-6), + + TemporalBlock(8, 640//8, 640, eps=1e-6), + TemporalBlock(8, 640//8, 640, eps=1e-6), + + TemporalBlock(8, 1280//8, 1280, eps=1e-6), + TemporalBlock(8, 1280//8, 1280, eps=1e-6), + + TemporalBlock(8, 1280//8, 1280, eps=1e-6), + TemporalBlock(8, 1280//8, 1280, eps=1e-6), + TemporalBlock(8, 1280//8, 1280, eps=1e-6), + + TemporalBlock(8, 640//8, 640, eps=1e-6), + TemporalBlock(8, 640//8, 640, eps=1e-6), + TemporalBlock(8, 640//8, 640, eps=1e-6), + + TemporalBlock(8, 320//8, 320, eps=1e-6), + TemporalBlock(8, 320//8, 320, eps=1e-6), + TemporalBlock(8, 320//8, 320, eps=1e-6), + ]) + self.call_block_id = { + 0: 0, + 2: 1, + 7: 2, + 10: 3, + 15: 4, + 18: 5, + 25: 6, + 28: 7, + 31: 8, + 35: 9, + 38: 10, + 41: 11, + 44: 12, + 46: 13, + 48: 14, + } + + def forward(self): + pass + + def state_dict_converter(self): + return SDMotionModelStateDictConverter() + + +class SDMotionModelStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "norm": "norm", + "proj_in": "proj_in", + "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q", + "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k", + "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v", + "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out", + "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1", + "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q", + "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k", + "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v", + "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out", + "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2", + "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1", + "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2", + "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj", + "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff", + "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3", + "proj_out": "proj_out", + } + name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")]) + name_list += sorted([i for i in state_dict if i.startswith("mid_block.")]) + name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")]) + state_dict_ = {} + last_prefix, module_id = "", -1 + for name in name_list: + names = name.split(".") + prefix_index = names.index("temporal_transformer") + 1 + prefix = ".".join(names[:prefix_index]) + if prefix != last_prefix: + last_prefix = prefix + module_id += 1 + middle_name = ".".join(names[prefix_index:-1]) + suffix = names[-1] + if "pos_encoder" in names: + rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]]) + else: + rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix]) + state_dict_[rename] = state_dict[name] + return state_dict_ + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) diff --git a/diffsynth/models/sdxl_text_encoder.py b/diffsynth/models/sdxl_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3df6ec6fa36574ffcbae5bf982a29d2fb25f358f --- /dev/null +++ b/diffsynth/models/sdxl_text_encoder.py @@ -0,0 +1,757 @@ +import torch +from .sd_text_encoder import CLIPEncoderLayer + + +class SDXLTextEncoder(torch.nn.Module): + def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=11, encoder_intermediate_size=3072): + super().__init__() + + # token_embedding + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # The text encoder is different to that in Stable Diffusion 1.x. + # It does not include final_layer_norm. + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, input_ids, clip_skip=1): + embeds = self.token_embedding(input_ids) + self.position_embeds + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + if encoder_id + clip_skip == len(self.encoders): + break + return embeds + + def state_dict_converter(self): + return SDXLTextEncoderStateDictConverter() + + +class SDXLTextEncoder2(torch.nn.Module): + def __init__(self, embed_dim=1280, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=32, encoder_intermediate_size=5120): + super().__init__() + + # token_embedding + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=20, head_dim=64, use_quick_gelu=False) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # final_layer_norm + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + # text_projection + self.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False) + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, input_ids, clip_skip=2): + embeds = self.token_embedding(input_ids) + self.position_embeds + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + if encoder_id + clip_skip == len(self.encoders): + hidden_states = embeds + embeds = self.final_layer_norm(embeds) + pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)] + pooled_embeds = self.text_projection(pooled_embeds) + return pooled_embeds, hidden_states + + def state_dict_converter(self): + return SDXLTextEncoder2StateDictConverter() + + +class SDXLTextEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embeds", + "text_model.final_layer_norm.weight": "final_layer_norm.weight", + "text_model.final_layer_norm.bias": "final_layer_norm.bias" + } + attn_rename_dict = { + "self_attn.q_proj": "attn.to_q", + "self_attn.k_proj": "attn.to_k", + "self_attn.v_proj": "attn.to_v", + "self_attn.out_proj": "attn.to_out", + "layer_norm1": "layer_norm1", + "layer_norm2": "layer_norm2", + "mlp.fc1": "fc1", + "mlp.fc2": "fc2", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + elif name.startswith("text_model.encoder.layers."): + param = state_dict[name] + names = name.split(".") + layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] + name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) + state_dict_[name_] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "position_embeds", + "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias", + "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + return state_dict_ + + +class SDXLTextEncoder2StateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embeds", + "text_model.final_layer_norm.weight": "final_layer_norm.weight", + "text_model.final_layer_norm.bias": "final_layer_norm.bias", + "text_projection.weight": "text_projection.weight" + } + attn_rename_dict = { + "self_attn.q_proj": "attn.to_q", + "self_attn.k_proj": "attn.to_k", + "self_attn.v_proj": "attn.to_v", + "self_attn.out_proj": "attn.to_out", + "layer_norm1": "layer_norm1", + "layer_norm2": "layer_norm2", + "mlp.fc1": "fc1", + "mlp.fc2": "fc2", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + elif name.startswith("text_model.encoder.layers."): + param = state_dict[name] + names = name.split(".") + layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] + name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) + state_dict_[name_] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "conditioner.embedders.1.model.ln_final.bias": "final_layer_norm.bias", + "conditioner.embedders.1.model.ln_final.weight": "final_layer_norm.weight", + "conditioner.embedders.1.model.positional_embedding": "position_embeds", + "conditioner.embedders.1.model.token_embedding.weight": "token_embedding.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'], + "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'], + "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias", + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight", + "conditioner.embedders.1.model.text_projection": "text_projection.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "conditioner.embedders.1.model.positional_embedding": + param = param.reshape((1, param.shape[0], param.shape[1])) + elif name == "conditioner.embedders.1.model.text_projection": + param = param.T + if isinstance(rename_dict[name], str): + state_dict_[rename_dict[name]] = param + else: + length = param.shape[0] // 3 + for i, rename in enumerate(rename_dict[name]): + state_dict_[rename] = param[i*length: i*length+length] + return state_dict_ \ No newline at end of file diff --git a/diffsynth/models/sdxl_unet.py b/diffsynth/models/sdxl_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..a33625940c01b3a103666d2d8f8cc1e69eea9e02 --- /dev/null +++ b/diffsynth/models/sdxl_unet.py @@ -0,0 +1,1876 @@ +import torch +from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock, DownSampler, UpSampler + + +class SDXLUNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.time_proj = Timesteps(320) + self.time_embedding = torch.nn.Sequential( + torch.nn.Linear(320, 1280), + torch.nn.SiLU(), + torch.nn.Linear(1280, 1280) + ) + self.add_time_proj = Timesteps(256) + self.add_time_embedding = torch.nn.Sequential( + torch.nn.Linear(2816, 1280), + torch.nn.SiLU(), + torch.nn.Linear(1280, 1280) + ) + self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1) + + self.blocks = torch.nn.ModuleList([ + # DownBlock2D + ResnetBlock(320, 320, 1280), + PushBlock(), + ResnetBlock(320, 320, 1280), + PushBlock(), + DownSampler(320), + PushBlock(), + # CrossAttnDownBlock2D + ResnetBlock(320, 640, 1280), + AttentionBlock(10, 64, 640, 2, 2048), + PushBlock(), + ResnetBlock(640, 640, 1280), + AttentionBlock(10, 64, 640, 2, 2048), + PushBlock(), + DownSampler(640), + PushBlock(), + # CrossAttnDownBlock2D + ResnetBlock(640, 1280, 1280), + AttentionBlock(20, 64, 1280, 10, 2048), + PushBlock(), + ResnetBlock(1280, 1280, 1280), + AttentionBlock(20, 64, 1280, 10, 2048), + PushBlock(), + # UNetMidBlock2DCrossAttn + ResnetBlock(1280, 1280, 1280), + AttentionBlock(20, 64, 1280, 10, 2048), + ResnetBlock(1280, 1280, 1280), + # CrossAttnUpBlock2D + PopBlock(), + ResnetBlock(2560, 1280, 1280), + AttentionBlock(20, 64, 1280, 10, 2048), + PopBlock(), + ResnetBlock(2560, 1280, 1280), + AttentionBlock(20, 64, 1280, 10, 2048), + PopBlock(), + ResnetBlock(1920, 1280, 1280), + AttentionBlock(20, 64, 1280, 10, 2048), + UpSampler(1280), + # CrossAttnUpBlock2D + PopBlock(), + ResnetBlock(1920, 640, 1280), + AttentionBlock(10, 64, 640, 2, 2048), + PopBlock(), + ResnetBlock(1280, 640, 1280), + AttentionBlock(10, 64, 640, 2, 2048), + PopBlock(), + ResnetBlock(960, 640, 1280), + AttentionBlock(10, 64, 640, 2, 2048), + UpSampler(640), + # UpBlock2D + PopBlock(), + ResnetBlock(960, 320, 1280), + PopBlock(), + ResnetBlock(640, 320, 1280), + PopBlock(), + ResnetBlock(640, 320, 1280) + ]) + + self.conv_norm_out = torch.nn.GroupNorm(num_channels=320, num_groups=32, eps=1e-5) + self.conv_act = torch.nn.SiLU() + self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1) + + def forward( + self, + sample, timestep, encoder_hidden_states, add_time_id, add_text_embeds, + tiled=False, tile_size=64, tile_stride=8, **kwargs + ): + # 1. time + t_emb = self.time_proj(timestep[None]).to(sample.dtype) + t_emb = self.time_embedding(t_emb) + + time_embeds = self.add_time_proj(add_time_id) + time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1)) + add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(sample.dtype) + add_embeds = self.add_time_embedding(add_embeds) + + time_emb = t_emb + add_embeds + + # 2. pre-process + height, width = sample.shape[2], sample.shape[3] + hidden_states = self.conv_in(sample) + text_emb = encoder_hidden_states + res_stack = [hidden_states] + + # 3. blocks + for i, block in enumerate(self.blocks): + hidden_states, time_emb, text_emb, res_stack = block( + hidden_states, time_emb, text_emb, res_stack, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride + ) + + # 4. output + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + def state_dict_converter(self): + return SDXLUNetStateDictConverter() + + +class SDXLUNetStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + # architecture + block_types = [ + 'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock', 'DownSampler', 'PushBlock', + 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock', + 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', + 'ResnetBlock', 'AttentionBlock', 'ResnetBlock', + 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler', + 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler', + 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock' + ] + + # Rename each parameter + name_list = sorted([name for name in state_dict]) + rename_dict = {} + block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1} + last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""} + for name in name_list: + names = name.split(".") + if names[0] in ["conv_in", "conv_norm_out", "conv_out"]: + pass + elif names[0] in ["time_embedding", "add_embedding"]: + if names[0] == "add_embedding": + names[0] = "add_time_embedding" + names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]] + elif names[0] in ["down_blocks", "mid_block", "up_blocks"]: + if names[0] == "mid_block": + names.insert(1, "0") + block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]] + block_type_with_id = ".".join(names[:4]) + if block_type_with_id != last_block_type_with_id[block_type]: + block_id[block_type] += 1 + last_block_type_with_id[block_type] = block_type_with_id + while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type: + block_id[block_type] += 1 + block_type_with_id = ".".join(names[:4]) + names = ["blocks", str(block_id[block_type])] + names[4:] + if "ff" in names: + ff_index = names.index("ff") + component = ".".join(names[ff_index:ff_index+3]) + component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component] + names = names[:ff_index] + [component] + names[ff_index+3:] + if "to_out" in names: + names.pop(names.index("to_out") + 1) + else: + raise ValueError(f"Unknown parameters: {name}") + rename_dict[name] = ".".join(names) + + # Convert state_dict + state_dict_ = {} + for name, param in state_dict.items(): + if ".proj_in." in name or ".proj_out." in name: + param = param.squeeze() + state_dict_[rename_dict[name]] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "model.diffusion_model.input_blocks.0.0.bias": "conv_in.bias", + "model.diffusion_model.input_blocks.0.0.weight": "conv_in.weight", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "blocks.2.time_emb_proj.bias", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "blocks.2.time_emb_proj.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "blocks.2.conv1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "blocks.2.conv1.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "blocks.2.conv2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "blocks.2.conv2.weight", + "model.diffusion_model.input_blocks.3.0.op.bias": "blocks.4.conv.bias", + "model.diffusion_model.input_blocks.3.0.op.weight": "blocks.4.conv.weight", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "blocks.6.time_emb_proj.bias", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "blocks.6.time_emb_proj.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "blocks.6.norm1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "blocks.6.norm1.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "blocks.6.conv1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "blocks.6.conv1.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "blocks.6.norm2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "blocks.6.norm2.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "blocks.6.conv2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "blocks.6.conv2.weight", + "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "blocks.6.conv_shortcut.bias", + "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "blocks.6.conv_shortcut.weight", + "model.diffusion_model.input_blocks.4.1.norm.bias": "blocks.7.norm.bias", + "model.diffusion_model.input_blocks.4.1.norm.weight": "blocks.7.norm.weight", + "model.diffusion_model.input_blocks.4.1.proj_in.bias": "blocks.7.proj_in.bias", + "model.diffusion_model.input_blocks.4.1.proj_in.weight": "blocks.7.proj_in.weight", + "model.diffusion_model.input_blocks.4.1.proj_out.bias": "blocks.7.proj_out.bias", + "model.diffusion_model.input_blocks.4.1.proj_out.weight": "blocks.7.proj_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.7.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.7.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.7.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.7.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.7.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.7.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.7.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.7.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.7.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.7.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.7.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.7.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.7.transformer_blocks.0.ff.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.7.transformer_blocks.0.ff.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.7.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.7.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.7.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.7.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.7.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.7.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "blocks.7.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.7.transformer_blocks.1.attn1.to_out.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.7.transformer_blocks.1.attn1.to_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "blocks.7.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "blocks.7.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "blocks.7.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.7.transformer_blocks.1.attn2.to_out.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.7.transformer_blocks.1.attn2.to_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "blocks.7.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "blocks.7.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.7.transformer_blocks.1.act_fn.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.7.transformer_blocks.1.act_fn.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "blocks.7.transformer_blocks.1.ff.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "blocks.7.transformer_blocks.1.ff.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.bias": "blocks.7.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.weight": "blocks.7.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.bias": "blocks.7.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.weight": "blocks.7.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.bias": "blocks.7.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.weight": "blocks.7.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "blocks.9.time_emb_proj.bias", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "blocks.9.time_emb_proj.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "blocks.9.norm1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "blocks.9.norm1.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "blocks.9.conv1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "blocks.9.conv1.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "blocks.9.norm2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "blocks.9.norm2.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "blocks.9.conv2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "blocks.9.conv2.weight", + "model.diffusion_model.input_blocks.5.1.norm.bias": "blocks.10.norm.bias", + "model.diffusion_model.input_blocks.5.1.norm.weight": "blocks.10.norm.weight", + "model.diffusion_model.input_blocks.5.1.proj_in.bias": "blocks.10.proj_in.bias", + "model.diffusion_model.input_blocks.5.1.proj_in.weight": "blocks.10.proj_in.weight", + "model.diffusion_model.input_blocks.5.1.proj_out.bias": "blocks.10.proj_out.bias", + "model.diffusion_model.input_blocks.5.1.proj_out.weight": "blocks.10.proj_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.10.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.10.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.10.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.10.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.10.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.10.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.10.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.10.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.10.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.10.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.10.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.10.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.10.transformer_blocks.0.ff.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.10.transformer_blocks.0.ff.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.10.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.10.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.10.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.10.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.10.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.10.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "blocks.10.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.10.transformer_blocks.1.attn1.to_out.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.10.transformer_blocks.1.attn1.to_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "blocks.10.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "blocks.10.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "blocks.10.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.10.transformer_blocks.1.attn2.to_out.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.10.transformer_blocks.1.attn2.to_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "blocks.10.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "blocks.10.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.10.transformer_blocks.1.act_fn.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.10.transformer_blocks.1.act_fn.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "blocks.10.transformer_blocks.1.ff.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "blocks.10.transformer_blocks.1.ff.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.bias": "blocks.10.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.weight": "blocks.10.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.bias": "blocks.10.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.weight": "blocks.10.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.bias": "blocks.10.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.weight": "blocks.10.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.6.0.op.bias": "blocks.12.conv.bias", + "model.diffusion_model.input_blocks.6.0.op.weight": "blocks.12.conv.weight", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "blocks.14.time_emb_proj.bias", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "blocks.14.time_emb_proj.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "blocks.14.norm1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "blocks.14.norm1.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "blocks.14.conv1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "blocks.14.conv1.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "blocks.14.norm2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "blocks.14.norm2.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "blocks.14.conv2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "blocks.14.conv2.weight", + "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "blocks.14.conv_shortcut.bias", + "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "blocks.14.conv_shortcut.weight", + "model.diffusion_model.input_blocks.7.1.norm.bias": "blocks.15.norm.bias", + "model.diffusion_model.input_blocks.7.1.norm.weight": "blocks.15.norm.weight", + "model.diffusion_model.input_blocks.7.1.proj_in.bias": "blocks.15.proj_in.bias", + "model.diffusion_model.input_blocks.7.1.proj_in.weight": "blocks.15.proj_in.weight", + "model.diffusion_model.input_blocks.7.1.proj_out.bias": "blocks.15.proj_out.bias", + "model.diffusion_model.input_blocks.7.1.proj_out.weight": "blocks.15.proj_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.15.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.15.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.15.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.15.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.15.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.15.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.15.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.15.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.15.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.15.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.15.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.15.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.15.transformer_blocks.0.ff.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.15.transformer_blocks.0.ff.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.15.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.15.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.15.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.15.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.15.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.15.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_k.weight": "blocks.15.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.15.transformer_blocks.1.attn1.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.15.transformer_blocks.1.attn1.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_q.weight": "blocks.15.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_v.weight": "blocks.15.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_k.weight": "blocks.15.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.15.transformer_blocks.1.attn2.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.15.transformer_blocks.1.attn2.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_q.weight": "blocks.15.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_v.weight": "blocks.15.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.15.transformer_blocks.1.act_fn.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.15.transformer_blocks.1.act_fn.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.bias": "blocks.15.transformer_blocks.1.ff.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.weight": "blocks.15.transformer_blocks.1.ff.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.bias": "blocks.15.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.weight": "blocks.15.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.bias": "blocks.15.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.weight": "blocks.15.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.bias": "blocks.15.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.weight": "blocks.15.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_k.weight": "blocks.15.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.bias": "blocks.15.transformer_blocks.2.attn1.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.weight": "blocks.15.transformer_blocks.2.attn1.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_q.weight": "blocks.15.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_v.weight": "blocks.15.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_k.weight": "blocks.15.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.bias": "blocks.15.transformer_blocks.2.attn2.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.weight": "blocks.15.transformer_blocks.2.attn2.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_q.weight": "blocks.15.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_v.weight": "blocks.15.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.bias": "blocks.15.transformer_blocks.2.act_fn.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.weight": "blocks.15.transformer_blocks.2.act_fn.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.bias": "blocks.15.transformer_blocks.2.ff.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.weight": "blocks.15.transformer_blocks.2.ff.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.bias": "blocks.15.transformer_blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.weight": "blocks.15.transformer_blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.bias": "blocks.15.transformer_blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.weight": "blocks.15.transformer_blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.bias": "blocks.15.transformer_blocks.2.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.weight": "blocks.15.transformer_blocks.2.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_k.weight": "blocks.15.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.bias": "blocks.15.transformer_blocks.3.attn1.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.weight": "blocks.15.transformer_blocks.3.attn1.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_q.weight": "blocks.15.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_v.weight": "blocks.15.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_k.weight": "blocks.15.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.bias": "blocks.15.transformer_blocks.3.attn2.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.weight": "blocks.15.transformer_blocks.3.attn2.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_q.weight": "blocks.15.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_v.weight": "blocks.15.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.bias": "blocks.15.transformer_blocks.3.act_fn.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.weight": "blocks.15.transformer_blocks.3.act_fn.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.bias": "blocks.15.transformer_blocks.3.ff.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.weight": "blocks.15.transformer_blocks.3.ff.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.bias": "blocks.15.transformer_blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.weight": "blocks.15.transformer_blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.bias": "blocks.15.transformer_blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.weight": "blocks.15.transformer_blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.bias": "blocks.15.transformer_blocks.3.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.weight": "blocks.15.transformer_blocks.3.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_k.weight": "blocks.15.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_out.0.bias": "blocks.15.transformer_blocks.4.attn1.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_out.0.weight": "blocks.15.transformer_blocks.4.attn1.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_q.weight": "blocks.15.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_v.weight": "blocks.15.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_k.weight": "blocks.15.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_out.0.bias": "blocks.15.transformer_blocks.4.attn2.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_out.0.weight": "blocks.15.transformer_blocks.4.attn2.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_q.weight": "blocks.15.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_v.weight": "blocks.15.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.0.proj.bias": "blocks.15.transformer_blocks.4.act_fn.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.0.proj.weight": "blocks.15.transformer_blocks.4.act_fn.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.2.bias": "blocks.15.transformer_blocks.4.ff.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.2.weight": "blocks.15.transformer_blocks.4.ff.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm1.bias": "blocks.15.transformer_blocks.4.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm1.weight": "blocks.15.transformer_blocks.4.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm2.bias": "blocks.15.transformer_blocks.4.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm2.weight": "blocks.15.transformer_blocks.4.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm3.bias": "blocks.15.transformer_blocks.4.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm3.weight": "blocks.15.transformer_blocks.4.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_k.weight": "blocks.15.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_out.0.bias": "blocks.15.transformer_blocks.5.attn1.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_out.0.weight": "blocks.15.transformer_blocks.5.attn1.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_q.weight": "blocks.15.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_v.weight": "blocks.15.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_k.weight": "blocks.15.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_out.0.bias": "blocks.15.transformer_blocks.5.attn2.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_out.0.weight": "blocks.15.transformer_blocks.5.attn2.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_q.weight": "blocks.15.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_v.weight": "blocks.15.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.0.proj.bias": "blocks.15.transformer_blocks.5.act_fn.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.0.proj.weight": "blocks.15.transformer_blocks.5.act_fn.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.2.bias": "blocks.15.transformer_blocks.5.ff.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.2.weight": "blocks.15.transformer_blocks.5.ff.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm1.bias": "blocks.15.transformer_blocks.5.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm1.weight": "blocks.15.transformer_blocks.5.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm2.bias": "blocks.15.transformer_blocks.5.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm2.weight": "blocks.15.transformer_blocks.5.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm3.bias": "blocks.15.transformer_blocks.5.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm3.weight": "blocks.15.transformer_blocks.5.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_k.weight": "blocks.15.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_out.0.bias": "blocks.15.transformer_blocks.6.attn1.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_out.0.weight": "blocks.15.transformer_blocks.6.attn1.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_q.weight": "blocks.15.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_v.weight": "blocks.15.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_k.weight": "blocks.15.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_out.0.bias": "blocks.15.transformer_blocks.6.attn2.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_out.0.weight": "blocks.15.transformer_blocks.6.attn2.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_q.weight": "blocks.15.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_v.weight": "blocks.15.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.0.proj.bias": "blocks.15.transformer_blocks.6.act_fn.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.0.proj.weight": "blocks.15.transformer_blocks.6.act_fn.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.2.bias": "blocks.15.transformer_blocks.6.ff.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.2.weight": "blocks.15.transformer_blocks.6.ff.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm1.bias": "blocks.15.transformer_blocks.6.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm1.weight": "blocks.15.transformer_blocks.6.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm2.bias": "blocks.15.transformer_blocks.6.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm2.weight": "blocks.15.transformer_blocks.6.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm3.bias": "blocks.15.transformer_blocks.6.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm3.weight": "blocks.15.transformer_blocks.6.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_k.weight": "blocks.15.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_out.0.bias": "blocks.15.transformer_blocks.7.attn1.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_out.0.weight": "blocks.15.transformer_blocks.7.attn1.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_q.weight": "blocks.15.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_v.weight": "blocks.15.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_k.weight": "blocks.15.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_out.0.bias": "blocks.15.transformer_blocks.7.attn2.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_out.0.weight": "blocks.15.transformer_blocks.7.attn2.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_q.weight": "blocks.15.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_v.weight": "blocks.15.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.0.proj.bias": "blocks.15.transformer_blocks.7.act_fn.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.0.proj.weight": "blocks.15.transformer_blocks.7.act_fn.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.2.bias": "blocks.15.transformer_blocks.7.ff.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.2.weight": "blocks.15.transformer_blocks.7.ff.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm1.bias": "blocks.15.transformer_blocks.7.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm1.weight": "blocks.15.transformer_blocks.7.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm2.bias": "blocks.15.transformer_blocks.7.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm2.weight": "blocks.15.transformer_blocks.7.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm3.bias": "blocks.15.transformer_blocks.7.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm3.weight": "blocks.15.transformer_blocks.7.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_k.weight": "blocks.15.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_out.0.bias": "blocks.15.transformer_blocks.8.attn1.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_out.0.weight": "blocks.15.transformer_blocks.8.attn1.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_q.weight": "blocks.15.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_v.weight": "blocks.15.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_k.weight": "blocks.15.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_out.0.bias": "blocks.15.transformer_blocks.8.attn2.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_out.0.weight": "blocks.15.transformer_blocks.8.attn2.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_q.weight": "blocks.15.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_v.weight": "blocks.15.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.0.proj.bias": "blocks.15.transformer_blocks.8.act_fn.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.0.proj.weight": "blocks.15.transformer_blocks.8.act_fn.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.2.bias": "blocks.15.transformer_blocks.8.ff.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.2.weight": "blocks.15.transformer_blocks.8.ff.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm1.bias": "blocks.15.transformer_blocks.8.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm1.weight": "blocks.15.transformer_blocks.8.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm2.bias": "blocks.15.transformer_blocks.8.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm2.weight": "blocks.15.transformer_blocks.8.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm3.bias": "blocks.15.transformer_blocks.8.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm3.weight": "blocks.15.transformer_blocks.8.norm3.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_k.weight": "blocks.15.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_out.0.bias": "blocks.15.transformer_blocks.9.attn1.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_out.0.weight": "blocks.15.transformer_blocks.9.attn1.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_q.weight": "blocks.15.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_v.weight": "blocks.15.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_k.weight": "blocks.15.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_out.0.bias": "blocks.15.transformer_blocks.9.attn2.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_out.0.weight": "blocks.15.transformer_blocks.9.attn2.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_q.weight": "blocks.15.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_v.weight": "blocks.15.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.0.proj.bias": "blocks.15.transformer_blocks.9.act_fn.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.0.proj.weight": "blocks.15.transformer_blocks.9.act_fn.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.2.bias": "blocks.15.transformer_blocks.9.ff.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.2.weight": "blocks.15.transformer_blocks.9.ff.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm1.bias": "blocks.15.transformer_blocks.9.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm1.weight": "blocks.15.transformer_blocks.9.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm2.bias": "blocks.15.transformer_blocks.9.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm2.weight": "blocks.15.transformer_blocks.9.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm3.bias": "blocks.15.transformer_blocks.9.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm3.weight": "blocks.15.transformer_blocks.9.norm3.weight", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "blocks.17.time_emb_proj.bias", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "blocks.17.time_emb_proj.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "blocks.17.norm1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "blocks.17.norm1.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "blocks.17.conv1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "blocks.17.conv1.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "blocks.17.norm2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "blocks.17.norm2.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "blocks.17.conv2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "blocks.17.conv2.weight", + "model.diffusion_model.input_blocks.8.1.norm.bias": "blocks.18.norm.bias", + "model.diffusion_model.input_blocks.8.1.norm.weight": "blocks.18.norm.weight", + "model.diffusion_model.input_blocks.8.1.proj_in.bias": "blocks.18.proj_in.bias", + "model.diffusion_model.input_blocks.8.1.proj_in.weight": "blocks.18.proj_in.weight", + "model.diffusion_model.input_blocks.8.1.proj_out.bias": "blocks.18.proj_out.bias", + "model.diffusion_model.input_blocks.8.1.proj_out.weight": "blocks.18.proj_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.18.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.18.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.18.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.18.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.18.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.18.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.18.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.18.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.18.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.18.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.18.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.18.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.18.transformer_blocks.0.ff.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.18.transformer_blocks.0.ff.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.18.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.18.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.18.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.18.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.18.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.18.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_k.weight": "blocks.18.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.18.transformer_blocks.1.attn1.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.18.transformer_blocks.1.attn1.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_q.weight": "blocks.18.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_v.weight": "blocks.18.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_k.weight": "blocks.18.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.18.transformer_blocks.1.attn2.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.18.transformer_blocks.1.attn2.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_q.weight": "blocks.18.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_v.weight": "blocks.18.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.18.transformer_blocks.1.act_fn.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.18.transformer_blocks.1.act_fn.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.bias": "blocks.18.transformer_blocks.1.ff.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.weight": "blocks.18.transformer_blocks.1.ff.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.bias": "blocks.18.transformer_blocks.1.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.weight": "blocks.18.transformer_blocks.1.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.bias": "blocks.18.transformer_blocks.1.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.weight": "blocks.18.transformer_blocks.1.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.bias": "blocks.18.transformer_blocks.1.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.weight": "blocks.18.transformer_blocks.1.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_k.weight": "blocks.18.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.bias": "blocks.18.transformer_blocks.2.attn1.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.weight": "blocks.18.transformer_blocks.2.attn1.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_q.weight": "blocks.18.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_v.weight": "blocks.18.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_k.weight": "blocks.18.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.bias": "blocks.18.transformer_blocks.2.attn2.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.weight": "blocks.18.transformer_blocks.2.attn2.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_q.weight": "blocks.18.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_v.weight": "blocks.18.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.bias": "blocks.18.transformer_blocks.2.act_fn.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.weight": "blocks.18.transformer_blocks.2.act_fn.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.bias": "blocks.18.transformer_blocks.2.ff.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.weight": "blocks.18.transformer_blocks.2.ff.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.bias": "blocks.18.transformer_blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.weight": "blocks.18.transformer_blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.bias": "blocks.18.transformer_blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.weight": "blocks.18.transformer_blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.bias": "blocks.18.transformer_blocks.2.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.weight": "blocks.18.transformer_blocks.2.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_k.weight": "blocks.18.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.bias": "blocks.18.transformer_blocks.3.attn1.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.weight": "blocks.18.transformer_blocks.3.attn1.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_q.weight": "blocks.18.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_v.weight": "blocks.18.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_k.weight": "blocks.18.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.bias": "blocks.18.transformer_blocks.3.attn2.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.weight": "blocks.18.transformer_blocks.3.attn2.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_q.weight": "blocks.18.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_v.weight": "blocks.18.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.bias": "blocks.18.transformer_blocks.3.act_fn.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.weight": "blocks.18.transformer_blocks.3.act_fn.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.bias": "blocks.18.transformer_blocks.3.ff.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.weight": "blocks.18.transformer_blocks.3.ff.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.bias": "blocks.18.transformer_blocks.3.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.weight": "blocks.18.transformer_blocks.3.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.bias": "blocks.18.transformer_blocks.3.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.weight": "blocks.18.transformer_blocks.3.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.bias": "blocks.18.transformer_blocks.3.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.weight": "blocks.18.transformer_blocks.3.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_k.weight": "blocks.18.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_out.0.bias": "blocks.18.transformer_blocks.4.attn1.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_out.0.weight": "blocks.18.transformer_blocks.4.attn1.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_q.weight": "blocks.18.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_v.weight": "blocks.18.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_k.weight": "blocks.18.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_out.0.bias": "blocks.18.transformer_blocks.4.attn2.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_out.0.weight": "blocks.18.transformer_blocks.4.attn2.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_q.weight": "blocks.18.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_v.weight": "blocks.18.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.0.proj.bias": "blocks.18.transformer_blocks.4.act_fn.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.0.proj.weight": "blocks.18.transformer_blocks.4.act_fn.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.2.bias": "blocks.18.transformer_blocks.4.ff.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.2.weight": "blocks.18.transformer_blocks.4.ff.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm1.bias": "blocks.18.transformer_blocks.4.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm1.weight": "blocks.18.transformer_blocks.4.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm2.bias": "blocks.18.transformer_blocks.4.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm2.weight": "blocks.18.transformer_blocks.4.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm3.bias": "blocks.18.transformer_blocks.4.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm3.weight": "blocks.18.transformer_blocks.4.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_k.weight": "blocks.18.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_out.0.bias": "blocks.18.transformer_blocks.5.attn1.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_out.0.weight": "blocks.18.transformer_blocks.5.attn1.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_q.weight": "blocks.18.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_v.weight": "blocks.18.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_k.weight": "blocks.18.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_out.0.bias": "blocks.18.transformer_blocks.5.attn2.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_out.0.weight": "blocks.18.transformer_blocks.5.attn2.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_q.weight": "blocks.18.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_v.weight": "blocks.18.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.0.proj.bias": "blocks.18.transformer_blocks.5.act_fn.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.0.proj.weight": "blocks.18.transformer_blocks.5.act_fn.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.2.bias": "blocks.18.transformer_blocks.5.ff.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.2.weight": "blocks.18.transformer_blocks.5.ff.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm1.bias": "blocks.18.transformer_blocks.5.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm1.weight": "blocks.18.transformer_blocks.5.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm2.bias": "blocks.18.transformer_blocks.5.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm2.weight": "blocks.18.transformer_blocks.5.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm3.bias": "blocks.18.transformer_blocks.5.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm3.weight": "blocks.18.transformer_blocks.5.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_k.weight": "blocks.18.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_out.0.bias": "blocks.18.transformer_blocks.6.attn1.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_out.0.weight": "blocks.18.transformer_blocks.6.attn1.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_q.weight": "blocks.18.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_v.weight": "blocks.18.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_k.weight": "blocks.18.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_out.0.bias": "blocks.18.transformer_blocks.6.attn2.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_out.0.weight": "blocks.18.transformer_blocks.6.attn2.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_q.weight": "blocks.18.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_v.weight": "blocks.18.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.0.proj.bias": "blocks.18.transformer_blocks.6.act_fn.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.0.proj.weight": "blocks.18.transformer_blocks.6.act_fn.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.2.bias": "blocks.18.transformer_blocks.6.ff.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.2.weight": "blocks.18.transformer_blocks.6.ff.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm1.bias": "blocks.18.transformer_blocks.6.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm1.weight": "blocks.18.transformer_blocks.6.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm2.bias": "blocks.18.transformer_blocks.6.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm2.weight": "blocks.18.transformer_blocks.6.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm3.bias": "blocks.18.transformer_blocks.6.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm3.weight": "blocks.18.transformer_blocks.6.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_k.weight": "blocks.18.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_out.0.bias": "blocks.18.transformer_blocks.7.attn1.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_out.0.weight": "blocks.18.transformer_blocks.7.attn1.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_q.weight": "blocks.18.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_v.weight": "blocks.18.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_k.weight": "blocks.18.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_out.0.bias": "blocks.18.transformer_blocks.7.attn2.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_out.0.weight": "blocks.18.transformer_blocks.7.attn2.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_q.weight": "blocks.18.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_v.weight": "blocks.18.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.0.proj.bias": "blocks.18.transformer_blocks.7.act_fn.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.0.proj.weight": "blocks.18.transformer_blocks.7.act_fn.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.2.bias": "blocks.18.transformer_blocks.7.ff.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.2.weight": "blocks.18.transformer_blocks.7.ff.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm1.bias": "blocks.18.transformer_blocks.7.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm1.weight": "blocks.18.transformer_blocks.7.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm2.bias": "blocks.18.transformer_blocks.7.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm2.weight": "blocks.18.transformer_blocks.7.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm3.bias": "blocks.18.transformer_blocks.7.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm3.weight": "blocks.18.transformer_blocks.7.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_k.weight": "blocks.18.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_out.0.bias": "blocks.18.transformer_blocks.8.attn1.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_out.0.weight": "blocks.18.transformer_blocks.8.attn1.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_q.weight": "blocks.18.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_v.weight": "blocks.18.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_k.weight": "blocks.18.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_out.0.bias": "blocks.18.transformer_blocks.8.attn2.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_out.0.weight": "blocks.18.transformer_blocks.8.attn2.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_q.weight": "blocks.18.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_v.weight": "blocks.18.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.0.proj.bias": "blocks.18.transformer_blocks.8.act_fn.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.0.proj.weight": "blocks.18.transformer_blocks.8.act_fn.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.2.bias": "blocks.18.transformer_blocks.8.ff.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.2.weight": "blocks.18.transformer_blocks.8.ff.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm1.bias": "blocks.18.transformer_blocks.8.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm1.weight": "blocks.18.transformer_blocks.8.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm2.bias": "blocks.18.transformer_blocks.8.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm2.weight": "blocks.18.transformer_blocks.8.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm3.bias": "blocks.18.transformer_blocks.8.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm3.weight": "blocks.18.transformer_blocks.8.norm3.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_k.weight": "blocks.18.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_out.0.bias": "blocks.18.transformer_blocks.9.attn1.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_out.0.weight": "blocks.18.transformer_blocks.9.attn1.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_q.weight": "blocks.18.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_v.weight": "blocks.18.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_k.weight": "blocks.18.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_out.0.bias": "blocks.18.transformer_blocks.9.attn2.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_out.0.weight": "blocks.18.transformer_blocks.9.attn2.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_q.weight": "blocks.18.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_v.weight": "blocks.18.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.0.proj.bias": "blocks.18.transformer_blocks.9.act_fn.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.0.proj.weight": "blocks.18.transformer_blocks.9.act_fn.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.2.bias": "blocks.18.transformer_blocks.9.ff.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.2.weight": "blocks.18.transformer_blocks.9.ff.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm1.bias": "blocks.18.transformer_blocks.9.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm1.weight": "blocks.18.transformer_blocks.9.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm2.bias": "blocks.18.transformer_blocks.9.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm2.weight": "blocks.18.transformer_blocks.9.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm3.bias": "blocks.18.transformer_blocks.9.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm3.weight": "blocks.18.transformer_blocks.9.norm3.weight", + "model.diffusion_model.label_emb.0.0.bias": "add_time_embedding.0.bias", + "model.diffusion_model.label_emb.0.0.weight": "add_time_embedding.0.weight", + "model.diffusion_model.label_emb.0.2.bias": "add_time_embedding.2.bias", + "model.diffusion_model.label_emb.0.2.weight": "add_time_embedding.2.weight", + "model.diffusion_model.middle_block.0.emb_layers.1.bias": "blocks.20.time_emb_proj.bias", + "model.diffusion_model.middle_block.0.emb_layers.1.weight": "blocks.20.time_emb_proj.weight", + "model.diffusion_model.middle_block.0.in_layers.0.bias": "blocks.20.norm1.bias", + "model.diffusion_model.middle_block.0.in_layers.0.weight": "blocks.20.norm1.weight", + "model.diffusion_model.middle_block.0.in_layers.2.bias": "blocks.20.conv1.bias", + "model.diffusion_model.middle_block.0.in_layers.2.weight": "blocks.20.conv1.weight", + "model.diffusion_model.middle_block.0.out_layers.0.bias": "blocks.20.norm2.bias", + "model.diffusion_model.middle_block.0.out_layers.0.weight": "blocks.20.norm2.weight", + "model.diffusion_model.middle_block.0.out_layers.3.bias": "blocks.20.conv2.bias", + "model.diffusion_model.middle_block.0.out_layers.3.weight": "blocks.20.conv2.weight", + "model.diffusion_model.middle_block.1.norm.bias": "blocks.21.norm.bias", + "model.diffusion_model.middle_block.1.norm.weight": "blocks.21.norm.weight", + "model.diffusion_model.middle_block.1.proj_in.bias": "blocks.21.proj_in.bias", + "model.diffusion_model.middle_block.1.proj_in.weight": "blocks.21.proj_in.weight", + "model.diffusion_model.middle_block.1.proj_out.bias": "blocks.21.proj_out.bias", + "model.diffusion_model.middle_block.1.proj_out.weight": "blocks.21.proj_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.21.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.21.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.21.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.21.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.21.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.21.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.21.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.21.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.21.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.21.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.21.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.21.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.21.transformer_blocks.0.ff.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.21.transformer_blocks.0.ff.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.21.transformer_blocks.0.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.21.transformer_blocks.0.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.21.transformer_blocks.0.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.21.transformer_blocks.0.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.21.transformer_blocks.0.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.21.transformer_blocks.0.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_k.weight": "blocks.21.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.21.transformer_blocks.1.attn1.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.21.transformer_blocks.1.attn1.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_q.weight": "blocks.21.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_v.weight": "blocks.21.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_k.weight": "blocks.21.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.21.transformer_blocks.1.attn2.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.21.transformer_blocks.1.attn2.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_q.weight": "blocks.21.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_v.weight": "blocks.21.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.21.transformer_blocks.1.act_fn.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.21.transformer_blocks.1.act_fn.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.2.bias": "blocks.21.transformer_blocks.1.ff.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.2.weight": "blocks.21.transformer_blocks.1.ff.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm1.bias": "blocks.21.transformer_blocks.1.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm1.weight": "blocks.21.transformer_blocks.1.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm2.bias": "blocks.21.transformer_blocks.1.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm2.weight": "blocks.21.transformer_blocks.1.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm3.bias": "blocks.21.transformer_blocks.1.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.1.norm3.weight": "blocks.21.transformer_blocks.1.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_k.weight": "blocks.21.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_out.0.bias": "blocks.21.transformer_blocks.2.attn1.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_out.0.weight": "blocks.21.transformer_blocks.2.attn1.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_q.weight": "blocks.21.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_v.weight": "blocks.21.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_k.weight": "blocks.21.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_out.0.bias": "blocks.21.transformer_blocks.2.attn2.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_out.0.weight": "blocks.21.transformer_blocks.2.attn2.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_q.weight": "blocks.21.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_v.weight": "blocks.21.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.0.proj.bias": "blocks.21.transformer_blocks.2.act_fn.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.0.proj.weight": "blocks.21.transformer_blocks.2.act_fn.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.2.bias": "blocks.21.transformer_blocks.2.ff.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.2.weight": "blocks.21.transformer_blocks.2.ff.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm1.bias": "blocks.21.transformer_blocks.2.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm1.weight": "blocks.21.transformer_blocks.2.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm2.bias": "blocks.21.transformer_blocks.2.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm2.weight": "blocks.21.transformer_blocks.2.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm3.bias": "blocks.21.transformer_blocks.2.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.2.norm3.weight": "blocks.21.transformer_blocks.2.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_k.weight": "blocks.21.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_out.0.bias": "blocks.21.transformer_blocks.3.attn1.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_out.0.weight": "blocks.21.transformer_blocks.3.attn1.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_q.weight": "blocks.21.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_v.weight": "blocks.21.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_k.weight": "blocks.21.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_out.0.bias": "blocks.21.transformer_blocks.3.attn2.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_out.0.weight": "blocks.21.transformer_blocks.3.attn2.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_q.weight": "blocks.21.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_v.weight": "blocks.21.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.0.proj.bias": "blocks.21.transformer_blocks.3.act_fn.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.0.proj.weight": "blocks.21.transformer_blocks.3.act_fn.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.2.bias": "blocks.21.transformer_blocks.3.ff.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.2.weight": "blocks.21.transformer_blocks.3.ff.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm1.bias": "blocks.21.transformer_blocks.3.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm1.weight": "blocks.21.transformer_blocks.3.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm2.bias": "blocks.21.transformer_blocks.3.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm2.weight": "blocks.21.transformer_blocks.3.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm3.bias": "blocks.21.transformer_blocks.3.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.3.norm3.weight": "blocks.21.transformer_blocks.3.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_k.weight": "blocks.21.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_out.0.bias": "blocks.21.transformer_blocks.4.attn1.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_out.0.weight": "blocks.21.transformer_blocks.4.attn1.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_q.weight": "blocks.21.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_v.weight": "blocks.21.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_k.weight": "blocks.21.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_out.0.bias": "blocks.21.transformer_blocks.4.attn2.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_out.0.weight": "blocks.21.transformer_blocks.4.attn2.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_q.weight": "blocks.21.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_v.weight": "blocks.21.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.0.proj.bias": "blocks.21.transformer_blocks.4.act_fn.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.0.proj.weight": "blocks.21.transformer_blocks.4.act_fn.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.2.bias": "blocks.21.transformer_blocks.4.ff.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.2.weight": "blocks.21.transformer_blocks.4.ff.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.norm1.bias": "blocks.21.transformer_blocks.4.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.norm1.weight": "blocks.21.transformer_blocks.4.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.norm2.bias": "blocks.21.transformer_blocks.4.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.norm2.weight": "blocks.21.transformer_blocks.4.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.4.norm3.bias": "blocks.21.transformer_blocks.4.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.4.norm3.weight": "blocks.21.transformer_blocks.4.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_k.weight": "blocks.21.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_out.0.bias": "blocks.21.transformer_blocks.5.attn1.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_out.0.weight": "blocks.21.transformer_blocks.5.attn1.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_q.weight": "blocks.21.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_v.weight": "blocks.21.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_k.weight": "blocks.21.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_out.0.bias": "blocks.21.transformer_blocks.5.attn2.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_out.0.weight": "blocks.21.transformer_blocks.5.attn2.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_q.weight": "blocks.21.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_v.weight": "blocks.21.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.0.proj.bias": "blocks.21.transformer_blocks.5.act_fn.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.0.proj.weight": "blocks.21.transformer_blocks.5.act_fn.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.2.bias": "blocks.21.transformer_blocks.5.ff.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.2.weight": "blocks.21.transformer_blocks.5.ff.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.norm1.bias": "blocks.21.transformer_blocks.5.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.norm1.weight": "blocks.21.transformer_blocks.5.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.norm2.bias": "blocks.21.transformer_blocks.5.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.norm2.weight": "blocks.21.transformer_blocks.5.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.5.norm3.bias": "blocks.21.transformer_blocks.5.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.5.norm3.weight": "blocks.21.transformer_blocks.5.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_k.weight": "blocks.21.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_out.0.bias": "blocks.21.transformer_blocks.6.attn1.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_out.0.weight": "blocks.21.transformer_blocks.6.attn1.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_q.weight": "blocks.21.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_v.weight": "blocks.21.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_k.weight": "blocks.21.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_out.0.bias": "blocks.21.transformer_blocks.6.attn2.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_out.0.weight": "blocks.21.transformer_blocks.6.attn2.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_q.weight": "blocks.21.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_v.weight": "blocks.21.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.0.proj.bias": "blocks.21.transformer_blocks.6.act_fn.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.0.proj.weight": "blocks.21.transformer_blocks.6.act_fn.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.2.bias": "blocks.21.transformer_blocks.6.ff.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.2.weight": "blocks.21.transformer_blocks.6.ff.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.norm1.bias": "blocks.21.transformer_blocks.6.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.norm1.weight": "blocks.21.transformer_blocks.6.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.norm2.bias": "blocks.21.transformer_blocks.6.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.norm2.weight": "blocks.21.transformer_blocks.6.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.6.norm3.bias": "blocks.21.transformer_blocks.6.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.6.norm3.weight": "blocks.21.transformer_blocks.6.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_k.weight": "blocks.21.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_out.0.bias": "blocks.21.transformer_blocks.7.attn1.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_out.0.weight": "blocks.21.transformer_blocks.7.attn1.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_q.weight": "blocks.21.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_v.weight": "blocks.21.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_k.weight": "blocks.21.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_out.0.bias": "blocks.21.transformer_blocks.7.attn2.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_out.0.weight": "blocks.21.transformer_blocks.7.attn2.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_q.weight": "blocks.21.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_v.weight": "blocks.21.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.0.proj.bias": "blocks.21.transformer_blocks.7.act_fn.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.0.proj.weight": "blocks.21.transformer_blocks.7.act_fn.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.2.bias": "blocks.21.transformer_blocks.7.ff.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.2.weight": "blocks.21.transformer_blocks.7.ff.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.norm1.bias": "blocks.21.transformer_blocks.7.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.norm1.weight": "blocks.21.transformer_blocks.7.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.norm2.bias": "blocks.21.transformer_blocks.7.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.norm2.weight": "blocks.21.transformer_blocks.7.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.7.norm3.bias": "blocks.21.transformer_blocks.7.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.7.norm3.weight": "blocks.21.transformer_blocks.7.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_k.weight": "blocks.21.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_out.0.bias": "blocks.21.transformer_blocks.8.attn1.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_out.0.weight": "blocks.21.transformer_blocks.8.attn1.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_q.weight": "blocks.21.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_v.weight": "blocks.21.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_k.weight": "blocks.21.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_out.0.bias": "blocks.21.transformer_blocks.8.attn2.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_out.0.weight": "blocks.21.transformer_blocks.8.attn2.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_q.weight": "blocks.21.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_v.weight": "blocks.21.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.0.proj.bias": "blocks.21.transformer_blocks.8.act_fn.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.0.proj.weight": "blocks.21.transformer_blocks.8.act_fn.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.2.bias": "blocks.21.transformer_blocks.8.ff.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.2.weight": "blocks.21.transformer_blocks.8.ff.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.norm1.bias": "blocks.21.transformer_blocks.8.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.norm1.weight": "blocks.21.transformer_blocks.8.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.norm2.bias": "blocks.21.transformer_blocks.8.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.norm2.weight": "blocks.21.transformer_blocks.8.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.8.norm3.bias": "blocks.21.transformer_blocks.8.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.8.norm3.weight": "blocks.21.transformer_blocks.8.norm3.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_k.weight": "blocks.21.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_out.0.bias": "blocks.21.transformer_blocks.9.attn1.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_out.0.weight": "blocks.21.transformer_blocks.9.attn1.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_q.weight": "blocks.21.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_v.weight": "blocks.21.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_k.weight": "blocks.21.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_out.0.bias": "blocks.21.transformer_blocks.9.attn2.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_out.0.weight": "blocks.21.transformer_blocks.9.attn2.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_q.weight": "blocks.21.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_v.weight": "blocks.21.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.0.proj.bias": "blocks.21.transformer_blocks.9.act_fn.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.0.proj.weight": "blocks.21.transformer_blocks.9.act_fn.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.2.bias": "blocks.21.transformer_blocks.9.ff.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.2.weight": "blocks.21.transformer_blocks.9.ff.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.norm1.bias": "blocks.21.transformer_blocks.9.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.norm1.weight": "blocks.21.transformer_blocks.9.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.norm2.bias": "blocks.21.transformer_blocks.9.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.norm2.weight": "blocks.21.transformer_blocks.9.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.9.norm3.bias": "blocks.21.transformer_blocks.9.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.9.norm3.weight": "blocks.21.transformer_blocks.9.norm3.weight", + "model.diffusion_model.middle_block.2.emb_layers.1.bias": "blocks.22.time_emb_proj.bias", + "model.diffusion_model.middle_block.2.emb_layers.1.weight": "blocks.22.time_emb_proj.weight", + "model.diffusion_model.middle_block.2.in_layers.0.bias": "blocks.22.norm1.bias", + "model.diffusion_model.middle_block.2.in_layers.0.weight": "blocks.22.norm1.weight", + "model.diffusion_model.middle_block.2.in_layers.2.bias": "blocks.22.conv1.bias", + "model.diffusion_model.middle_block.2.in_layers.2.weight": "blocks.22.conv1.weight", + "model.diffusion_model.middle_block.2.out_layers.0.bias": "blocks.22.norm2.bias", + "model.diffusion_model.middle_block.2.out_layers.0.weight": "blocks.22.norm2.weight", + "model.diffusion_model.middle_block.2.out_layers.3.bias": "blocks.22.conv2.bias", + "model.diffusion_model.middle_block.2.out_layers.3.weight": "blocks.22.conv2.weight", + "model.diffusion_model.out.0.bias": "conv_norm_out.bias", + "model.diffusion_model.out.0.weight": "conv_norm_out.weight", + "model.diffusion_model.out.2.bias": "conv_out.bias", + "model.diffusion_model.out.2.weight": "conv_out.weight", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "blocks.24.norm1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "blocks.24.norm1.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "blocks.24.conv1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "blocks.24.conv1.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "blocks.24.norm2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "blocks.24.norm2.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "blocks.24.conv2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "blocks.24.conv2.weight", + "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "blocks.24.conv_shortcut.bias", + "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "blocks.24.conv_shortcut.weight", + "model.diffusion_model.output_blocks.0.1.norm.bias": "blocks.25.norm.bias", + "model.diffusion_model.output_blocks.0.1.norm.weight": "blocks.25.norm.weight", + "model.diffusion_model.output_blocks.0.1.proj_in.bias": "blocks.25.proj_in.bias", + "model.diffusion_model.output_blocks.0.1.proj_in.weight": "blocks.25.proj_in.weight", + "model.diffusion_model.output_blocks.0.1.proj_out.bias": "blocks.25.proj_out.bias", + "model.diffusion_model.output_blocks.0.1.proj_out.weight": "blocks.25.proj_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_k.weight": "blocks.25.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.25.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.25.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_q.weight": "blocks.25.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_v.weight": "blocks.25.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_k.weight": "blocks.25.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.25.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.25.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_q.weight": "blocks.25.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_v.weight": "blocks.25.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.25.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.25.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.bias": "blocks.25.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.weight": "blocks.25.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.bias": "blocks.25.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.weight": "blocks.25.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.bias": "blocks.25.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.weight": "blocks.25.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.bias": "blocks.25.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.weight": "blocks.25.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_k.weight": "blocks.25.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.25.transformer_blocks.1.attn1.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.25.transformer_blocks.1.attn1.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_q.weight": "blocks.25.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_v.weight": "blocks.25.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_k.weight": "blocks.25.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.25.transformer_blocks.1.attn2.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.25.transformer_blocks.1.attn2.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_q.weight": "blocks.25.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_v.weight": "blocks.25.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.25.transformer_blocks.1.act_fn.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.25.transformer_blocks.1.act_fn.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.bias": "blocks.25.transformer_blocks.1.ff.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.weight": "blocks.25.transformer_blocks.1.ff.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.bias": "blocks.25.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.weight": "blocks.25.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.bias": "blocks.25.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.weight": "blocks.25.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.bias": "blocks.25.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.weight": "blocks.25.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_k.weight": "blocks.25.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_out.0.bias": "blocks.25.transformer_blocks.2.attn1.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_out.0.weight": "blocks.25.transformer_blocks.2.attn1.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_q.weight": "blocks.25.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_v.weight": "blocks.25.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_k.weight": "blocks.25.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_out.0.bias": "blocks.25.transformer_blocks.2.attn2.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_out.0.weight": "blocks.25.transformer_blocks.2.attn2.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_q.weight": "blocks.25.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_v.weight": "blocks.25.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.0.proj.bias": "blocks.25.transformer_blocks.2.act_fn.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.0.proj.weight": "blocks.25.transformer_blocks.2.act_fn.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.2.bias": "blocks.25.transformer_blocks.2.ff.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.2.weight": "blocks.25.transformer_blocks.2.ff.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm1.bias": "blocks.25.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm1.weight": "blocks.25.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm2.bias": "blocks.25.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm2.weight": "blocks.25.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm3.bias": "blocks.25.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm3.weight": "blocks.25.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_k.weight": "blocks.25.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_out.0.bias": "blocks.25.transformer_blocks.3.attn1.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_out.0.weight": "blocks.25.transformer_blocks.3.attn1.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_q.weight": "blocks.25.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_v.weight": "blocks.25.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_k.weight": "blocks.25.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_out.0.bias": "blocks.25.transformer_blocks.3.attn2.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_out.0.weight": "blocks.25.transformer_blocks.3.attn2.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_q.weight": "blocks.25.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_v.weight": "blocks.25.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.0.proj.bias": "blocks.25.transformer_blocks.3.act_fn.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.0.proj.weight": "blocks.25.transformer_blocks.3.act_fn.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.2.bias": "blocks.25.transformer_blocks.3.ff.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.2.weight": "blocks.25.transformer_blocks.3.ff.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm1.bias": "blocks.25.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm1.weight": "blocks.25.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm2.bias": "blocks.25.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm2.weight": "blocks.25.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm3.bias": "blocks.25.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm3.weight": "blocks.25.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_k.weight": "blocks.25.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_out.0.bias": "blocks.25.transformer_blocks.4.attn1.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_out.0.weight": "blocks.25.transformer_blocks.4.attn1.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_q.weight": "blocks.25.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_v.weight": "blocks.25.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_k.weight": "blocks.25.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_out.0.bias": "blocks.25.transformer_blocks.4.attn2.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_out.0.weight": "blocks.25.transformer_blocks.4.attn2.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_q.weight": "blocks.25.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_v.weight": "blocks.25.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.0.proj.bias": "blocks.25.transformer_blocks.4.act_fn.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.0.proj.weight": "blocks.25.transformer_blocks.4.act_fn.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.2.bias": "blocks.25.transformer_blocks.4.ff.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.2.weight": "blocks.25.transformer_blocks.4.ff.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm1.bias": "blocks.25.transformer_blocks.4.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm1.weight": "blocks.25.transformer_blocks.4.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm2.bias": "blocks.25.transformer_blocks.4.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm2.weight": "blocks.25.transformer_blocks.4.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm3.bias": "blocks.25.transformer_blocks.4.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm3.weight": "blocks.25.transformer_blocks.4.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_k.weight": "blocks.25.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_out.0.bias": "blocks.25.transformer_blocks.5.attn1.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_out.0.weight": "blocks.25.transformer_blocks.5.attn1.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_q.weight": "blocks.25.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_v.weight": "blocks.25.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_k.weight": "blocks.25.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_out.0.bias": "blocks.25.transformer_blocks.5.attn2.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_out.0.weight": "blocks.25.transformer_blocks.5.attn2.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_q.weight": "blocks.25.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_v.weight": "blocks.25.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.0.proj.bias": "blocks.25.transformer_blocks.5.act_fn.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.0.proj.weight": "blocks.25.transformer_blocks.5.act_fn.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.2.bias": "blocks.25.transformer_blocks.5.ff.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.2.weight": "blocks.25.transformer_blocks.5.ff.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm1.bias": "blocks.25.transformer_blocks.5.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm1.weight": "blocks.25.transformer_blocks.5.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm2.bias": "blocks.25.transformer_blocks.5.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm2.weight": "blocks.25.transformer_blocks.5.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm3.bias": "blocks.25.transformer_blocks.5.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm3.weight": "blocks.25.transformer_blocks.5.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_k.weight": "blocks.25.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_out.0.bias": "blocks.25.transformer_blocks.6.attn1.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_out.0.weight": "blocks.25.transformer_blocks.6.attn1.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_q.weight": "blocks.25.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_v.weight": "blocks.25.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_k.weight": "blocks.25.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_out.0.bias": "blocks.25.transformer_blocks.6.attn2.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_out.0.weight": "blocks.25.transformer_blocks.6.attn2.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_q.weight": "blocks.25.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_v.weight": "blocks.25.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.0.proj.bias": "blocks.25.transformer_blocks.6.act_fn.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.0.proj.weight": "blocks.25.transformer_blocks.6.act_fn.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.2.bias": "blocks.25.transformer_blocks.6.ff.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.2.weight": "blocks.25.transformer_blocks.6.ff.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm1.bias": "blocks.25.transformer_blocks.6.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm1.weight": "blocks.25.transformer_blocks.6.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm2.bias": "blocks.25.transformer_blocks.6.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm2.weight": "blocks.25.transformer_blocks.6.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm3.bias": "blocks.25.transformer_blocks.6.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm3.weight": "blocks.25.transformer_blocks.6.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_k.weight": "blocks.25.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_out.0.bias": "blocks.25.transformer_blocks.7.attn1.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_out.0.weight": "blocks.25.transformer_blocks.7.attn1.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_q.weight": "blocks.25.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_v.weight": "blocks.25.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_k.weight": "blocks.25.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_out.0.bias": "blocks.25.transformer_blocks.7.attn2.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_out.0.weight": "blocks.25.transformer_blocks.7.attn2.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_q.weight": "blocks.25.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_v.weight": "blocks.25.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.0.proj.bias": "blocks.25.transformer_blocks.7.act_fn.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.0.proj.weight": "blocks.25.transformer_blocks.7.act_fn.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.2.bias": "blocks.25.transformer_blocks.7.ff.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.2.weight": "blocks.25.transformer_blocks.7.ff.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm1.bias": "blocks.25.transformer_blocks.7.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm1.weight": "blocks.25.transformer_blocks.7.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm2.bias": "blocks.25.transformer_blocks.7.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm2.weight": "blocks.25.transformer_blocks.7.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm3.bias": "blocks.25.transformer_blocks.7.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm3.weight": "blocks.25.transformer_blocks.7.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_k.weight": "blocks.25.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_out.0.bias": "blocks.25.transformer_blocks.8.attn1.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_out.0.weight": "blocks.25.transformer_blocks.8.attn1.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_q.weight": "blocks.25.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_v.weight": "blocks.25.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_k.weight": "blocks.25.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_out.0.bias": "blocks.25.transformer_blocks.8.attn2.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_out.0.weight": "blocks.25.transformer_blocks.8.attn2.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_q.weight": "blocks.25.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_v.weight": "blocks.25.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.0.proj.bias": "blocks.25.transformer_blocks.8.act_fn.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.0.proj.weight": "blocks.25.transformer_blocks.8.act_fn.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.2.bias": "blocks.25.transformer_blocks.8.ff.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.2.weight": "blocks.25.transformer_blocks.8.ff.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm1.bias": "blocks.25.transformer_blocks.8.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm1.weight": "blocks.25.transformer_blocks.8.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm2.bias": "blocks.25.transformer_blocks.8.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm2.weight": "blocks.25.transformer_blocks.8.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm3.bias": "blocks.25.transformer_blocks.8.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm3.weight": "blocks.25.transformer_blocks.8.norm3.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_k.weight": "blocks.25.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_out.0.bias": "blocks.25.transformer_blocks.9.attn1.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_out.0.weight": "blocks.25.transformer_blocks.9.attn1.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_q.weight": "blocks.25.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_v.weight": "blocks.25.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_k.weight": "blocks.25.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_out.0.bias": "blocks.25.transformer_blocks.9.attn2.to_out.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_out.0.weight": "blocks.25.transformer_blocks.9.attn2.to_out.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_q.weight": "blocks.25.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_v.weight": "blocks.25.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.0.proj.bias": "blocks.25.transformer_blocks.9.act_fn.proj.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.0.proj.weight": "blocks.25.transformer_blocks.9.act_fn.proj.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.2.bias": "blocks.25.transformer_blocks.9.ff.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.2.weight": "blocks.25.transformer_blocks.9.ff.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm1.bias": "blocks.25.transformer_blocks.9.norm1.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm1.weight": "blocks.25.transformer_blocks.9.norm1.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm2.bias": "blocks.25.transformer_blocks.9.norm2.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm2.weight": "blocks.25.transformer_blocks.9.norm2.weight", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm3.bias": "blocks.25.transformer_blocks.9.norm3.bias", + "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm3.weight": "blocks.25.transformer_blocks.9.norm3.weight", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "blocks.27.time_emb_proj.bias", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "blocks.27.time_emb_proj.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "blocks.27.norm1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "blocks.27.norm1.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "blocks.27.conv1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "blocks.27.conv1.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "blocks.27.norm2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "blocks.27.norm2.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "blocks.27.conv2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "blocks.27.conv2.weight", + "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "blocks.27.conv_shortcut.bias", + "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "blocks.27.conv_shortcut.weight", + "model.diffusion_model.output_blocks.1.1.norm.bias": "blocks.28.norm.bias", + "model.diffusion_model.output_blocks.1.1.norm.weight": "blocks.28.norm.weight", + "model.diffusion_model.output_blocks.1.1.proj_in.bias": "blocks.28.proj_in.bias", + "model.diffusion_model.output_blocks.1.1.proj_in.weight": "blocks.28.proj_in.weight", + "model.diffusion_model.output_blocks.1.1.proj_out.bias": "blocks.28.proj_out.bias", + "model.diffusion_model.output_blocks.1.1.proj_out.weight": "blocks.28.proj_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.28.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.28.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.28.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.28.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.28.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.28.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.28.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.28.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.28.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.28.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.28.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.28.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.28.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.28.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.28.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.28.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.28.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.28.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.28.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.28.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_k.weight": "blocks.28.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.28.transformer_blocks.1.attn1.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.28.transformer_blocks.1.attn1.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_q.weight": "blocks.28.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_v.weight": "blocks.28.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_k.weight": "blocks.28.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.28.transformer_blocks.1.attn2.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.28.transformer_blocks.1.attn2.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_q.weight": "blocks.28.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_v.weight": "blocks.28.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.28.transformer_blocks.1.act_fn.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.28.transformer_blocks.1.act_fn.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.bias": "blocks.28.transformer_blocks.1.ff.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.weight": "blocks.28.transformer_blocks.1.ff.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.bias": "blocks.28.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.weight": "blocks.28.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.bias": "blocks.28.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.weight": "blocks.28.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.bias": "blocks.28.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.weight": "blocks.28.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_k.weight": "blocks.28.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_out.0.bias": "blocks.28.transformer_blocks.2.attn1.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_out.0.weight": "blocks.28.transformer_blocks.2.attn1.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_q.weight": "blocks.28.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_v.weight": "blocks.28.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_k.weight": "blocks.28.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_out.0.bias": "blocks.28.transformer_blocks.2.attn2.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_out.0.weight": "blocks.28.transformer_blocks.2.attn2.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_q.weight": "blocks.28.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_v.weight": "blocks.28.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.0.proj.bias": "blocks.28.transformer_blocks.2.act_fn.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.0.proj.weight": "blocks.28.transformer_blocks.2.act_fn.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.2.bias": "blocks.28.transformer_blocks.2.ff.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.2.weight": "blocks.28.transformer_blocks.2.ff.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm1.bias": "blocks.28.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm1.weight": "blocks.28.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm2.bias": "blocks.28.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm2.weight": "blocks.28.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm3.bias": "blocks.28.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm3.weight": "blocks.28.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_k.weight": "blocks.28.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_out.0.bias": "blocks.28.transformer_blocks.3.attn1.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_out.0.weight": "blocks.28.transformer_blocks.3.attn1.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_q.weight": "blocks.28.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_v.weight": "blocks.28.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_k.weight": "blocks.28.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_out.0.bias": "blocks.28.transformer_blocks.3.attn2.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_out.0.weight": "blocks.28.transformer_blocks.3.attn2.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_q.weight": "blocks.28.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_v.weight": "blocks.28.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.0.proj.bias": "blocks.28.transformer_blocks.3.act_fn.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.0.proj.weight": "blocks.28.transformer_blocks.3.act_fn.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.2.bias": "blocks.28.transformer_blocks.3.ff.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.2.weight": "blocks.28.transformer_blocks.3.ff.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm1.bias": "blocks.28.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm1.weight": "blocks.28.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm2.bias": "blocks.28.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm2.weight": "blocks.28.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm3.bias": "blocks.28.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm3.weight": "blocks.28.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_k.weight": "blocks.28.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_out.0.bias": "blocks.28.transformer_blocks.4.attn1.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_out.0.weight": "blocks.28.transformer_blocks.4.attn1.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_q.weight": "blocks.28.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_v.weight": "blocks.28.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_k.weight": "blocks.28.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_out.0.bias": "blocks.28.transformer_blocks.4.attn2.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_out.0.weight": "blocks.28.transformer_blocks.4.attn2.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_q.weight": "blocks.28.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_v.weight": "blocks.28.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.0.proj.bias": "blocks.28.transformer_blocks.4.act_fn.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.0.proj.weight": "blocks.28.transformer_blocks.4.act_fn.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.2.bias": "blocks.28.transformer_blocks.4.ff.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.2.weight": "blocks.28.transformer_blocks.4.ff.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm1.bias": "blocks.28.transformer_blocks.4.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm1.weight": "blocks.28.transformer_blocks.4.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm2.bias": "blocks.28.transformer_blocks.4.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm2.weight": "blocks.28.transformer_blocks.4.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm3.bias": "blocks.28.transformer_blocks.4.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm3.weight": "blocks.28.transformer_blocks.4.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_k.weight": "blocks.28.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_out.0.bias": "blocks.28.transformer_blocks.5.attn1.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_out.0.weight": "blocks.28.transformer_blocks.5.attn1.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_q.weight": "blocks.28.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_v.weight": "blocks.28.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_k.weight": "blocks.28.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_out.0.bias": "blocks.28.transformer_blocks.5.attn2.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_out.0.weight": "blocks.28.transformer_blocks.5.attn2.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_q.weight": "blocks.28.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_v.weight": "blocks.28.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.0.proj.bias": "blocks.28.transformer_blocks.5.act_fn.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.0.proj.weight": "blocks.28.transformer_blocks.5.act_fn.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.2.bias": "blocks.28.transformer_blocks.5.ff.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.2.weight": "blocks.28.transformer_blocks.5.ff.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm1.bias": "blocks.28.transformer_blocks.5.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm1.weight": "blocks.28.transformer_blocks.5.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm2.bias": "blocks.28.transformer_blocks.5.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm2.weight": "blocks.28.transformer_blocks.5.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm3.bias": "blocks.28.transformer_blocks.5.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm3.weight": "blocks.28.transformer_blocks.5.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_k.weight": "blocks.28.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_out.0.bias": "blocks.28.transformer_blocks.6.attn1.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_out.0.weight": "blocks.28.transformer_blocks.6.attn1.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_q.weight": "blocks.28.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_v.weight": "blocks.28.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_k.weight": "blocks.28.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_out.0.bias": "blocks.28.transformer_blocks.6.attn2.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_out.0.weight": "blocks.28.transformer_blocks.6.attn2.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_q.weight": "blocks.28.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_v.weight": "blocks.28.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.0.proj.bias": "blocks.28.transformer_blocks.6.act_fn.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.0.proj.weight": "blocks.28.transformer_blocks.6.act_fn.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.2.bias": "blocks.28.transformer_blocks.6.ff.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.2.weight": "blocks.28.transformer_blocks.6.ff.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm1.bias": "blocks.28.transformer_blocks.6.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm1.weight": "blocks.28.transformer_blocks.6.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm2.bias": "blocks.28.transformer_blocks.6.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm2.weight": "blocks.28.transformer_blocks.6.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm3.bias": "blocks.28.transformer_blocks.6.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm3.weight": "blocks.28.transformer_blocks.6.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_k.weight": "blocks.28.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_out.0.bias": "blocks.28.transformer_blocks.7.attn1.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_out.0.weight": "blocks.28.transformer_blocks.7.attn1.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_q.weight": "blocks.28.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_v.weight": "blocks.28.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_k.weight": "blocks.28.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_out.0.bias": "blocks.28.transformer_blocks.7.attn2.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_out.0.weight": "blocks.28.transformer_blocks.7.attn2.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_q.weight": "blocks.28.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_v.weight": "blocks.28.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.0.proj.bias": "blocks.28.transformer_blocks.7.act_fn.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.0.proj.weight": "blocks.28.transformer_blocks.7.act_fn.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.2.bias": "blocks.28.transformer_blocks.7.ff.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.2.weight": "blocks.28.transformer_blocks.7.ff.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm1.bias": "blocks.28.transformer_blocks.7.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm1.weight": "blocks.28.transformer_blocks.7.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm2.bias": "blocks.28.transformer_blocks.7.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm2.weight": "blocks.28.transformer_blocks.7.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm3.bias": "blocks.28.transformer_blocks.7.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm3.weight": "blocks.28.transformer_blocks.7.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_k.weight": "blocks.28.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_out.0.bias": "blocks.28.transformer_blocks.8.attn1.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_out.0.weight": "blocks.28.transformer_blocks.8.attn1.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_q.weight": "blocks.28.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_v.weight": "blocks.28.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_k.weight": "blocks.28.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_out.0.bias": "blocks.28.transformer_blocks.8.attn2.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_out.0.weight": "blocks.28.transformer_blocks.8.attn2.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_q.weight": "blocks.28.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_v.weight": "blocks.28.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.0.proj.bias": "blocks.28.transformer_blocks.8.act_fn.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.0.proj.weight": "blocks.28.transformer_blocks.8.act_fn.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.2.bias": "blocks.28.transformer_blocks.8.ff.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.2.weight": "blocks.28.transformer_blocks.8.ff.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm1.bias": "blocks.28.transformer_blocks.8.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm1.weight": "blocks.28.transformer_blocks.8.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm2.bias": "blocks.28.transformer_blocks.8.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm2.weight": "blocks.28.transformer_blocks.8.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm3.bias": "blocks.28.transformer_blocks.8.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm3.weight": "blocks.28.transformer_blocks.8.norm3.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_k.weight": "blocks.28.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_out.0.bias": "blocks.28.transformer_blocks.9.attn1.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_out.0.weight": "blocks.28.transformer_blocks.9.attn1.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_q.weight": "blocks.28.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_v.weight": "blocks.28.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_k.weight": "blocks.28.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_out.0.bias": "blocks.28.transformer_blocks.9.attn2.to_out.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_out.0.weight": "blocks.28.transformer_blocks.9.attn2.to_out.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_q.weight": "blocks.28.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_v.weight": "blocks.28.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.0.proj.bias": "blocks.28.transformer_blocks.9.act_fn.proj.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.0.proj.weight": "blocks.28.transformer_blocks.9.act_fn.proj.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.2.bias": "blocks.28.transformer_blocks.9.ff.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.2.weight": "blocks.28.transformer_blocks.9.ff.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm1.bias": "blocks.28.transformer_blocks.9.norm1.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm1.weight": "blocks.28.transformer_blocks.9.norm1.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm2.bias": "blocks.28.transformer_blocks.9.norm2.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm2.weight": "blocks.28.transformer_blocks.9.norm2.weight", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm3.bias": "blocks.28.transformer_blocks.9.norm3.bias", + "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm3.weight": "blocks.28.transformer_blocks.9.norm3.weight", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "blocks.30.time_emb_proj.bias", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "blocks.30.time_emb_proj.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "blocks.30.norm1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "blocks.30.norm1.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "blocks.30.conv1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "blocks.30.conv1.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "blocks.30.norm2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "blocks.30.norm2.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "blocks.30.conv2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "blocks.30.conv2.weight", + "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "blocks.30.conv_shortcut.bias", + "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "blocks.30.conv_shortcut.weight", + "model.diffusion_model.output_blocks.2.1.norm.bias": "blocks.31.norm.bias", + "model.diffusion_model.output_blocks.2.1.norm.weight": "blocks.31.norm.weight", + "model.diffusion_model.output_blocks.2.1.proj_in.bias": "blocks.31.proj_in.bias", + "model.diffusion_model.output_blocks.2.1.proj_in.weight": "blocks.31.proj_in.weight", + "model.diffusion_model.output_blocks.2.1.proj_out.bias": "blocks.31.proj_out.bias", + "model.diffusion_model.output_blocks.2.1.proj_out.weight": "blocks.31.proj_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.31.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.31.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.31.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.31.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.31.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.31.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.31.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.31.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.31.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.31.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.31.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.31.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.31.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.31.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.31.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.31.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.31.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.31.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.31.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.31.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_k.weight": "blocks.31.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.31.transformer_blocks.1.attn1.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.31.transformer_blocks.1.attn1.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_q.weight": "blocks.31.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_v.weight": "blocks.31.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_k.weight": "blocks.31.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.31.transformer_blocks.1.attn2.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.31.transformer_blocks.1.attn2.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_q.weight": "blocks.31.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_v.weight": "blocks.31.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.31.transformer_blocks.1.act_fn.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.31.transformer_blocks.1.act_fn.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.bias": "blocks.31.transformer_blocks.1.ff.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.weight": "blocks.31.transformer_blocks.1.ff.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.bias": "blocks.31.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.weight": "blocks.31.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.bias": "blocks.31.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.weight": "blocks.31.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.bias": "blocks.31.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.weight": "blocks.31.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_k.weight": "blocks.31.transformer_blocks.2.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_out.0.bias": "blocks.31.transformer_blocks.2.attn1.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_out.0.weight": "blocks.31.transformer_blocks.2.attn1.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_q.weight": "blocks.31.transformer_blocks.2.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_v.weight": "blocks.31.transformer_blocks.2.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_k.weight": "blocks.31.transformer_blocks.2.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_out.0.bias": "blocks.31.transformer_blocks.2.attn2.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_out.0.weight": "blocks.31.transformer_blocks.2.attn2.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_q.weight": "blocks.31.transformer_blocks.2.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_v.weight": "blocks.31.transformer_blocks.2.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.0.proj.bias": "blocks.31.transformer_blocks.2.act_fn.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.0.proj.weight": "blocks.31.transformer_blocks.2.act_fn.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.2.bias": "blocks.31.transformer_blocks.2.ff.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.2.weight": "blocks.31.transformer_blocks.2.ff.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm1.bias": "blocks.31.transformer_blocks.2.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm1.weight": "blocks.31.transformer_blocks.2.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm2.bias": "blocks.31.transformer_blocks.2.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm2.weight": "blocks.31.transformer_blocks.2.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm3.bias": "blocks.31.transformer_blocks.2.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm3.weight": "blocks.31.transformer_blocks.2.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_k.weight": "blocks.31.transformer_blocks.3.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_out.0.bias": "blocks.31.transformer_blocks.3.attn1.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_out.0.weight": "blocks.31.transformer_blocks.3.attn1.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_q.weight": "blocks.31.transformer_blocks.3.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_v.weight": "blocks.31.transformer_blocks.3.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_k.weight": "blocks.31.transformer_blocks.3.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_out.0.bias": "blocks.31.transformer_blocks.3.attn2.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_out.0.weight": "blocks.31.transformer_blocks.3.attn2.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_q.weight": "blocks.31.transformer_blocks.3.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_v.weight": "blocks.31.transformer_blocks.3.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.0.proj.bias": "blocks.31.transformer_blocks.3.act_fn.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.0.proj.weight": "blocks.31.transformer_blocks.3.act_fn.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.2.bias": "blocks.31.transformer_blocks.3.ff.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.2.weight": "blocks.31.transformer_blocks.3.ff.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm1.bias": "blocks.31.transformer_blocks.3.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm1.weight": "blocks.31.transformer_blocks.3.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm2.bias": "blocks.31.transformer_blocks.3.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm2.weight": "blocks.31.transformer_blocks.3.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm3.bias": "blocks.31.transformer_blocks.3.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm3.weight": "blocks.31.transformer_blocks.3.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_k.weight": "blocks.31.transformer_blocks.4.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_out.0.bias": "blocks.31.transformer_blocks.4.attn1.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_out.0.weight": "blocks.31.transformer_blocks.4.attn1.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_q.weight": "blocks.31.transformer_blocks.4.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_v.weight": "blocks.31.transformer_blocks.4.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_k.weight": "blocks.31.transformer_blocks.4.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_out.0.bias": "blocks.31.transformer_blocks.4.attn2.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_out.0.weight": "blocks.31.transformer_blocks.4.attn2.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_q.weight": "blocks.31.transformer_blocks.4.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_v.weight": "blocks.31.transformer_blocks.4.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.0.proj.bias": "blocks.31.transformer_blocks.4.act_fn.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.0.proj.weight": "blocks.31.transformer_blocks.4.act_fn.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.2.bias": "blocks.31.transformer_blocks.4.ff.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.2.weight": "blocks.31.transformer_blocks.4.ff.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm1.bias": "blocks.31.transformer_blocks.4.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm1.weight": "blocks.31.transformer_blocks.4.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm2.bias": "blocks.31.transformer_blocks.4.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm2.weight": "blocks.31.transformer_blocks.4.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm3.bias": "blocks.31.transformer_blocks.4.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm3.weight": "blocks.31.transformer_blocks.4.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_k.weight": "blocks.31.transformer_blocks.5.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_out.0.bias": "blocks.31.transformer_blocks.5.attn1.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_out.0.weight": "blocks.31.transformer_blocks.5.attn1.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_q.weight": "blocks.31.transformer_blocks.5.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_v.weight": "blocks.31.transformer_blocks.5.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_k.weight": "blocks.31.transformer_blocks.5.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_out.0.bias": "blocks.31.transformer_blocks.5.attn2.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_out.0.weight": "blocks.31.transformer_blocks.5.attn2.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_q.weight": "blocks.31.transformer_blocks.5.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_v.weight": "blocks.31.transformer_blocks.5.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.0.proj.bias": "blocks.31.transformer_blocks.5.act_fn.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.0.proj.weight": "blocks.31.transformer_blocks.5.act_fn.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.2.bias": "blocks.31.transformer_blocks.5.ff.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.2.weight": "blocks.31.transformer_blocks.5.ff.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm1.bias": "blocks.31.transformer_blocks.5.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm1.weight": "blocks.31.transformer_blocks.5.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm2.bias": "blocks.31.transformer_blocks.5.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm2.weight": "blocks.31.transformer_blocks.5.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm3.bias": "blocks.31.transformer_blocks.5.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm3.weight": "blocks.31.transformer_blocks.5.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_k.weight": "blocks.31.transformer_blocks.6.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_out.0.bias": "blocks.31.transformer_blocks.6.attn1.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_out.0.weight": "blocks.31.transformer_blocks.6.attn1.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_q.weight": "blocks.31.transformer_blocks.6.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_v.weight": "blocks.31.transformer_blocks.6.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_k.weight": "blocks.31.transformer_blocks.6.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_out.0.bias": "blocks.31.transformer_blocks.6.attn2.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_out.0.weight": "blocks.31.transformer_blocks.6.attn2.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_q.weight": "blocks.31.transformer_blocks.6.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_v.weight": "blocks.31.transformer_blocks.6.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.0.proj.bias": "blocks.31.transformer_blocks.6.act_fn.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.0.proj.weight": "blocks.31.transformer_blocks.6.act_fn.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.2.bias": "blocks.31.transformer_blocks.6.ff.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.2.weight": "blocks.31.transformer_blocks.6.ff.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm1.bias": "blocks.31.transformer_blocks.6.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm1.weight": "blocks.31.transformer_blocks.6.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm2.bias": "blocks.31.transformer_blocks.6.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm2.weight": "blocks.31.transformer_blocks.6.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm3.bias": "blocks.31.transformer_blocks.6.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm3.weight": "blocks.31.transformer_blocks.6.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_k.weight": "blocks.31.transformer_blocks.7.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_out.0.bias": "blocks.31.transformer_blocks.7.attn1.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_out.0.weight": "blocks.31.transformer_blocks.7.attn1.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_q.weight": "blocks.31.transformer_blocks.7.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_v.weight": "blocks.31.transformer_blocks.7.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_k.weight": "blocks.31.transformer_blocks.7.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_out.0.bias": "blocks.31.transformer_blocks.7.attn2.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_out.0.weight": "blocks.31.transformer_blocks.7.attn2.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_q.weight": "blocks.31.transformer_blocks.7.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_v.weight": "blocks.31.transformer_blocks.7.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.0.proj.bias": "blocks.31.transformer_blocks.7.act_fn.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.0.proj.weight": "blocks.31.transformer_blocks.7.act_fn.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.2.bias": "blocks.31.transformer_blocks.7.ff.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.2.weight": "blocks.31.transformer_blocks.7.ff.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm1.bias": "blocks.31.transformer_blocks.7.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm1.weight": "blocks.31.transformer_blocks.7.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm2.bias": "blocks.31.transformer_blocks.7.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm2.weight": "blocks.31.transformer_blocks.7.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm3.bias": "blocks.31.transformer_blocks.7.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm3.weight": "blocks.31.transformer_blocks.7.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_k.weight": "blocks.31.transformer_blocks.8.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_out.0.bias": "blocks.31.transformer_blocks.8.attn1.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_out.0.weight": "blocks.31.transformer_blocks.8.attn1.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_q.weight": "blocks.31.transformer_blocks.8.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_v.weight": "blocks.31.transformer_blocks.8.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_k.weight": "blocks.31.transformer_blocks.8.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_out.0.bias": "blocks.31.transformer_blocks.8.attn2.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_out.0.weight": "blocks.31.transformer_blocks.8.attn2.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_q.weight": "blocks.31.transformer_blocks.8.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_v.weight": "blocks.31.transformer_blocks.8.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.0.proj.bias": "blocks.31.transformer_blocks.8.act_fn.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.0.proj.weight": "blocks.31.transformer_blocks.8.act_fn.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.2.bias": "blocks.31.transformer_blocks.8.ff.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.2.weight": "blocks.31.transformer_blocks.8.ff.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm1.bias": "blocks.31.transformer_blocks.8.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm1.weight": "blocks.31.transformer_blocks.8.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm2.bias": "blocks.31.transformer_blocks.8.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm2.weight": "blocks.31.transformer_blocks.8.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm3.bias": "blocks.31.transformer_blocks.8.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm3.weight": "blocks.31.transformer_blocks.8.norm3.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_k.weight": "blocks.31.transformer_blocks.9.attn1.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_out.0.bias": "blocks.31.transformer_blocks.9.attn1.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_out.0.weight": "blocks.31.transformer_blocks.9.attn1.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_q.weight": "blocks.31.transformer_blocks.9.attn1.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_v.weight": "blocks.31.transformer_blocks.9.attn1.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_k.weight": "blocks.31.transformer_blocks.9.attn2.to_k.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_out.0.bias": "blocks.31.transformer_blocks.9.attn2.to_out.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_out.0.weight": "blocks.31.transformer_blocks.9.attn2.to_out.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_q.weight": "blocks.31.transformer_blocks.9.attn2.to_q.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_v.weight": "blocks.31.transformer_blocks.9.attn2.to_v.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.0.proj.bias": "blocks.31.transformer_blocks.9.act_fn.proj.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.0.proj.weight": "blocks.31.transformer_blocks.9.act_fn.proj.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.2.bias": "blocks.31.transformer_blocks.9.ff.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.2.weight": "blocks.31.transformer_blocks.9.ff.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm1.bias": "blocks.31.transformer_blocks.9.norm1.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm1.weight": "blocks.31.transformer_blocks.9.norm1.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm2.bias": "blocks.31.transformer_blocks.9.norm2.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm2.weight": "blocks.31.transformer_blocks.9.norm2.weight", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm3.bias": "blocks.31.transformer_blocks.9.norm3.bias", + "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm3.weight": "blocks.31.transformer_blocks.9.norm3.weight", + "model.diffusion_model.output_blocks.2.2.conv.bias": "blocks.32.conv.bias", + "model.diffusion_model.output_blocks.2.2.conv.weight": "blocks.32.conv.weight", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "blocks.34.time_emb_proj.bias", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "blocks.34.time_emb_proj.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "blocks.34.norm1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "blocks.34.norm1.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "blocks.34.conv1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "blocks.34.conv1.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "blocks.34.norm2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "blocks.34.norm2.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "blocks.34.conv2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "blocks.34.conv2.weight", + "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "blocks.34.conv_shortcut.bias", + "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "blocks.34.conv_shortcut.weight", + "model.diffusion_model.output_blocks.3.1.norm.bias": "blocks.35.norm.bias", + "model.diffusion_model.output_blocks.3.1.norm.weight": "blocks.35.norm.weight", + "model.diffusion_model.output_blocks.3.1.proj_in.bias": "blocks.35.proj_in.bias", + "model.diffusion_model.output_blocks.3.1.proj_in.weight": "blocks.35.proj_in.weight", + "model.diffusion_model.output_blocks.3.1.proj_out.bias": "blocks.35.proj_out.bias", + "model.diffusion_model.output_blocks.3.1.proj_out.weight": "blocks.35.proj_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "blocks.35.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.35.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.35.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "blocks.35.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "blocks.35.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "blocks.35.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.35.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.35.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "blocks.35.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "blocks.35.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.35.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.35.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "blocks.35.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "blocks.35.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "blocks.35.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "blocks.35.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "blocks.35.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "blocks.35.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "blocks.35.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "blocks.35.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_k.weight": "blocks.35.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.35.transformer_blocks.1.attn1.to_out.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.35.transformer_blocks.1.attn1.to_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_q.weight": "blocks.35.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_v.weight": "blocks.35.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_k.weight": "blocks.35.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.35.transformer_blocks.1.attn2.to_out.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.35.transformer_blocks.1.attn2.to_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_q.weight": "blocks.35.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_v.weight": "blocks.35.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.35.transformer_blocks.1.act_fn.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.35.transformer_blocks.1.act_fn.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.bias": "blocks.35.transformer_blocks.1.ff.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.weight": "blocks.35.transformer_blocks.1.ff.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.bias": "blocks.35.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.weight": "blocks.35.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.bias": "blocks.35.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.weight": "blocks.35.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.bias": "blocks.35.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.weight": "blocks.35.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "blocks.37.time_emb_proj.bias", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "blocks.37.time_emb_proj.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "blocks.37.norm1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "blocks.37.norm1.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "blocks.37.conv1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "blocks.37.conv1.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "blocks.37.norm2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "blocks.37.norm2.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "blocks.37.conv2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "blocks.37.conv2.weight", + "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "blocks.37.conv_shortcut.bias", + "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "blocks.37.conv_shortcut.weight", + "model.diffusion_model.output_blocks.4.1.norm.bias": "blocks.38.norm.bias", + "model.diffusion_model.output_blocks.4.1.norm.weight": "blocks.38.norm.weight", + "model.diffusion_model.output_blocks.4.1.proj_in.bias": "blocks.38.proj_in.bias", + "model.diffusion_model.output_blocks.4.1.proj_in.weight": "blocks.38.proj_in.weight", + "model.diffusion_model.output_blocks.4.1.proj_out.bias": "blocks.38.proj_out.bias", + "model.diffusion_model.output_blocks.4.1.proj_out.weight": "blocks.38.proj_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.38.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.38.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.38.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.38.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.38.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.38.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.38.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.38.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.38.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.38.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.38.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.38.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.38.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.38.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.38.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.38.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.38.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.38.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.38.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.38.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "blocks.38.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.38.transformer_blocks.1.attn1.to_out.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.38.transformer_blocks.1.attn1.to_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "blocks.38.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "blocks.38.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "blocks.38.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.38.transformer_blocks.1.attn2.to_out.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.38.transformer_blocks.1.attn2.to_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "blocks.38.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "blocks.38.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.38.transformer_blocks.1.act_fn.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.38.transformer_blocks.1.act_fn.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "blocks.38.transformer_blocks.1.ff.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "blocks.38.transformer_blocks.1.ff.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm1.bias": "blocks.38.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm1.weight": "blocks.38.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm2.bias": "blocks.38.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm2.weight": "blocks.38.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm3.bias": "blocks.38.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm3.weight": "blocks.38.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "blocks.40.time_emb_proj.bias", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "blocks.40.time_emb_proj.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "blocks.40.norm1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "blocks.40.norm1.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "blocks.40.conv1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "blocks.40.conv1.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "blocks.40.norm2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "blocks.40.norm2.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "blocks.40.conv2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "blocks.40.conv2.weight", + "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "blocks.40.conv_shortcut.bias", + "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "blocks.40.conv_shortcut.weight", + "model.diffusion_model.output_blocks.5.1.norm.bias": "blocks.41.norm.bias", + "model.diffusion_model.output_blocks.5.1.norm.weight": "blocks.41.norm.weight", + "model.diffusion_model.output_blocks.5.1.proj_in.bias": "blocks.41.proj_in.bias", + "model.diffusion_model.output_blocks.5.1.proj_in.weight": "blocks.41.proj_in.weight", + "model.diffusion_model.output_blocks.5.1.proj_out.bias": "blocks.41.proj_out.bias", + "model.diffusion_model.output_blocks.5.1.proj_out.weight": "blocks.41.proj_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.41.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.41.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.41.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.41.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.41.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.41.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.41.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.41.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.41.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.41.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.41.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.41.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.41.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.41.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.41.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.41.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.41.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.41.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.41.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.41.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "blocks.41.transformer_blocks.1.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.41.transformer_blocks.1.attn1.to_out.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.41.transformer_blocks.1.attn1.to_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "blocks.41.transformer_blocks.1.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "blocks.41.transformer_blocks.1.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "blocks.41.transformer_blocks.1.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.41.transformer_blocks.1.attn2.to_out.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.41.transformer_blocks.1.attn2.to_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "blocks.41.transformer_blocks.1.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "blocks.41.transformer_blocks.1.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.41.transformer_blocks.1.act_fn.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.41.transformer_blocks.1.act_fn.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "blocks.41.transformer_blocks.1.ff.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "blocks.41.transformer_blocks.1.ff.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm1.bias": "blocks.41.transformer_blocks.1.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm1.weight": "blocks.41.transformer_blocks.1.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm2.bias": "blocks.41.transformer_blocks.1.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm2.weight": "blocks.41.transformer_blocks.1.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm3.bias": "blocks.41.transformer_blocks.1.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm3.weight": "blocks.41.transformer_blocks.1.norm3.weight", + "model.diffusion_model.output_blocks.5.2.conv.bias": "blocks.42.conv.bias", + "model.diffusion_model.output_blocks.5.2.conv.weight": "blocks.42.conv.weight", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "blocks.44.time_emb_proj.bias", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "blocks.44.time_emb_proj.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "blocks.44.norm1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "blocks.44.norm1.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "blocks.44.conv1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "blocks.44.conv1.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "blocks.44.norm2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "blocks.44.norm2.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "blocks.44.conv2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "blocks.44.conv2.weight", + "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "blocks.44.conv_shortcut.bias", + "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "blocks.44.conv_shortcut.weight", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "blocks.46.time_emb_proj.bias", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "blocks.46.time_emb_proj.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "blocks.46.norm1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "blocks.46.norm1.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "blocks.46.conv1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "blocks.46.conv1.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "blocks.46.norm2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "blocks.46.norm2.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "blocks.46.conv2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "blocks.46.conv2.weight", + "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "blocks.46.conv_shortcut.bias", + "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "blocks.46.conv_shortcut.weight", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "blocks.48.time_emb_proj.bias", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "blocks.48.time_emb_proj.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "blocks.48.norm1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "blocks.48.norm1.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "blocks.48.conv1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "blocks.48.conv1.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "blocks.48.norm2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "blocks.48.norm2.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "blocks.48.conv2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "blocks.48.conv2.weight", + "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "blocks.48.conv_shortcut.bias", + "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "blocks.48.conv_shortcut.weight", + "model.diffusion_model.time_embed.0.bias": "time_embedding.0.bias", + "model.diffusion_model.time_embed.0.weight": "time_embedding.0.weight", + "model.diffusion_model.time_embed.2.bias": "time_embedding.2.bias", + "model.diffusion_model.time_embed.2.weight": "time_embedding.2.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if ".proj_in." in name or ".proj_out." in name: + param = param.squeeze() + state_dict_[rename_dict[name]] = param + return state_dict_ diff --git a/diffsynth/models/sdxl_vae_decoder.py b/diffsynth/models/sdxl_vae_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c191f52d4fb7ef5478a396159c6767111d364d2f --- /dev/null +++ b/diffsynth/models/sdxl_vae_decoder.py @@ -0,0 +1,15 @@ +from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter + + +class SDXLVAEDecoder(SDVAEDecoder): + def __init__(self): + super().__init__() + self.scaling_factor = 0.13025 + + def state_dict_converter(self): + return SDXLVAEDecoderStateDictConverter() + + +class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter): + def __init__(self): + super().__init__() diff --git a/diffsynth/models/sdxl_vae_encoder.py b/diffsynth/models/sdxl_vae_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..36b131debb28d157f15584951bdae8416c979b16 --- /dev/null +++ b/diffsynth/models/sdxl_vae_encoder.py @@ -0,0 +1,15 @@ +from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder + + +class SDXLVAEEncoder(SDVAEEncoder): + def __init__(self): + super().__init__() + self.scaling_factor = 0.13025 + + def state_dict_converter(self): + return SDXLVAEEncoderStateDictConverter() + + +class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter): + def __init__(self): + super().__init__() diff --git a/diffsynth/models/svd_image_encoder.py b/diffsynth/models/svd_image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c3aa32fbb00bf55d0bb59e7edae797f24e8fea42 --- /dev/null +++ b/diffsynth/models/svd_image_encoder.py @@ -0,0 +1,504 @@ +import torch +from .sd_text_encoder import CLIPEncoderLayer + + +class CLIPVisionEmbeddings(torch.nn.Module): + def __init__(self, embed_dim=1280, image_size=224, patch_size=14, num_channels=3): + super().__init__() + + # class_embeds (This is a fixed tensor) + self.class_embedding = torch.nn.Parameter(torch.randn(1, 1, embed_dim)) + + # position_embeds + self.patch_embedding = torch.nn.Conv2d(in_channels=num_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, bias=False) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, (image_size // patch_size) ** 2 + 1, embed_dim)) + + def forward(self, pixel_values): + batch_size = pixel_values.shape[0] + patch_embeds = self.patch_embedding(pixel_values) + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.repeat(batch_size, 1, 1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + self.position_embeds + return embeddings + + +class SVDImageEncoder(torch.nn.Module): + def __init__(self, embed_dim=1280, layer_norm_eps=1e-5, num_encoder_layers=32, encoder_intermediate_size=5120, projection_dim=1024, num_heads=16, head_dim=80): + super().__init__() + self.embeddings = CLIPVisionEmbeddings(embed_dim=embed_dim) + self.pre_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps) + self.encoders = torch.nn.ModuleList([ + CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=num_heads, head_dim=head_dim, use_quick_gelu=False) + for _ in range(num_encoder_layers)]) + self.post_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps) + self.visual_projection = torch.nn.Linear(embed_dim, projection_dim, bias=False) + + def forward(self, pixel_values): + embeds = self.embeddings(pixel_values) + embeds = self.pre_layernorm(embeds) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds) + embeds = self.post_layernorm(embeds[:, 0, :]) + embeds = self.visual_projection(embeds) + return embeds + + def state_dict_converter(self): + return SVDImageEncoderStateDictConverter() + + +class SVDImageEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "vision_model.embeddings.patch_embedding.weight": "embeddings.patch_embedding.weight", + "vision_model.embeddings.class_embedding": "embeddings.class_embedding", + "vision_model.embeddings.position_embedding.weight": "embeddings.position_embeds", + "vision_model.pre_layrnorm.weight": "pre_layernorm.weight", + "vision_model.pre_layrnorm.bias": "pre_layernorm.bias", + "vision_model.post_layernorm.weight": "post_layernorm.weight", + "vision_model.post_layernorm.bias": "post_layernorm.bias", + "visual_projection.weight": "visual_projection.weight" + } + attn_rename_dict = { + "self_attn.q_proj": "attn.to_q", + "self_attn.k_proj": "attn.to_k", + "self_attn.v_proj": "attn.to_v", + "self_attn.out_proj": "attn.to_out", + "layer_norm1": "layer_norm1", + "layer_norm2": "layer_norm2", + "mlp.fc1": "fc1", + "mlp.fc2": "fc2", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "vision_model.embeddings.class_embedding": + param = state_dict[name].view(1, 1, -1) + elif name == "vision_model.embeddings.position_embedding.weight": + param = state_dict[name].unsqueeze(0) + state_dict_[rename_dict[name]] = param + elif name.startswith("vision_model.encoder.layers."): + param = state_dict[name] + names = name.split(".") + layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] + name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) + state_dict_[name_] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "conditioner.embedders.0.open_clip.model.visual.class_embedding": "embeddings.class_embedding", + "conditioner.embedders.0.open_clip.model.visual.conv1.weight": "embeddings.patch_embedding.weight", + "conditioner.embedders.0.open_clip.model.visual.ln_post.bias": "post_layernorm.bias", + "conditioner.embedders.0.open_clip.model.visual.ln_post.weight": "post_layernorm.weight", + "conditioner.embedders.0.open_clip.model.visual.ln_pre.bias": "pre_layernorm.bias", + "conditioner.embedders.0.open_clip.model.visual.ln_pre.weight": "pre_layernorm.weight", + "conditioner.embedders.0.open_clip.model.visual.positional_embedding": "embeddings.position_embeds", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'], + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias", + "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight", + "conditioner.embedders.0.open_clip.model.visual.proj": "visual_projection.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "conditioner.embedders.0.open_clip.model.visual.class_embedding": + param = param.reshape((1, 1, param.shape[0])) + elif name == "conditioner.embedders.0.open_clip.model.visual.positional_embedding": + param = param.reshape((1, param.shape[0], param.shape[1])) + elif name == "conditioner.embedders.0.open_clip.model.visual.proj": + param = param.T + if isinstance(rename_dict[name], str): + state_dict_[rename_dict[name]] = param + else: + length = param.shape[0] // 3 + for i, rename in enumerate(rename_dict[name]): + state_dict_[rename] = param[i*length: i*length+length] + return state_dict_ diff --git a/diffsynth/models/svd_unet.py b/diffsynth/models/svd_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..4bab0832f13f3eaee8f4fe0e18c7b23866e9fa5b --- /dev/null +++ b/diffsynth/models/svd_unet.py @@ -0,0 +1,2003 @@ +import torch, math +from einops import rearrange, repeat +from .sd_unet import Timesteps, PushBlock, PopBlock, Attention, GEGLU, ResnetBlock, AttentionBlock, DownSampler, UpSampler + + +class TemporalResnetBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5): + super().__init__() + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + self.conv1 = torch.nn.Conv3d(in_channels, out_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0)) + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + self.conv2 = torch.nn.Conv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0)) + self.nonlinearity = torch.nn.SiLU() + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True) + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + x = rearrange(hidden_states, "f c h w -> 1 c f h w") + x = self.norm1(x) + x = self.nonlinearity(x) + x = self.conv1(x) + if time_emb is not None: + emb = self.nonlinearity(time_emb) + emb = self.time_emb_proj(emb) + emb = repeat(emb, "b c -> b c f 1 1", f=hidden_states.shape[0]) + x = x + emb + x = self.norm2(x) + x = self.nonlinearity(x) + x = self.conv2(x) + if self.conv_shortcut is not None: + hidden_states = self.conv_shortcut(hidden_states) + x = rearrange(x[0], "c f h w -> f c h w") + hidden_states = hidden_states + x + return hidden_states, time_emb, text_emb, res_stack + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TemporalTimesteps(torch.nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class TrainableTemporalTimesteps(torch.nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, num_frames: int): + super().__init__() + timesteps = PositionalID()(num_frames) + embeddings = get_timestep_embedding(timesteps, num_channels, flip_sin_to_cos, downscale_freq_shift) + self.embeddings = torch.nn.Parameter(embeddings) + + def forward(self, timesteps): + t_emb = self.embeddings[timesteps] + return t_emb + + +class PositionalID(torch.nn.Module): + def __init__(self, max_id=25, repeat_length=20): + super().__init__() + self.max_id = max_id + self.repeat_length = repeat_length + + def frame_id_to_position_id(self, frame_id): + if frame_id < self.max_id: + position_id = frame_id + else: + position_id = (frame_id - self.max_id) % (self.repeat_length * 2) + if position_id < self.repeat_length: + position_id = self.max_id - 2 - position_id + else: + position_id = self.max_id - 2 * self.repeat_length + position_id + return position_id + + def forward(self, num_frames, pivot_frame_id=0): + position_ids = [self.frame_id_to_position_id(abs(i-pivot_frame_id)) for i in range(num_frames)] + position_ids = torch.IntTensor(position_ids) + return position_ids + + +class TemporalAttentionBlock(torch.nn.Module): + + def __init__(self, num_attention_heads, attention_head_dim, in_channels, cross_attention_dim=None, add_positional_conv=None): + super().__init__() + + self.positional_embedding_proj = torch.nn.Sequential( + torch.nn.Linear(in_channels, in_channels * 4), + torch.nn.SiLU(), + torch.nn.Linear(in_channels * 4, in_channels) + ) + if add_positional_conv is not None: + self.positional_embedding = TrainableTemporalTimesteps(in_channels, True, 0, add_positional_conv) + self.positional_conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1, padding_mode="reflect") + else: + self.positional_embedding = TemporalTimesteps(in_channels, True, 0) + self.positional_conv = None + + self.norm_in = torch.nn.LayerNorm(in_channels) + self.act_fn_in = GEGLU(in_channels, in_channels * 4) + self.ff_in = torch.nn.Linear(in_channels * 4, in_channels) + + self.norm1 = torch.nn.LayerNorm(in_channels) + self.attn1 = Attention( + q_dim=in_channels, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + bias_out=True + ) + + self.norm2 = torch.nn.LayerNorm(in_channels) + self.attn2 = Attention( + q_dim=in_channels, + kv_dim=cross_attention_dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + bias_out=True + ) + + self.norm_out = torch.nn.LayerNorm(in_channels) + self.act_fn_out = GEGLU(in_channels, in_channels * 4) + self.ff_out = torch.nn.Linear(in_channels * 4, in_channels) + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + + batch, inner_dim, height, width = hidden_states.shape + pos_emb = torch.arange(batch) + pos_emb = self.positional_embedding(pos_emb).to(dtype=hidden_states.dtype, device=hidden_states.device) + pos_emb = self.positional_embedding_proj(pos_emb) + + hidden_states = rearrange(hidden_states, "T C H W -> 1 C T H W") + rearrange(pos_emb, "T C -> 1 C T 1 1") + if self.positional_conv is not None: + hidden_states = self.positional_conv(hidden_states) + hidden_states = rearrange(hidden_states[0], "C T H W -> (H W) T C") + + residual = hidden_states + hidden_states = self.norm_in(hidden_states) + hidden_states = self.act_fn_in(hidden_states) + hidden_states = self.ff_in(hidden_states) + hidden_states = hidden_states + residual + + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) + hidden_states = attn_output + hidden_states + + norm_hidden_states = self.norm2(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=text_emb.repeat(height * width, 1)) + hidden_states = attn_output + hidden_states + + residual = hidden_states + hidden_states = self.norm_out(hidden_states) + hidden_states = self.act_fn_out(hidden_states) + hidden_states = self.ff_out(hidden_states) + hidden_states = hidden_states + residual + + hidden_states = hidden_states.reshape(height, width, batch, inner_dim).permute(2, 3, 0, 1) + + return hidden_states, time_emb, text_emb, res_stack + + +class PopMixBlock(torch.nn.Module): + def __init__(self, in_channels=None): + super().__init__() + self.mix_factor = torch.nn.Parameter(torch.Tensor([0.5])) + self.need_proj = in_channels is not None + if self.need_proj: + self.proj = torch.nn.Linear(in_channels, in_channels) + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + res_hidden_states = res_stack.pop() + alpha = torch.sigmoid(self.mix_factor) + hidden_states = alpha * res_hidden_states + (1 - alpha) * hidden_states + if self.need_proj: + hidden_states = hidden_states.permute(0, 2, 3, 1) + hidden_states = self.proj(hidden_states) + hidden_states = hidden_states.permute(0, 3, 1, 2) + res_hidden_states = res_stack.pop() + hidden_states = hidden_states + res_hidden_states + return hidden_states, time_emb, text_emb, res_stack + + +class SVDUNet(torch.nn.Module): + def __init__(self, add_positional_conv=None): + super().__init__() + self.time_proj = Timesteps(320) + self.time_embedding = torch.nn.Sequential( + torch.nn.Linear(320, 1280), + torch.nn.SiLU(), + torch.nn.Linear(1280, 1280) + ) + self.add_time_proj = Timesteps(256) + self.add_time_embedding = torch.nn.Sequential( + torch.nn.Linear(768, 1280), + torch.nn.SiLU(), + torch.nn.Linear(1280, 1280) + ) + self.conv_in = torch.nn.Conv2d(8, 320, kernel_size=3, padding=1) + + self.blocks = torch.nn.ModuleList([ + # CrossAttnDownBlockSpatioTemporal + ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), PushBlock(), + ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), PushBlock(), + DownSampler(320), PushBlock(), + # CrossAttnDownBlockSpatioTemporal + ResnetBlock(320, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), PushBlock(), + ResnetBlock(640, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), PushBlock(), + DownSampler(640), PushBlock(), + # CrossAttnDownBlockSpatioTemporal + ResnetBlock(640, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), PushBlock(), + ResnetBlock(1280, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), PushBlock(), + DownSampler(1280), PushBlock(), + # DownBlockSpatioTemporal + ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(), + ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(), + # UNetMidBlockSpatioTemporal + ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(), + AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), + ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), + # UpBlockSpatioTemporal + PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), + PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), + PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), + UpSampler(1280), + # CrossAttnUpBlockSpatioTemporal + PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), + PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), + PopBlock(), ResnetBlock(1920, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), + UpSampler(1280), + # CrossAttnUpBlockSpatioTemporal + PopBlock(), ResnetBlock(1920, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), + PopBlock(), ResnetBlock(1280, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), + PopBlock(), ResnetBlock(960, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), + UpSampler(640), + # CrossAttnUpBlockSpatioTemporal + PopBlock(), ResnetBlock(960, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), + PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), + PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), + ]) + + self.conv_norm_out = torch.nn.GroupNorm(32, 320, eps=1e-05, affine=True) + self.conv_act = torch.nn.SiLU() + self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + + + def build_mask(self, data, is_bound): + T, C, H, W = data.shape + t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W) + h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W) + w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W) + border_width = (T + H + W) // 6 + pad = torch.ones_like(t) * border_width + mask = torch.stack([ + pad if is_bound[0] else t + 1, + pad if is_bound[1] else T - t, + pad if is_bound[2] else h + 1, + pad if is_bound[3] else H - h, + pad if is_bound[4] else w + 1, + pad if is_bound[5] else W - w + ]).min(dim=0).values + mask = mask.clip(1, border_width) + mask = (mask / border_width).to(dtype=data.dtype, device=data.device) + mask = rearrange(mask, "T H W -> T 1 H W") + return mask + + + def tiled_forward( + self, sample, timestep, encoder_hidden_states, add_time_id, + batch_time=25, batch_height=128, batch_width=128, + stride_time=5, stride_height=64, stride_width=64, + progress_bar=lambda x:x + ): + data_device = sample.device + computation_device = self.conv_in.weight.device + torch_dtype = sample.dtype + T, C, H, W = sample.shape + + weight = torch.zeros((T, 1, H, W), dtype=torch_dtype, device=data_device) + values = torch.zeros((T, 4, H, W), dtype=torch_dtype, device=data_device) + + # Split tasks + tasks = [] + for t in range(0, T, stride_time): + for h in range(0, H, stride_height): + for w in range(0, W, stride_width): + if (t-stride_time >= 0 and t-stride_time+batch_time >= T)\ + or (h-stride_height >= 0 and h-stride_height+batch_height >= H)\ + or (w-stride_width >= 0 and w-stride_width+batch_width >= W): + continue + tasks.append((t, t+batch_time, h, h+batch_height, w, w+batch_width)) + + # Run + for tl, tr, hl, hr, wl, wr in progress_bar(tasks): + sample_batch = sample[tl:tr, :, hl:hr, wl:wr].to(computation_device) + sample_batch = self.forward(sample_batch, timestep, encoder_hidden_states, add_time_id).to(data_device) + mask = self.build_mask(sample_batch, is_bound=(tl==0, tr>=T, hl==0, hr>=H, wl==0, wr>=W)) + values[tl:tr, :, hl:hr, wl:wr] += sample_batch * mask + weight[tl:tr, :, hl:hr, wl:wr] += mask + values /= weight + return values + + + def forward(self, sample, timestep, encoder_hidden_states, add_time_id, use_gradient_checkpointing=False, **kwargs): + # 1. time + timestep = torch.tensor((timestep,)).to(sample.device) + t_emb = self.time_proj(timestep).to(sample.dtype) + t_emb = self.time_embedding(t_emb) + + add_embeds = self.add_time_proj(add_time_id.flatten()).to(sample.dtype) + add_embeds = add_embeds.reshape((-1, 768)) + add_embeds = self.add_time_embedding(add_embeds) + + time_emb = t_emb + add_embeds + + # 2. pre-process + height, width = sample.shape[2], sample.shape[3] + hidden_states = self.conv_in(sample) + text_emb = encoder_hidden_states + res_stack = [hidden_states] + + # 3. blocks + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + for i, block in enumerate(self.blocks): + if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock) or isinstance(block, PopMixBlock)): + hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, time_emb, text_emb, res_stack, + use_reentrant=False, + ) + else: + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + + # 4. output + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + def state_dict_converter(self): + return SVDUNetStateDictConverter() + + + +class SVDUNetStateDictConverter: + def __init__(self): + pass + + def get_block_name(self, names): + if names[0] in ["down_blocks", "mid_block", "up_blocks"]: + if names[4] in ["norm", "proj_in"]: + return ".".join(names[:4] + ["transformer_blocks"]) + elif names[4] in ["time_pos_embed"]: + return ".".join(names[:4] + ["temporal_transformer_blocks"]) + elif names[4] in ["proj_out"]: + return ".".join(names[:4] + ["time_mixer"]) + else: + return ".".join(names[:5]) + return "" + + def from_diffusers(self, state_dict): + rename_dict = { + "time_embedding.linear_1": "time_embedding.0", + "time_embedding.linear_2": "time_embedding.2", + "add_embedding.linear_1": "add_time_embedding.0", + "add_embedding.linear_2": "add_time_embedding.2", + "conv_in": "conv_in", + "conv_norm_out": "conv_norm_out", + "conv_out": "conv_out", + } + blocks_rename_dict = [ + "down_blocks.0.resnets.0.spatial_res_block", None, "down_blocks.0.resnets.0.temporal_res_block", "down_blocks.0.resnets.0.time_mixer", None, + "down_blocks.0.attentions.0.transformer_blocks", None, "down_blocks.0.attentions.0.temporal_transformer_blocks", "down_blocks.0.attentions.0.time_mixer", None, + "down_blocks.0.resnets.1.spatial_res_block", None, "down_blocks.0.resnets.1.temporal_res_block", "down_blocks.0.resnets.1.time_mixer", None, + "down_blocks.0.attentions.1.transformer_blocks", None, "down_blocks.0.attentions.1.temporal_transformer_blocks", "down_blocks.0.attentions.1.time_mixer", None, + "down_blocks.0.downsamplers.0.conv", None, + "down_blocks.1.resnets.0.spatial_res_block", None, "down_blocks.1.resnets.0.temporal_res_block", "down_blocks.1.resnets.0.time_mixer", None, + "down_blocks.1.attentions.0.transformer_blocks", None, "down_blocks.1.attentions.0.temporal_transformer_blocks", "down_blocks.1.attentions.0.time_mixer", None, + "down_blocks.1.resnets.1.spatial_res_block", None, "down_blocks.1.resnets.1.temporal_res_block", "down_blocks.1.resnets.1.time_mixer", None, + "down_blocks.1.attentions.1.transformer_blocks", None, "down_blocks.1.attentions.1.temporal_transformer_blocks", "down_blocks.1.attentions.1.time_mixer", None, + "down_blocks.1.downsamplers.0.conv", None, + "down_blocks.2.resnets.0.spatial_res_block", None, "down_blocks.2.resnets.0.temporal_res_block", "down_blocks.2.resnets.0.time_mixer", None, + "down_blocks.2.attentions.0.transformer_blocks", None, "down_blocks.2.attentions.0.temporal_transformer_blocks", "down_blocks.2.attentions.0.time_mixer", None, + "down_blocks.2.resnets.1.spatial_res_block", None, "down_blocks.2.resnets.1.temporal_res_block", "down_blocks.2.resnets.1.time_mixer", None, + "down_blocks.2.attentions.1.transformer_blocks", None, "down_blocks.2.attentions.1.temporal_transformer_blocks", "down_blocks.2.attentions.1.time_mixer", None, + "down_blocks.2.downsamplers.0.conv", None, + "down_blocks.3.resnets.0.spatial_res_block", None, "down_blocks.3.resnets.0.temporal_res_block", "down_blocks.3.resnets.0.time_mixer", None, + "down_blocks.3.resnets.1.spatial_res_block", None, "down_blocks.3.resnets.1.temporal_res_block", "down_blocks.3.resnets.1.time_mixer", None, + "mid_block.mid_block.resnets.0.spatial_res_block", None, "mid_block.mid_block.resnets.0.temporal_res_block", "mid_block.mid_block.resnets.0.time_mixer", None, + "mid_block.mid_block.attentions.0.transformer_blocks", None, "mid_block.mid_block.attentions.0.temporal_transformer_blocks", "mid_block.mid_block.attentions.0.time_mixer", + "mid_block.mid_block.resnets.1.spatial_res_block", None, "mid_block.mid_block.resnets.1.temporal_res_block", "mid_block.mid_block.resnets.1.time_mixer", + None, "up_blocks.0.resnets.0.spatial_res_block", None, "up_blocks.0.resnets.0.temporal_res_block", "up_blocks.0.resnets.0.time_mixer", + None, "up_blocks.0.resnets.1.spatial_res_block", None, "up_blocks.0.resnets.1.temporal_res_block", "up_blocks.0.resnets.1.time_mixer", + None, "up_blocks.0.resnets.2.spatial_res_block", None, "up_blocks.0.resnets.2.temporal_res_block", "up_blocks.0.resnets.2.time_mixer", + "up_blocks.0.upsamplers.0.conv", + None, "up_blocks.1.resnets.0.spatial_res_block", None, "up_blocks.1.resnets.0.temporal_res_block", "up_blocks.1.resnets.0.time_mixer", None, + "up_blocks.1.attentions.0.transformer_blocks", None, "up_blocks.1.attentions.0.temporal_transformer_blocks", "up_blocks.1.attentions.0.time_mixer", + None, "up_blocks.1.resnets.1.spatial_res_block", None, "up_blocks.1.resnets.1.temporal_res_block", "up_blocks.1.resnets.1.time_mixer", None, + "up_blocks.1.attentions.1.transformer_blocks", None, "up_blocks.1.attentions.1.temporal_transformer_blocks", "up_blocks.1.attentions.1.time_mixer", + None, "up_blocks.1.resnets.2.spatial_res_block", None, "up_blocks.1.resnets.2.temporal_res_block", "up_blocks.1.resnets.2.time_mixer", None, + "up_blocks.1.attentions.2.transformer_blocks", None, "up_blocks.1.attentions.2.temporal_transformer_blocks", "up_blocks.1.attentions.2.time_mixer", + "up_blocks.1.upsamplers.0.conv", + None, "up_blocks.2.resnets.0.spatial_res_block", None, "up_blocks.2.resnets.0.temporal_res_block", "up_blocks.2.resnets.0.time_mixer", None, + "up_blocks.2.attentions.0.transformer_blocks", None, "up_blocks.2.attentions.0.temporal_transformer_blocks", "up_blocks.2.attentions.0.time_mixer", + None, "up_blocks.2.resnets.1.spatial_res_block", None, "up_blocks.2.resnets.1.temporal_res_block", "up_blocks.2.resnets.1.time_mixer", None, + "up_blocks.2.attentions.1.transformer_blocks", None, "up_blocks.2.attentions.1.temporal_transformer_blocks", "up_blocks.2.attentions.1.time_mixer", + None, "up_blocks.2.resnets.2.spatial_res_block", None, "up_blocks.2.resnets.2.temporal_res_block", "up_blocks.2.resnets.2.time_mixer", None, + "up_blocks.2.attentions.2.transformer_blocks", None, "up_blocks.2.attentions.2.temporal_transformer_blocks", "up_blocks.2.attentions.2.time_mixer", + "up_blocks.2.upsamplers.0.conv", + None, "up_blocks.3.resnets.0.spatial_res_block", None, "up_blocks.3.resnets.0.temporal_res_block", "up_blocks.3.resnets.0.time_mixer", None, + "up_blocks.3.attentions.0.transformer_blocks", None, "up_blocks.3.attentions.0.temporal_transformer_blocks", "up_blocks.3.attentions.0.time_mixer", + None, "up_blocks.3.resnets.1.spatial_res_block", None, "up_blocks.3.resnets.1.temporal_res_block", "up_blocks.3.resnets.1.time_mixer", None, + "up_blocks.3.attentions.1.transformer_blocks", None, "up_blocks.3.attentions.1.temporal_transformer_blocks", "up_blocks.3.attentions.1.time_mixer", + None, "up_blocks.3.resnets.2.spatial_res_block", None, "up_blocks.3.resnets.2.temporal_res_block", "up_blocks.3.resnets.2.time_mixer", None, + "up_blocks.3.attentions.2.transformer_blocks", None, "up_blocks.3.attentions.2.temporal_transformer_blocks", "up_blocks.3.attentions.2.time_mixer", + ] + blocks_rename_dict = {i:j for j,i in enumerate(blocks_rename_dict) if i is not None} + state_dict_ = {} + for name, param in sorted(state_dict.items()): + names = name.split(".") + if names[0] == "mid_block": + names = ["mid_block"] + names + if names[-1] in ["weight", "bias"]: + name_prefix = ".".join(names[:-1]) + if name_prefix in rename_dict: + state_dict_[rename_dict[name_prefix] + "." + names[-1]] = param + else: + block_name = self.get_block_name(names) + if "resnets" in block_name and block_name in blocks_rename_dict: + rename = ".".join(["blocks", str(blocks_rename_dict[block_name])] + names[5:]) + state_dict_[rename] = param + elif ("downsamplers" in block_name or "upsamplers" in block_name) and block_name in blocks_rename_dict: + rename = ".".join(["blocks", str(blocks_rename_dict[block_name])] + names[-2:]) + state_dict_[rename] = param + elif "attentions" in block_name and block_name in blocks_rename_dict: + attention_id = names[5] + if "transformer_blocks" in names: + suffix_dict = { + "attn1.to_out.0": "attn1.to_out", + "attn2.to_out.0": "attn2.to_out", + "ff.net.0.proj": "act_fn.proj", + "ff.net.2": "ff", + } + suffix = ".".join(names[6:-1]) + suffix = suffix_dict.get(suffix, suffix) + rename = ".".join(["blocks", str(blocks_rename_dict[block_name]), "transformer_blocks", attention_id, suffix, names[-1]]) + elif "temporal_transformer_blocks" in names: + suffix_dict = { + "attn1.to_out.0": "attn1.to_out", + "attn2.to_out.0": "attn2.to_out", + "ff_in.net.0.proj": "act_fn_in.proj", + "ff_in.net.2": "ff_in", + "ff.net.0.proj": "act_fn_out.proj", + "ff.net.2": "ff_out", + "norm3": "norm_out", + } + suffix = ".".join(names[6:-1]) + suffix = suffix_dict.get(suffix, suffix) + rename = ".".join(["blocks", str(blocks_rename_dict[block_name]), suffix, names[-1]]) + elif "time_mixer" in block_name: + rename = ".".join(["blocks", str(blocks_rename_dict[block_name]), "proj", names[-1]]) + else: + suffix_dict = { + "linear_1": "positional_embedding_proj.0", + "linear_2": "positional_embedding_proj.2", + } + suffix = names[-2] + suffix = suffix_dict.get(suffix, suffix) + rename = ".".join(["blocks", str(blocks_rename_dict[block_name]), suffix, names[-1]]) + state_dict_[rename] = param + else: + print(name) + else: + block_name = self.get_block_name(names) + if len(block_name)>0 and block_name in blocks_rename_dict: + rename = ".".join(["blocks", str(blocks_rename_dict[block_name]), names[-1]]) + state_dict_[rename] = param + return state_dict_ + + + def from_civitai(self, state_dict, add_positional_conv=None): + rename_dict = { + "model.diffusion_model.input_blocks.0.0.bias": "conv_in.bias", + "model.diffusion_model.input_blocks.0.0.weight": "conv_in.weight", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias", + "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias", + "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias", + "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight", + "model.diffusion_model.input_blocks.1.0.time_mixer.mix_factor": "blocks.3.mix_factor", + "model.diffusion_model.input_blocks.1.0.time_stack.emb_layers.1.bias": "blocks.2.time_emb_proj.bias", + "model.diffusion_model.input_blocks.1.0.time_stack.emb_layers.1.weight": "blocks.2.time_emb_proj.weight", + "model.diffusion_model.input_blocks.1.0.time_stack.in_layers.0.bias": "blocks.2.norm1.bias", + "model.diffusion_model.input_blocks.1.0.time_stack.in_layers.0.weight": "blocks.2.norm1.weight", + "model.diffusion_model.input_blocks.1.0.time_stack.in_layers.2.bias": "blocks.2.conv1.bias", + "model.diffusion_model.input_blocks.1.0.time_stack.in_layers.2.weight": "blocks.2.conv1.weight", + "model.diffusion_model.input_blocks.1.0.time_stack.out_layers.0.bias": "blocks.2.norm2.bias", + "model.diffusion_model.input_blocks.1.0.time_stack.out_layers.0.weight": "blocks.2.norm2.weight", + "model.diffusion_model.input_blocks.1.0.time_stack.out_layers.3.bias": "blocks.2.conv2.bias", + "model.diffusion_model.input_blocks.1.0.time_stack.out_layers.3.weight": "blocks.2.conv2.weight", + "model.diffusion_model.input_blocks.1.1.norm.bias": "blocks.5.norm.bias", + "model.diffusion_model.input_blocks.1.1.norm.weight": "blocks.5.norm.weight", + "model.diffusion_model.input_blocks.1.1.proj_in.bias": "blocks.5.proj_in.bias", + "model.diffusion_model.input_blocks.1.1.proj_in.weight": "blocks.5.proj_in.weight", + "model.diffusion_model.input_blocks.1.1.proj_out.bias": "blocks.8.proj.bias", + "model.diffusion_model.input_blocks.1.1.proj_out.weight": "blocks.8.proj.weight", + "model.diffusion_model.input_blocks.1.1.time_mixer.mix_factor": "blocks.8.mix_factor", + "model.diffusion_model.input_blocks.1.1.time_pos_embed.0.bias": "blocks.7.positional_embedding_proj.0.bias", + "model.diffusion_model.input_blocks.1.1.time_pos_embed.0.weight": "blocks.7.positional_embedding_proj.0.weight", + "model.diffusion_model.input_blocks.1.1.time_pos_embed.2.bias": "blocks.7.positional_embedding_proj.2.bias", + "model.diffusion_model.input_blocks.1.1.time_pos_embed.2.weight": "blocks.7.positional_embedding_proj.2.weight", + "model.diffusion_model.input_blocks.1.1.time_stack.0.attn1.to_k.weight": "blocks.7.attn1.to_k.weight", + "model.diffusion_model.input_blocks.1.1.time_stack.0.attn1.to_out.0.bias": "blocks.7.attn1.to_out.bias", + "model.diffusion_model.input_blocks.1.1.time_stack.0.attn1.to_out.0.weight": "blocks.7.attn1.to_out.weight", + "model.diffusion_model.input_blocks.1.1.time_stack.0.attn1.to_q.weight": "blocks.7.attn1.to_q.weight", + "model.diffusion_model.input_blocks.1.1.time_stack.0.attn1.to_v.weight": "blocks.7.attn1.to_v.weight", + "model.diffusion_model.input_blocks.1.1.time_stack.0.attn2.to_k.weight": "blocks.7.attn2.to_k.weight", + "model.diffusion_model.input_blocks.1.1.time_stack.0.attn2.to_out.0.bias": "blocks.7.attn2.to_out.bias", + "model.diffusion_model.input_blocks.1.1.time_stack.0.attn2.to_out.0.weight": "blocks.7.attn2.to_out.weight", + "model.diffusion_model.input_blocks.1.1.time_stack.0.attn2.to_q.weight": "blocks.7.attn2.to_q.weight", + "model.diffusion_model.input_blocks.1.1.time_stack.0.attn2.to_v.weight": "blocks.7.attn2.to_v.weight", + "model.diffusion_model.input_blocks.1.1.time_stack.0.ff.net.0.proj.bias": "blocks.7.act_fn_out.proj.bias", + "model.diffusion_model.input_blocks.1.1.time_stack.0.ff.net.0.proj.weight": "blocks.7.act_fn_out.proj.weight", + "model.diffusion_model.input_blocks.1.1.time_stack.0.ff.net.2.bias": "blocks.7.ff_out.bias", + "model.diffusion_model.input_blocks.1.1.time_stack.0.ff.net.2.weight": "blocks.7.ff_out.weight", + "model.diffusion_model.input_blocks.1.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.7.act_fn_in.proj.bias", + "model.diffusion_model.input_blocks.1.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.7.act_fn_in.proj.weight", + "model.diffusion_model.input_blocks.1.1.time_stack.0.ff_in.net.2.bias": "blocks.7.ff_in.bias", + "model.diffusion_model.input_blocks.1.1.time_stack.0.ff_in.net.2.weight": "blocks.7.ff_in.weight", + "model.diffusion_model.input_blocks.1.1.time_stack.0.norm1.bias": "blocks.7.norm1.bias", + "model.diffusion_model.input_blocks.1.1.time_stack.0.norm1.weight": "blocks.7.norm1.weight", + "model.diffusion_model.input_blocks.1.1.time_stack.0.norm2.bias": "blocks.7.norm2.bias", + "model.diffusion_model.input_blocks.1.1.time_stack.0.norm2.weight": "blocks.7.norm2.weight", + "model.diffusion_model.input_blocks.1.1.time_stack.0.norm3.bias": "blocks.7.norm_out.bias", + "model.diffusion_model.input_blocks.1.1.time_stack.0.norm3.weight": "blocks.7.norm_out.weight", + "model.diffusion_model.input_blocks.1.1.time_stack.0.norm_in.bias": "blocks.7.norm_in.bias", + "model.diffusion_model.input_blocks.1.1.time_stack.0.norm_in.weight": "blocks.7.norm_in.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.5.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.5.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.5.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.5.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.5.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.5.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.5.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.5.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.5.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.5.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.5.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.5.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.5.transformer_blocks.0.ff.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.5.transformer_blocks.0.ff.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.5.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.5.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.5.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.5.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.5.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.5.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.10.0.emb_layers.1.bias": "blocks.66.time_emb_proj.bias", + "model.diffusion_model.input_blocks.10.0.emb_layers.1.weight": "blocks.66.time_emb_proj.weight", + "model.diffusion_model.input_blocks.10.0.in_layers.0.bias": "blocks.66.norm1.bias", + "model.diffusion_model.input_blocks.10.0.in_layers.0.weight": "blocks.66.norm1.weight", + "model.diffusion_model.input_blocks.10.0.in_layers.2.bias": "blocks.66.conv1.bias", + "model.diffusion_model.input_blocks.10.0.in_layers.2.weight": "blocks.66.conv1.weight", + "model.diffusion_model.input_blocks.10.0.out_layers.0.bias": "blocks.66.norm2.bias", + "model.diffusion_model.input_blocks.10.0.out_layers.0.weight": "blocks.66.norm2.weight", + "model.diffusion_model.input_blocks.10.0.out_layers.3.bias": "blocks.66.conv2.bias", + "model.diffusion_model.input_blocks.10.0.out_layers.3.weight": "blocks.66.conv2.weight", + "model.diffusion_model.input_blocks.10.0.time_mixer.mix_factor": "blocks.69.mix_factor", + "model.diffusion_model.input_blocks.10.0.time_stack.emb_layers.1.bias": "blocks.68.time_emb_proj.bias", + "model.diffusion_model.input_blocks.10.0.time_stack.emb_layers.1.weight": "blocks.68.time_emb_proj.weight", + "model.diffusion_model.input_blocks.10.0.time_stack.in_layers.0.bias": "blocks.68.norm1.bias", + "model.diffusion_model.input_blocks.10.0.time_stack.in_layers.0.weight": "blocks.68.norm1.weight", + "model.diffusion_model.input_blocks.10.0.time_stack.in_layers.2.bias": "blocks.68.conv1.bias", + "model.diffusion_model.input_blocks.10.0.time_stack.in_layers.2.weight": "blocks.68.conv1.weight", + "model.diffusion_model.input_blocks.10.0.time_stack.out_layers.0.bias": "blocks.68.norm2.bias", + "model.diffusion_model.input_blocks.10.0.time_stack.out_layers.0.weight": "blocks.68.norm2.weight", + "model.diffusion_model.input_blocks.10.0.time_stack.out_layers.3.bias": "blocks.68.conv2.bias", + "model.diffusion_model.input_blocks.10.0.time_stack.out_layers.3.weight": "blocks.68.conv2.weight", + "model.diffusion_model.input_blocks.11.0.emb_layers.1.bias": "blocks.71.time_emb_proj.bias", + "model.diffusion_model.input_blocks.11.0.emb_layers.1.weight": "blocks.71.time_emb_proj.weight", + "model.diffusion_model.input_blocks.11.0.in_layers.0.bias": "blocks.71.norm1.bias", + "model.diffusion_model.input_blocks.11.0.in_layers.0.weight": "blocks.71.norm1.weight", + "model.diffusion_model.input_blocks.11.0.in_layers.2.bias": "blocks.71.conv1.bias", + "model.diffusion_model.input_blocks.11.0.in_layers.2.weight": "blocks.71.conv1.weight", + "model.diffusion_model.input_blocks.11.0.out_layers.0.bias": "blocks.71.norm2.bias", + "model.diffusion_model.input_blocks.11.0.out_layers.0.weight": "blocks.71.norm2.weight", + "model.diffusion_model.input_blocks.11.0.out_layers.3.bias": "blocks.71.conv2.bias", + "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "blocks.71.conv2.weight", + "model.diffusion_model.input_blocks.11.0.time_mixer.mix_factor": "blocks.74.mix_factor", + "model.diffusion_model.input_blocks.11.0.time_stack.emb_layers.1.bias": "blocks.73.time_emb_proj.bias", + "model.diffusion_model.input_blocks.11.0.time_stack.emb_layers.1.weight": "blocks.73.time_emb_proj.weight", + "model.diffusion_model.input_blocks.11.0.time_stack.in_layers.0.bias": "blocks.73.norm1.bias", + "model.diffusion_model.input_blocks.11.0.time_stack.in_layers.0.weight": "blocks.73.norm1.weight", + "model.diffusion_model.input_blocks.11.0.time_stack.in_layers.2.bias": "blocks.73.conv1.bias", + "model.diffusion_model.input_blocks.11.0.time_stack.in_layers.2.weight": "blocks.73.conv1.weight", + "model.diffusion_model.input_blocks.11.0.time_stack.out_layers.0.bias": "blocks.73.norm2.bias", + "model.diffusion_model.input_blocks.11.0.time_stack.out_layers.0.weight": "blocks.73.norm2.weight", + "model.diffusion_model.input_blocks.11.0.time_stack.out_layers.3.bias": "blocks.73.conv2.bias", + "model.diffusion_model.input_blocks.11.0.time_stack.out_layers.3.weight": "blocks.73.conv2.weight", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "blocks.10.time_emb_proj.bias", + "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "blocks.10.time_emb_proj.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "blocks.10.norm1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "blocks.10.norm1.weight", + "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "blocks.10.conv1.bias", + "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "blocks.10.conv1.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "blocks.10.norm2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "blocks.10.norm2.weight", + "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "blocks.10.conv2.bias", + "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "blocks.10.conv2.weight", + "model.diffusion_model.input_blocks.2.0.time_mixer.mix_factor": "blocks.13.mix_factor", + "model.diffusion_model.input_blocks.2.0.time_stack.emb_layers.1.bias": "blocks.12.time_emb_proj.bias", + "model.diffusion_model.input_blocks.2.0.time_stack.emb_layers.1.weight": "blocks.12.time_emb_proj.weight", + "model.diffusion_model.input_blocks.2.0.time_stack.in_layers.0.bias": "blocks.12.norm1.bias", + "model.diffusion_model.input_blocks.2.0.time_stack.in_layers.0.weight": "blocks.12.norm1.weight", + "model.diffusion_model.input_blocks.2.0.time_stack.in_layers.2.bias": "blocks.12.conv1.bias", + "model.diffusion_model.input_blocks.2.0.time_stack.in_layers.2.weight": "blocks.12.conv1.weight", + "model.diffusion_model.input_blocks.2.0.time_stack.out_layers.0.bias": "blocks.12.norm2.bias", + "model.diffusion_model.input_blocks.2.0.time_stack.out_layers.0.weight": "blocks.12.norm2.weight", + "model.diffusion_model.input_blocks.2.0.time_stack.out_layers.3.bias": "blocks.12.conv2.bias", + "model.diffusion_model.input_blocks.2.0.time_stack.out_layers.3.weight": "blocks.12.conv2.weight", + "model.diffusion_model.input_blocks.2.1.norm.bias": "blocks.15.norm.bias", + "model.diffusion_model.input_blocks.2.1.norm.weight": "blocks.15.norm.weight", + "model.diffusion_model.input_blocks.2.1.proj_in.bias": "blocks.15.proj_in.bias", + "model.diffusion_model.input_blocks.2.1.proj_in.weight": "blocks.15.proj_in.weight", + "model.diffusion_model.input_blocks.2.1.proj_out.bias": "blocks.18.proj.bias", + "model.diffusion_model.input_blocks.2.1.proj_out.weight": "blocks.18.proj.weight", + "model.diffusion_model.input_blocks.2.1.time_mixer.mix_factor": "blocks.18.mix_factor", + "model.diffusion_model.input_blocks.2.1.time_pos_embed.0.bias": "blocks.17.positional_embedding_proj.0.bias", + "model.diffusion_model.input_blocks.2.1.time_pos_embed.0.weight": "blocks.17.positional_embedding_proj.0.weight", + "model.diffusion_model.input_blocks.2.1.time_pos_embed.2.bias": "blocks.17.positional_embedding_proj.2.bias", + "model.diffusion_model.input_blocks.2.1.time_pos_embed.2.weight": "blocks.17.positional_embedding_proj.2.weight", + "model.diffusion_model.input_blocks.2.1.time_stack.0.attn1.to_k.weight": "blocks.17.attn1.to_k.weight", + "model.diffusion_model.input_blocks.2.1.time_stack.0.attn1.to_out.0.bias": "blocks.17.attn1.to_out.bias", + "model.diffusion_model.input_blocks.2.1.time_stack.0.attn1.to_out.0.weight": "blocks.17.attn1.to_out.weight", + "model.diffusion_model.input_blocks.2.1.time_stack.0.attn1.to_q.weight": "blocks.17.attn1.to_q.weight", + "model.diffusion_model.input_blocks.2.1.time_stack.0.attn1.to_v.weight": "blocks.17.attn1.to_v.weight", + "model.diffusion_model.input_blocks.2.1.time_stack.0.attn2.to_k.weight": "blocks.17.attn2.to_k.weight", + "model.diffusion_model.input_blocks.2.1.time_stack.0.attn2.to_out.0.bias": "blocks.17.attn2.to_out.bias", + "model.diffusion_model.input_blocks.2.1.time_stack.0.attn2.to_out.0.weight": "blocks.17.attn2.to_out.weight", + "model.diffusion_model.input_blocks.2.1.time_stack.0.attn2.to_q.weight": "blocks.17.attn2.to_q.weight", + "model.diffusion_model.input_blocks.2.1.time_stack.0.attn2.to_v.weight": "blocks.17.attn2.to_v.weight", + "model.diffusion_model.input_blocks.2.1.time_stack.0.ff.net.0.proj.bias": "blocks.17.act_fn_out.proj.bias", + "model.diffusion_model.input_blocks.2.1.time_stack.0.ff.net.0.proj.weight": "blocks.17.act_fn_out.proj.weight", + "model.diffusion_model.input_blocks.2.1.time_stack.0.ff.net.2.bias": "blocks.17.ff_out.bias", + "model.diffusion_model.input_blocks.2.1.time_stack.0.ff.net.2.weight": "blocks.17.ff_out.weight", + "model.diffusion_model.input_blocks.2.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.17.act_fn_in.proj.bias", + "model.diffusion_model.input_blocks.2.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.17.act_fn_in.proj.weight", + "model.diffusion_model.input_blocks.2.1.time_stack.0.ff_in.net.2.bias": "blocks.17.ff_in.bias", + "model.diffusion_model.input_blocks.2.1.time_stack.0.ff_in.net.2.weight": "blocks.17.ff_in.weight", + "model.diffusion_model.input_blocks.2.1.time_stack.0.norm1.bias": "blocks.17.norm1.bias", + "model.diffusion_model.input_blocks.2.1.time_stack.0.norm1.weight": "blocks.17.norm1.weight", + "model.diffusion_model.input_blocks.2.1.time_stack.0.norm2.bias": "blocks.17.norm2.bias", + "model.diffusion_model.input_blocks.2.1.time_stack.0.norm2.weight": "blocks.17.norm2.weight", + "model.diffusion_model.input_blocks.2.1.time_stack.0.norm3.bias": "blocks.17.norm_out.bias", + "model.diffusion_model.input_blocks.2.1.time_stack.0.norm3.weight": "blocks.17.norm_out.weight", + "model.diffusion_model.input_blocks.2.1.time_stack.0.norm_in.bias": "blocks.17.norm_in.bias", + "model.diffusion_model.input_blocks.2.1.time_stack.0.norm_in.weight": "blocks.17.norm_in.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.15.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.15.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.15.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.15.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.15.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.15.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.15.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.15.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.15.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.15.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.15.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.15.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.15.transformer_blocks.0.ff.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.15.transformer_blocks.0.ff.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.15.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.15.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.15.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.15.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.15.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.15.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.3.0.op.bias": "blocks.20.conv.bias", + "model.diffusion_model.input_blocks.3.0.op.weight": "blocks.20.conv.weight", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "blocks.22.time_emb_proj.bias", + "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "blocks.22.time_emb_proj.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "blocks.22.norm1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "blocks.22.norm1.weight", + "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "blocks.22.conv1.bias", + "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "blocks.22.conv1.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "blocks.22.norm2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "blocks.22.norm2.weight", + "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "blocks.22.conv2.bias", + "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "blocks.22.conv2.weight", + "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "blocks.22.conv_shortcut.bias", + "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "blocks.22.conv_shortcut.weight", + "model.diffusion_model.input_blocks.4.0.time_mixer.mix_factor": "blocks.25.mix_factor", + "model.diffusion_model.input_blocks.4.0.time_stack.emb_layers.1.bias": "blocks.24.time_emb_proj.bias", + "model.diffusion_model.input_blocks.4.0.time_stack.emb_layers.1.weight": "blocks.24.time_emb_proj.weight", + "model.diffusion_model.input_blocks.4.0.time_stack.in_layers.0.bias": "blocks.24.norm1.bias", + "model.diffusion_model.input_blocks.4.0.time_stack.in_layers.0.weight": "blocks.24.norm1.weight", + "model.diffusion_model.input_blocks.4.0.time_stack.in_layers.2.bias": "blocks.24.conv1.bias", + "model.diffusion_model.input_blocks.4.0.time_stack.in_layers.2.weight": "blocks.24.conv1.weight", + "model.diffusion_model.input_blocks.4.0.time_stack.out_layers.0.bias": "blocks.24.norm2.bias", + "model.diffusion_model.input_blocks.4.0.time_stack.out_layers.0.weight": "blocks.24.norm2.weight", + "model.diffusion_model.input_blocks.4.0.time_stack.out_layers.3.bias": "blocks.24.conv2.bias", + "model.diffusion_model.input_blocks.4.0.time_stack.out_layers.3.weight": "blocks.24.conv2.weight", + "model.diffusion_model.input_blocks.4.1.norm.bias": "blocks.27.norm.bias", + "model.diffusion_model.input_blocks.4.1.norm.weight": "blocks.27.norm.weight", + "model.diffusion_model.input_blocks.4.1.proj_in.bias": "blocks.27.proj_in.bias", + "model.diffusion_model.input_blocks.4.1.proj_in.weight": "blocks.27.proj_in.weight", + "model.diffusion_model.input_blocks.4.1.proj_out.bias": "blocks.30.proj.bias", + "model.diffusion_model.input_blocks.4.1.proj_out.weight": "blocks.30.proj.weight", + "model.diffusion_model.input_blocks.4.1.time_mixer.mix_factor": "blocks.30.mix_factor", + "model.diffusion_model.input_blocks.4.1.time_pos_embed.0.bias": "blocks.29.positional_embedding_proj.0.bias", + "model.diffusion_model.input_blocks.4.1.time_pos_embed.0.weight": "blocks.29.positional_embedding_proj.0.weight", + "model.diffusion_model.input_blocks.4.1.time_pos_embed.2.bias": "blocks.29.positional_embedding_proj.2.bias", + "model.diffusion_model.input_blocks.4.1.time_pos_embed.2.weight": "blocks.29.positional_embedding_proj.2.weight", + "model.diffusion_model.input_blocks.4.1.time_stack.0.attn1.to_k.weight": "blocks.29.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.time_stack.0.attn1.to_out.0.bias": "blocks.29.attn1.to_out.bias", + "model.diffusion_model.input_blocks.4.1.time_stack.0.attn1.to_out.0.weight": "blocks.29.attn1.to_out.weight", + "model.diffusion_model.input_blocks.4.1.time_stack.0.attn1.to_q.weight": "blocks.29.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.time_stack.0.attn1.to_v.weight": "blocks.29.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.time_stack.0.attn2.to_k.weight": "blocks.29.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.time_stack.0.attn2.to_out.0.bias": "blocks.29.attn2.to_out.bias", + "model.diffusion_model.input_blocks.4.1.time_stack.0.attn2.to_out.0.weight": "blocks.29.attn2.to_out.weight", + "model.diffusion_model.input_blocks.4.1.time_stack.0.attn2.to_q.weight": "blocks.29.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.time_stack.0.attn2.to_v.weight": "blocks.29.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.time_stack.0.ff.net.0.proj.bias": "blocks.29.act_fn_out.proj.bias", + "model.diffusion_model.input_blocks.4.1.time_stack.0.ff.net.0.proj.weight": "blocks.29.act_fn_out.proj.weight", + "model.diffusion_model.input_blocks.4.1.time_stack.0.ff.net.2.bias": "blocks.29.ff_out.bias", + "model.diffusion_model.input_blocks.4.1.time_stack.0.ff.net.2.weight": "blocks.29.ff_out.weight", + "model.diffusion_model.input_blocks.4.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.29.act_fn_in.proj.bias", + "model.diffusion_model.input_blocks.4.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.29.act_fn_in.proj.weight", + "model.diffusion_model.input_blocks.4.1.time_stack.0.ff_in.net.2.bias": "blocks.29.ff_in.bias", + "model.diffusion_model.input_blocks.4.1.time_stack.0.ff_in.net.2.weight": "blocks.29.ff_in.weight", + "model.diffusion_model.input_blocks.4.1.time_stack.0.norm1.bias": "blocks.29.norm1.bias", + "model.diffusion_model.input_blocks.4.1.time_stack.0.norm1.weight": "blocks.29.norm1.weight", + "model.diffusion_model.input_blocks.4.1.time_stack.0.norm2.bias": "blocks.29.norm2.bias", + "model.diffusion_model.input_blocks.4.1.time_stack.0.norm2.weight": "blocks.29.norm2.weight", + "model.diffusion_model.input_blocks.4.1.time_stack.0.norm3.bias": "blocks.29.norm_out.bias", + "model.diffusion_model.input_blocks.4.1.time_stack.0.norm3.weight": "blocks.29.norm_out.weight", + "model.diffusion_model.input_blocks.4.1.time_stack.0.norm_in.bias": "blocks.29.norm_in.bias", + "model.diffusion_model.input_blocks.4.1.time_stack.0.norm_in.weight": "blocks.29.norm_in.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.27.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.27.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.27.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.27.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.27.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.27.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.27.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.27.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.27.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.27.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.27.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.27.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.27.transformer_blocks.0.ff.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.27.transformer_blocks.0.ff.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.27.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.27.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.27.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.27.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.27.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.27.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "blocks.32.time_emb_proj.bias", + "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "blocks.32.time_emb_proj.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "blocks.32.norm1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "blocks.32.norm1.weight", + "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "blocks.32.conv1.bias", + "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "blocks.32.conv1.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "blocks.32.norm2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "blocks.32.norm2.weight", + "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "blocks.32.conv2.bias", + "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "blocks.32.conv2.weight", + "model.diffusion_model.input_blocks.5.0.time_mixer.mix_factor": "blocks.35.mix_factor", + "model.diffusion_model.input_blocks.5.0.time_stack.emb_layers.1.bias": "blocks.34.time_emb_proj.bias", + "model.diffusion_model.input_blocks.5.0.time_stack.emb_layers.1.weight": "blocks.34.time_emb_proj.weight", + "model.diffusion_model.input_blocks.5.0.time_stack.in_layers.0.bias": "blocks.34.norm1.bias", + "model.diffusion_model.input_blocks.5.0.time_stack.in_layers.0.weight": "blocks.34.norm1.weight", + "model.diffusion_model.input_blocks.5.0.time_stack.in_layers.2.bias": "blocks.34.conv1.bias", + "model.diffusion_model.input_blocks.5.0.time_stack.in_layers.2.weight": "blocks.34.conv1.weight", + "model.diffusion_model.input_blocks.5.0.time_stack.out_layers.0.bias": "blocks.34.norm2.bias", + "model.diffusion_model.input_blocks.5.0.time_stack.out_layers.0.weight": "blocks.34.norm2.weight", + "model.diffusion_model.input_blocks.5.0.time_stack.out_layers.3.bias": "blocks.34.conv2.bias", + "model.diffusion_model.input_blocks.5.0.time_stack.out_layers.3.weight": "blocks.34.conv2.weight", + "model.diffusion_model.input_blocks.5.1.norm.bias": "blocks.37.norm.bias", + "model.diffusion_model.input_blocks.5.1.norm.weight": "blocks.37.norm.weight", + "model.diffusion_model.input_blocks.5.1.proj_in.bias": "blocks.37.proj_in.bias", + "model.diffusion_model.input_blocks.5.1.proj_in.weight": "blocks.37.proj_in.weight", + "model.diffusion_model.input_blocks.5.1.proj_out.bias": "blocks.40.proj.bias", + "model.diffusion_model.input_blocks.5.1.proj_out.weight": "blocks.40.proj.weight", + "model.diffusion_model.input_blocks.5.1.time_mixer.mix_factor": "blocks.40.mix_factor", + "model.diffusion_model.input_blocks.5.1.time_pos_embed.0.bias": "blocks.39.positional_embedding_proj.0.bias", + "model.diffusion_model.input_blocks.5.1.time_pos_embed.0.weight": "blocks.39.positional_embedding_proj.0.weight", + "model.diffusion_model.input_blocks.5.1.time_pos_embed.2.bias": "blocks.39.positional_embedding_proj.2.bias", + "model.diffusion_model.input_blocks.5.1.time_pos_embed.2.weight": "blocks.39.positional_embedding_proj.2.weight", + "model.diffusion_model.input_blocks.5.1.time_stack.0.attn1.to_k.weight": "blocks.39.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.time_stack.0.attn1.to_out.0.bias": "blocks.39.attn1.to_out.bias", + "model.diffusion_model.input_blocks.5.1.time_stack.0.attn1.to_out.0.weight": "blocks.39.attn1.to_out.weight", + "model.diffusion_model.input_blocks.5.1.time_stack.0.attn1.to_q.weight": "blocks.39.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.time_stack.0.attn1.to_v.weight": "blocks.39.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.time_stack.0.attn2.to_k.weight": "blocks.39.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.time_stack.0.attn2.to_out.0.bias": "blocks.39.attn2.to_out.bias", + "model.diffusion_model.input_blocks.5.1.time_stack.0.attn2.to_out.0.weight": "blocks.39.attn2.to_out.weight", + "model.diffusion_model.input_blocks.5.1.time_stack.0.attn2.to_q.weight": "blocks.39.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.time_stack.0.attn2.to_v.weight": "blocks.39.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.time_stack.0.ff.net.0.proj.bias": "blocks.39.act_fn_out.proj.bias", + "model.diffusion_model.input_blocks.5.1.time_stack.0.ff.net.0.proj.weight": "blocks.39.act_fn_out.proj.weight", + "model.diffusion_model.input_blocks.5.1.time_stack.0.ff.net.2.bias": "blocks.39.ff_out.bias", + "model.diffusion_model.input_blocks.5.1.time_stack.0.ff.net.2.weight": "blocks.39.ff_out.weight", + "model.diffusion_model.input_blocks.5.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.39.act_fn_in.proj.bias", + "model.diffusion_model.input_blocks.5.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.39.act_fn_in.proj.weight", + "model.diffusion_model.input_blocks.5.1.time_stack.0.ff_in.net.2.bias": "blocks.39.ff_in.bias", + "model.diffusion_model.input_blocks.5.1.time_stack.0.ff_in.net.2.weight": "blocks.39.ff_in.weight", + "model.diffusion_model.input_blocks.5.1.time_stack.0.norm1.bias": "blocks.39.norm1.bias", + "model.diffusion_model.input_blocks.5.1.time_stack.0.norm1.weight": "blocks.39.norm1.weight", + "model.diffusion_model.input_blocks.5.1.time_stack.0.norm2.bias": "blocks.39.norm2.bias", + "model.diffusion_model.input_blocks.5.1.time_stack.0.norm2.weight": "blocks.39.norm2.weight", + "model.diffusion_model.input_blocks.5.1.time_stack.0.norm3.bias": "blocks.39.norm_out.bias", + "model.diffusion_model.input_blocks.5.1.time_stack.0.norm3.weight": "blocks.39.norm_out.weight", + "model.diffusion_model.input_blocks.5.1.time_stack.0.norm_in.bias": "blocks.39.norm_in.bias", + "model.diffusion_model.input_blocks.5.1.time_stack.0.norm_in.weight": "blocks.39.norm_in.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.37.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.37.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.37.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.37.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.37.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.37.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.37.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.37.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.37.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.37.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.37.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.37.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.37.transformer_blocks.0.ff.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.37.transformer_blocks.0.ff.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.37.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.37.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.37.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.37.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.37.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.37.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.6.0.op.bias": "blocks.42.conv.bias", + "model.diffusion_model.input_blocks.6.0.op.weight": "blocks.42.conv.weight", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "blocks.44.time_emb_proj.bias", + "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "blocks.44.time_emb_proj.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "blocks.44.norm1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "blocks.44.norm1.weight", + "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "blocks.44.conv1.bias", + "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "blocks.44.conv1.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "blocks.44.norm2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "blocks.44.norm2.weight", + "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "blocks.44.conv2.bias", + "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "blocks.44.conv2.weight", + "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "blocks.44.conv_shortcut.bias", + "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "blocks.44.conv_shortcut.weight", + "model.diffusion_model.input_blocks.7.0.time_mixer.mix_factor": "blocks.47.mix_factor", + "model.diffusion_model.input_blocks.7.0.time_stack.emb_layers.1.bias": "blocks.46.time_emb_proj.bias", + "model.diffusion_model.input_blocks.7.0.time_stack.emb_layers.1.weight": "blocks.46.time_emb_proj.weight", + "model.diffusion_model.input_blocks.7.0.time_stack.in_layers.0.bias": "blocks.46.norm1.bias", + "model.diffusion_model.input_blocks.7.0.time_stack.in_layers.0.weight": "blocks.46.norm1.weight", + "model.diffusion_model.input_blocks.7.0.time_stack.in_layers.2.bias": "blocks.46.conv1.bias", + "model.diffusion_model.input_blocks.7.0.time_stack.in_layers.2.weight": "blocks.46.conv1.weight", + "model.diffusion_model.input_blocks.7.0.time_stack.out_layers.0.bias": "blocks.46.norm2.bias", + "model.diffusion_model.input_blocks.7.0.time_stack.out_layers.0.weight": "blocks.46.norm2.weight", + "model.diffusion_model.input_blocks.7.0.time_stack.out_layers.3.bias": "blocks.46.conv2.bias", + "model.diffusion_model.input_blocks.7.0.time_stack.out_layers.3.weight": "blocks.46.conv2.weight", + "model.diffusion_model.input_blocks.7.1.norm.bias": "blocks.49.norm.bias", + "model.diffusion_model.input_blocks.7.1.norm.weight": "blocks.49.norm.weight", + "model.diffusion_model.input_blocks.7.1.proj_in.bias": "blocks.49.proj_in.bias", + "model.diffusion_model.input_blocks.7.1.proj_in.weight": "blocks.49.proj_in.weight", + "model.diffusion_model.input_blocks.7.1.proj_out.bias": "blocks.52.proj.bias", + "model.diffusion_model.input_blocks.7.1.proj_out.weight": "blocks.52.proj.weight", + "model.diffusion_model.input_blocks.7.1.time_mixer.mix_factor": "blocks.52.mix_factor", + "model.diffusion_model.input_blocks.7.1.time_pos_embed.0.bias": "blocks.51.positional_embedding_proj.0.bias", + "model.diffusion_model.input_blocks.7.1.time_pos_embed.0.weight": "blocks.51.positional_embedding_proj.0.weight", + "model.diffusion_model.input_blocks.7.1.time_pos_embed.2.bias": "blocks.51.positional_embedding_proj.2.bias", + "model.diffusion_model.input_blocks.7.1.time_pos_embed.2.weight": "blocks.51.positional_embedding_proj.2.weight", + "model.diffusion_model.input_blocks.7.1.time_stack.0.attn1.to_k.weight": "blocks.51.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.time_stack.0.attn1.to_out.0.bias": "blocks.51.attn1.to_out.bias", + "model.diffusion_model.input_blocks.7.1.time_stack.0.attn1.to_out.0.weight": "blocks.51.attn1.to_out.weight", + "model.diffusion_model.input_blocks.7.1.time_stack.0.attn1.to_q.weight": "blocks.51.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.time_stack.0.attn1.to_v.weight": "blocks.51.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.time_stack.0.attn2.to_k.weight": "blocks.51.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.time_stack.0.attn2.to_out.0.bias": "blocks.51.attn2.to_out.bias", + "model.diffusion_model.input_blocks.7.1.time_stack.0.attn2.to_out.0.weight": "blocks.51.attn2.to_out.weight", + "model.diffusion_model.input_blocks.7.1.time_stack.0.attn2.to_q.weight": "blocks.51.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.time_stack.0.attn2.to_v.weight": "blocks.51.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.time_stack.0.ff.net.0.proj.bias": "blocks.51.act_fn_out.proj.bias", + "model.diffusion_model.input_blocks.7.1.time_stack.0.ff.net.0.proj.weight": "blocks.51.act_fn_out.proj.weight", + "model.diffusion_model.input_blocks.7.1.time_stack.0.ff.net.2.bias": "blocks.51.ff_out.bias", + "model.diffusion_model.input_blocks.7.1.time_stack.0.ff.net.2.weight": "blocks.51.ff_out.weight", + "model.diffusion_model.input_blocks.7.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.51.act_fn_in.proj.bias", + "model.diffusion_model.input_blocks.7.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.51.act_fn_in.proj.weight", + "model.diffusion_model.input_blocks.7.1.time_stack.0.ff_in.net.2.bias": "blocks.51.ff_in.bias", + "model.diffusion_model.input_blocks.7.1.time_stack.0.ff_in.net.2.weight": "blocks.51.ff_in.weight", + "model.diffusion_model.input_blocks.7.1.time_stack.0.norm1.bias": "blocks.51.norm1.bias", + "model.diffusion_model.input_blocks.7.1.time_stack.0.norm1.weight": "blocks.51.norm1.weight", + "model.diffusion_model.input_blocks.7.1.time_stack.0.norm2.bias": "blocks.51.norm2.bias", + "model.diffusion_model.input_blocks.7.1.time_stack.0.norm2.weight": "blocks.51.norm2.weight", + "model.diffusion_model.input_blocks.7.1.time_stack.0.norm3.bias": "blocks.51.norm_out.bias", + "model.diffusion_model.input_blocks.7.1.time_stack.0.norm3.weight": "blocks.51.norm_out.weight", + "model.diffusion_model.input_blocks.7.1.time_stack.0.norm_in.bias": "blocks.51.norm_in.bias", + "model.diffusion_model.input_blocks.7.1.time_stack.0.norm_in.weight": "blocks.51.norm_in.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.49.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.49.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.49.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.49.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.49.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.49.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.49.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.49.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.49.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.49.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.49.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.49.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.49.transformer_blocks.0.ff.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.49.transformer_blocks.0.ff.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.49.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.49.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.49.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.49.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.49.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.49.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "blocks.54.time_emb_proj.bias", + "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "blocks.54.time_emb_proj.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "blocks.54.norm1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "blocks.54.norm1.weight", + "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "blocks.54.conv1.bias", + "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "blocks.54.conv1.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "blocks.54.norm2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "blocks.54.norm2.weight", + "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "blocks.54.conv2.bias", + "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "blocks.54.conv2.weight", + "model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor": "blocks.57.mix_factor", + "model.diffusion_model.input_blocks.8.0.time_stack.emb_layers.1.bias": "blocks.56.time_emb_proj.bias", + "model.diffusion_model.input_blocks.8.0.time_stack.emb_layers.1.weight": "blocks.56.time_emb_proj.weight", + "model.diffusion_model.input_blocks.8.0.time_stack.in_layers.0.bias": "blocks.56.norm1.bias", + "model.diffusion_model.input_blocks.8.0.time_stack.in_layers.0.weight": "blocks.56.norm1.weight", + "model.diffusion_model.input_blocks.8.0.time_stack.in_layers.2.bias": "blocks.56.conv1.bias", + "model.diffusion_model.input_blocks.8.0.time_stack.in_layers.2.weight": "blocks.56.conv1.weight", + "model.diffusion_model.input_blocks.8.0.time_stack.out_layers.0.bias": "blocks.56.norm2.bias", + "model.diffusion_model.input_blocks.8.0.time_stack.out_layers.0.weight": "blocks.56.norm2.weight", + "model.diffusion_model.input_blocks.8.0.time_stack.out_layers.3.bias": "blocks.56.conv2.bias", + "model.diffusion_model.input_blocks.8.0.time_stack.out_layers.3.weight": "blocks.56.conv2.weight", + "model.diffusion_model.input_blocks.8.1.norm.bias": "blocks.59.norm.bias", + "model.diffusion_model.input_blocks.8.1.norm.weight": "blocks.59.norm.weight", + "model.diffusion_model.input_blocks.8.1.proj_in.bias": "blocks.59.proj_in.bias", + "model.diffusion_model.input_blocks.8.1.proj_in.weight": "blocks.59.proj_in.weight", + "model.diffusion_model.input_blocks.8.1.proj_out.bias": "blocks.62.proj.bias", + "model.diffusion_model.input_blocks.8.1.proj_out.weight": "blocks.62.proj.weight", + "model.diffusion_model.input_blocks.8.1.time_mixer.mix_factor": "blocks.62.mix_factor", + "model.diffusion_model.input_blocks.8.1.time_pos_embed.0.bias": "blocks.61.positional_embedding_proj.0.bias", + "model.diffusion_model.input_blocks.8.1.time_pos_embed.0.weight": "blocks.61.positional_embedding_proj.0.weight", + "model.diffusion_model.input_blocks.8.1.time_pos_embed.2.bias": "blocks.61.positional_embedding_proj.2.bias", + "model.diffusion_model.input_blocks.8.1.time_pos_embed.2.weight": "blocks.61.positional_embedding_proj.2.weight", + "model.diffusion_model.input_blocks.8.1.time_stack.0.attn1.to_k.weight": "blocks.61.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.time_stack.0.attn1.to_out.0.bias": "blocks.61.attn1.to_out.bias", + "model.diffusion_model.input_blocks.8.1.time_stack.0.attn1.to_out.0.weight": "blocks.61.attn1.to_out.weight", + "model.diffusion_model.input_blocks.8.1.time_stack.0.attn1.to_q.weight": "blocks.61.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.time_stack.0.attn1.to_v.weight": "blocks.61.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.time_stack.0.attn2.to_k.weight": "blocks.61.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.time_stack.0.attn2.to_out.0.bias": "blocks.61.attn2.to_out.bias", + "model.diffusion_model.input_blocks.8.1.time_stack.0.attn2.to_out.0.weight": "blocks.61.attn2.to_out.weight", + "model.diffusion_model.input_blocks.8.1.time_stack.0.attn2.to_q.weight": "blocks.61.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.time_stack.0.attn2.to_v.weight": "blocks.61.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.time_stack.0.ff.net.0.proj.bias": "blocks.61.act_fn_out.proj.bias", + "model.diffusion_model.input_blocks.8.1.time_stack.0.ff.net.0.proj.weight": "blocks.61.act_fn_out.proj.weight", + "model.diffusion_model.input_blocks.8.1.time_stack.0.ff.net.2.bias": "blocks.61.ff_out.bias", + "model.diffusion_model.input_blocks.8.1.time_stack.0.ff.net.2.weight": "blocks.61.ff_out.weight", + "model.diffusion_model.input_blocks.8.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.61.act_fn_in.proj.bias", + "model.diffusion_model.input_blocks.8.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.61.act_fn_in.proj.weight", + "model.diffusion_model.input_blocks.8.1.time_stack.0.ff_in.net.2.bias": "blocks.61.ff_in.bias", + "model.diffusion_model.input_blocks.8.1.time_stack.0.ff_in.net.2.weight": "blocks.61.ff_in.weight", + "model.diffusion_model.input_blocks.8.1.time_stack.0.norm1.bias": "blocks.61.norm1.bias", + "model.diffusion_model.input_blocks.8.1.time_stack.0.norm1.weight": "blocks.61.norm1.weight", + "model.diffusion_model.input_blocks.8.1.time_stack.0.norm2.bias": "blocks.61.norm2.bias", + "model.diffusion_model.input_blocks.8.1.time_stack.0.norm2.weight": "blocks.61.norm2.weight", + "model.diffusion_model.input_blocks.8.1.time_stack.0.norm3.bias": "blocks.61.norm_out.bias", + "model.diffusion_model.input_blocks.8.1.time_stack.0.norm3.weight": "blocks.61.norm_out.weight", + "model.diffusion_model.input_blocks.8.1.time_stack.0.norm_in.bias": "blocks.61.norm_in.bias", + "model.diffusion_model.input_blocks.8.1.time_stack.0.norm_in.weight": "blocks.61.norm_in.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.59.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.59.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.59.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.59.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.59.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.59.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.59.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.59.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.59.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.59.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.59.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.59.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.59.transformer_blocks.0.ff.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.59.transformer_blocks.0.ff.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.59.transformer_blocks.0.norm1.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.59.transformer_blocks.0.norm1.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.59.transformer_blocks.0.norm2.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.59.transformer_blocks.0.norm2.weight", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.59.transformer_blocks.0.norm3.bias", + "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.59.transformer_blocks.0.norm3.weight", + "model.diffusion_model.input_blocks.9.0.op.bias": "blocks.64.conv.bias", + "model.diffusion_model.input_blocks.9.0.op.weight": "blocks.64.conv.weight", + "model.diffusion_model.label_emb.0.0.bias": "add_time_embedding.0.bias", + "model.diffusion_model.label_emb.0.0.weight": "add_time_embedding.0.weight", + "model.diffusion_model.label_emb.0.2.bias": "add_time_embedding.2.bias", + "model.diffusion_model.label_emb.0.2.weight": "add_time_embedding.2.weight", + "model.diffusion_model.middle_block.0.emb_layers.1.bias": "blocks.76.time_emb_proj.bias", + "model.diffusion_model.middle_block.0.emb_layers.1.weight": "blocks.76.time_emb_proj.weight", + "model.diffusion_model.middle_block.0.in_layers.0.bias": "blocks.76.norm1.bias", + "model.diffusion_model.middle_block.0.in_layers.0.weight": "blocks.76.norm1.weight", + "model.diffusion_model.middle_block.0.in_layers.2.bias": "blocks.76.conv1.bias", + "model.diffusion_model.middle_block.0.in_layers.2.weight": "blocks.76.conv1.weight", + "model.diffusion_model.middle_block.0.out_layers.0.bias": "blocks.76.norm2.bias", + "model.diffusion_model.middle_block.0.out_layers.0.weight": "blocks.76.norm2.weight", + "model.diffusion_model.middle_block.0.out_layers.3.bias": "blocks.76.conv2.bias", + "model.diffusion_model.middle_block.0.out_layers.3.weight": "blocks.76.conv2.weight", + "model.diffusion_model.middle_block.0.time_mixer.mix_factor": "blocks.79.mix_factor", + "model.diffusion_model.middle_block.0.time_stack.emb_layers.1.bias": "blocks.78.time_emb_proj.bias", + "model.diffusion_model.middle_block.0.time_stack.emb_layers.1.weight": "blocks.78.time_emb_proj.weight", + "model.diffusion_model.middle_block.0.time_stack.in_layers.0.bias": "blocks.78.norm1.bias", + "model.diffusion_model.middle_block.0.time_stack.in_layers.0.weight": "blocks.78.norm1.weight", + "model.diffusion_model.middle_block.0.time_stack.in_layers.2.bias": "blocks.78.conv1.bias", + "model.diffusion_model.middle_block.0.time_stack.in_layers.2.weight": "blocks.78.conv1.weight", + "model.diffusion_model.middle_block.0.time_stack.out_layers.0.bias": "blocks.78.norm2.bias", + "model.diffusion_model.middle_block.0.time_stack.out_layers.0.weight": "blocks.78.norm2.weight", + "model.diffusion_model.middle_block.0.time_stack.out_layers.3.bias": "blocks.78.conv2.bias", + "model.diffusion_model.middle_block.0.time_stack.out_layers.3.weight": "blocks.78.conv2.weight", + "model.diffusion_model.middle_block.1.norm.bias": "blocks.81.norm.bias", + "model.diffusion_model.middle_block.1.norm.weight": "blocks.81.norm.weight", + "model.diffusion_model.middle_block.1.proj_in.bias": "blocks.81.proj_in.bias", + "model.diffusion_model.middle_block.1.proj_in.weight": "blocks.81.proj_in.weight", + "model.diffusion_model.middle_block.1.proj_out.bias": "blocks.84.proj.bias", + "model.diffusion_model.middle_block.1.proj_out.weight": "blocks.84.proj.weight", + "model.diffusion_model.middle_block.1.time_mixer.mix_factor": "blocks.84.mix_factor", + "model.diffusion_model.middle_block.1.time_pos_embed.0.bias": "blocks.83.positional_embedding_proj.0.bias", + "model.diffusion_model.middle_block.1.time_pos_embed.0.weight": "blocks.83.positional_embedding_proj.0.weight", + "model.diffusion_model.middle_block.1.time_pos_embed.2.bias": "blocks.83.positional_embedding_proj.2.bias", + "model.diffusion_model.middle_block.1.time_pos_embed.2.weight": "blocks.83.positional_embedding_proj.2.weight", + "model.diffusion_model.middle_block.1.time_stack.0.attn1.to_k.weight": "blocks.83.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.time_stack.0.attn1.to_out.0.bias": "blocks.83.attn1.to_out.bias", + "model.diffusion_model.middle_block.1.time_stack.0.attn1.to_out.0.weight": "blocks.83.attn1.to_out.weight", + "model.diffusion_model.middle_block.1.time_stack.0.attn1.to_q.weight": "blocks.83.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.time_stack.0.attn1.to_v.weight": "blocks.83.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.time_stack.0.attn2.to_k.weight": "blocks.83.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.time_stack.0.attn2.to_out.0.bias": "blocks.83.attn2.to_out.bias", + "model.diffusion_model.middle_block.1.time_stack.0.attn2.to_out.0.weight": "blocks.83.attn2.to_out.weight", + "model.diffusion_model.middle_block.1.time_stack.0.attn2.to_q.weight": "blocks.83.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.time_stack.0.attn2.to_v.weight": "blocks.83.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.time_stack.0.ff.net.0.proj.bias": "blocks.83.act_fn_out.proj.bias", + "model.diffusion_model.middle_block.1.time_stack.0.ff.net.0.proj.weight": "blocks.83.act_fn_out.proj.weight", + "model.diffusion_model.middle_block.1.time_stack.0.ff.net.2.bias": "blocks.83.ff_out.bias", + "model.diffusion_model.middle_block.1.time_stack.0.ff.net.2.weight": "blocks.83.ff_out.weight", + "model.diffusion_model.middle_block.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.83.act_fn_in.proj.bias", + "model.diffusion_model.middle_block.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.83.act_fn_in.proj.weight", + "model.diffusion_model.middle_block.1.time_stack.0.ff_in.net.2.bias": "blocks.83.ff_in.bias", + "model.diffusion_model.middle_block.1.time_stack.0.ff_in.net.2.weight": "blocks.83.ff_in.weight", + "model.diffusion_model.middle_block.1.time_stack.0.norm1.bias": "blocks.83.norm1.bias", + "model.diffusion_model.middle_block.1.time_stack.0.norm1.weight": "blocks.83.norm1.weight", + "model.diffusion_model.middle_block.1.time_stack.0.norm2.bias": "blocks.83.norm2.bias", + "model.diffusion_model.middle_block.1.time_stack.0.norm2.weight": "blocks.83.norm2.weight", + "model.diffusion_model.middle_block.1.time_stack.0.norm3.bias": "blocks.83.norm_out.bias", + "model.diffusion_model.middle_block.1.time_stack.0.norm3.weight": "blocks.83.norm_out.weight", + "model.diffusion_model.middle_block.1.time_stack.0.norm_in.bias": "blocks.83.norm_in.bias", + "model.diffusion_model.middle_block.1.time_stack.0.norm_in.weight": "blocks.83.norm_in.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.81.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.81.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.81.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.81.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.81.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.81.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.81.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.81.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.81.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.81.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.81.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.81.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.81.transformer_blocks.0.ff.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.81.transformer_blocks.0.ff.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.81.transformer_blocks.0.norm1.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.81.transformer_blocks.0.norm1.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.81.transformer_blocks.0.norm2.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.81.transformer_blocks.0.norm2.weight", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.81.transformer_blocks.0.norm3.bias", + "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.81.transformer_blocks.0.norm3.weight", + "model.diffusion_model.middle_block.2.emb_layers.1.bias": "blocks.85.time_emb_proj.bias", + "model.diffusion_model.middle_block.2.emb_layers.1.weight": "blocks.85.time_emb_proj.weight", + "model.diffusion_model.middle_block.2.in_layers.0.bias": "blocks.85.norm1.bias", + "model.diffusion_model.middle_block.2.in_layers.0.weight": "blocks.85.norm1.weight", + "model.diffusion_model.middle_block.2.in_layers.2.bias": "blocks.85.conv1.bias", + "model.diffusion_model.middle_block.2.in_layers.2.weight": "blocks.85.conv1.weight", + "model.diffusion_model.middle_block.2.out_layers.0.bias": "blocks.85.norm2.bias", + "model.diffusion_model.middle_block.2.out_layers.0.weight": "blocks.85.norm2.weight", + "model.diffusion_model.middle_block.2.out_layers.3.bias": "blocks.85.conv2.bias", + "model.diffusion_model.middle_block.2.out_layers.3.weight": "blocks.85.conv2.weight", + "model.diffusion_model.middle_block.2.time_mixer.mix_factor": "blocks.88.mix_factor", + "model.diffusion_model.middle_block.2.time_stack.emb_layers.1.bias": "blocks.87.time_emb_proj.bias", + "model.diffusion_model.middle_block.2.time_stack.emb_layers.1.weight": "blocks.87.time_emb_proj.weight", + "model.diffusion_model.middle_block.2.time_stack.in_layers.0.bias": "blocks.87.norm1.bias", + "model.diffusion_model.middle_block.2.time_stack.in_layers.0.weight": "blocks.87.norm1.weight", + "model.diffusion_model.middle_block.2.time_stack.in_layers.2.bias": "blocks.87.conv1.bias", + "model.diffusion_model.middle_block.2.time_stack.in_layers.2.weight": "blocks.87.conv1.weight", + "model.diffusion_model.middle_block.2.time_stack.out_layers.0.bias": "blocks.87.norm2.bias", + "model.diffusion_model.middle_block.2.time_stack.out_layers.0.weight": "blocks.87.norm2.weight", + "model.diffusion_model.middle_block.2.time_stack.out_layers.3.bias": "blocks.87.conv2.bias", + "model.diffusion_model.middle_block.2.time_stack.out_layers.3.weight": "blocks.87.conv2.weight", + "model.diffusion_model.out.0.bias": "conv_norm_out.bias", + "model.diffusion_model.out.0.weight": "conv_norm_out.weight", + "model.diffusion_model.out.2.bias": "conv_out.bias", + "model.diffusion_model.out.2.weight": "conv_out.weight", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "blocks.90.time_emb_proj.bias", + "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "blocks.90.time_emb_proj.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "blocks.90.norm1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "blocks.90.norm1.weight", + "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "blocks.90.conv1.bias", + "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "blocks.90.conv1.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "blocks.90.norm2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "blocks.90.norm2.weight", + "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "blocks.90.conv2.bias", + "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "blocks.90.conv2.weight", + "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "blocks.90.conv_shortcut.bias", + "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "blocks.90.conv_shortcut.weight", + "model.diffusion_model.output_blocks.0.0.time_mixer.mix_factor": "blocks.93.mix_factor", + "model.diffusion_model.output_blocks.0.0.time_stack.emb_layers.1.bias": "blocks.92.time_emb_proj.bias", + "model.diffusion_model.output_blocks.0.0.time_stack.emb_layers.1.weight": "blocks.92.time_emb_proj.weight", + "model.diffusion_model.output_blocks.0.0.time_stack.in_layers.0.bias": "blocks.92.norm1.bias", + "model.diffusion_model.output_blocks.0.0.time_stack.in_layers.0.weight": "blocks.92.norm1.weight", + "model.diffusion_model.output_blocks.0.0.time_stack.in_layers.2.bias": "blocks.92.conv1.bias", + "model.diffusion_model.output_blocks.0.0.time_stack.in_layers.2.weight": "blocks.92.conv1.weight", + "model.diffusion_model.output_blocks.0.0.time_stack.out_layers.0.bias": "blocks.92.norm2.bias", + "model.diffusion_model.output_blocks.0.0.time_stack.out_layers.0.weight": "blocks.92.norm2.weight", + "model.diffusion_model.output_blocks.0.0.time_stack.out_layers.3.bias": "blocks.92.conv2.bias", + "model.diffusion_model.output_blocks.0.0.time_stack.out_layers.3.weight": "blocks.92.conv2.weight", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "blocks.95.time_emb_proj.bias", + "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "blocks.95.time_emb_proj.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "blocks.95.norm1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "blocks.95.norm1.weight", + "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "blocks.95.conv1.bias", + "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "blocks.95.conv1.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "blocks.95.norm2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "blocks.95.norm2.weight", + "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "blocks.95.conv2.bias", + "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "blocks.95.conv2.weight", + "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "blocks.95.conv_shortcut.bias", + "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "blocks.95.conv_shortcut.weight", + "model.diffusion_model.output_blocks.1.0.time_mixer.mix_factor": "blocks.98.mix_factor", + "model.diffusion_model.output_blocks.1.0.time_stack.emb_layers.1.bias": "blocks.97.time_emb_proj.bias", + "model.diffusion_model.output_blocks.1.0.time_stack.emb_layers.1.weight": "blocks.97.time_emb_proj.weight", + "model.diffusion_model.output_blocks.1.0.time_stack.in_layers.0.bias": "blocks.97.norm1.bias", + "model.diffusion_model.output_blocks.1.0.time_stack.in_layers.0.weight": "blocks.97.norm1.weight", + "model.diffusion_model.output_blocks.1.0.time_stack.in_layers.2.bias": "blocks.97.conv1.bias", + "model.diffusion_model.output_blocks.1.0.time_stack.in_layers.2.weight": "blocks.97.conv1.weight", + "model.diffusion_model.output_blocks.1.0.time_stack.out_layers.0.bias": "blocks.97.norm2.bias", + "model.diffusion_model.output_blocks.1.0.time_stack.out_layers.0.weight": "blocks.97.norm2.weight", + "model.diffusion_model.output_blocks.1.0.time_stack.out_layers.3.bias": "blocks.97.conv2.bias", + "model.diffusion_model.output_blocks.1.0.time_stack.out_layers.3.weight": "blocks.97.conv2.weight", + "model.diffusion_model.output_blocks.10.0.emb_layers.1.bias": "blocks.178.time_emb_proj.bias", + "model.diffusion_model.output_blocks.10.0.emb_layers.1.weight": "blocks.178.time_emb_proj.weight", + "model.diffusion_model.output_blocks.10.0.in_layers.0.bias": "blocks.178.norm1.bias", + "model.diffusion_model.output_blocks.10.0.in_layers.0.weight": "blocks.178.norm1.weight", + "model.diffusion_model.output_blocks.10.0.in_layers.2.bias": "blocks.178.conv1.bias", + "model.diffusion_model.output_blocks.10.0.in_layers.2.weight": "blocks.178.conv1.weight", + "model.diffusion_model.output_blocks.10.0.out_layers.0.bias": "blocks.178.norm2.bias", + "model.diffusion_model.output_blocks.10.0.out_layers.0.weight": "blocks.178.norm2.weight", + "model.diffusion_model.output_blocks.10.0.out_layers.3.bias": "blocks.178.conv2.bias", + "model.diffusion_model.output_blocks.10.0.out_layers.3.weight": "blocks.178.conv2.weight", + "model.diffusion_model.output_blocks.10.0.skip_connection.bias": "blocks.178.conv_shortcut.bias", + "model.diffusion_model.output_blocks.10.0.skip_connection.weight": "blocks.178.conv_shortcut.weight", + "model.diffusion_model.output_blocks.10.0.time_mixer.mix_factor": "blocks.181.mix_factor", + "model.diffusion_model.output_blocks.10.0.time_stack.emb_layers.1.bias": "blocks.180.time_emb_proj.bias", + "model.diffusion_model.output_blocks.10.0.time_stack.emb_layers.1.weight": "blocks.180.time_emb_proj.weight", + "model.diffusion_model.output_blocks.10.0.time_stack.in_layers.0.bias": "blocks.180.norm1.bias", + "model.diffusion_model.output_blocks.10.0.time_stack.in_layers.0.weight": "blocks.180.norm1.weight", + "model.diffusion_model.output_blocks.10.0.time_stack.in_layers.2.bias": "blocks.180.conv1.bias", + "model.diffusion_model.output_blocks.10.0.time_stack.in_layers.2.weight": "blocks.180.conv1.weight", + "model.diffusion_model.output_blocks.10.0.time_stack.out_layers.0.bias": "blocks.180.norm2.bias", + "model.diffusion_model.output_blocks.10.0.time_stack.out_layers.0.weight": "blocks.180.norm2.weight", + "model.diffusion_model.output_blocks.10.0.time_stack.out_layers.3.bias": "blocks.180.conv2.bias", + "model.diffusion_model.output_blocks.10.0.time_stack.out_layers.3.weight": "blocks.180.conv2.weight", + "model.diffusion_model.output_blocks.10.1.norm.bias": "blocks.183.norm.bias", + "model.diffusion_model.output_blocks.10.1.norm.weight": "blocks.183.norm.weight", + "model.diffusion_model.output_blocks.10.1.proj_in.bias": "blocks.183.proj_in.bias", + "model.diffusion_model.output_blocks.10.1.proj_in.weight": "blocks.183.proj_in.weight", + "model.diffusion_model.output_blocks.10.1.proj_out.bias": "blocks.186.proj.bias", + "model.diffusion_model.output_blocks.10.1.proj_out.weight": "blocks.186.proj.weight", + "model.diffusion_model.output_blocks.10.1.time_mixer.mix_factor": "blocks.186.mix_factor", + "model.diffusion_model.output_blocks.10.1.time_pos_embed.0.bias": "blocks.185.positional_embedding_proj.0.bias", + "model.diffusion_model.output_blocks.10.1.time_pos_embed.0.weight": "blocks.185.positional_embedding_proj.0.weight", + "model.diffusion_model.output_blocks.10.1.time_pos_embed.2.bias": "blocks.185.positional_embedding_proj.2.bias", + "model.diffusion_model.output_blocks.10.1.time_pos_embed.2.weight": "blocks.185.positional_embedding_proj.2.weight", + "model.diffusion_model.output_blocks.10.1.time_stack.0.attn1.to_k.weight": "blocks.185.attn1.to_k.weight", + "model.diffusion_model.output_blocks.10.1.time_stack.0.attn1.to_out.0.bias": "blocks.185.attn1.to_out.bias", + "model.diffusion_model.output_blocks.10.1.time_stack.0.attn1.to_out.0.weight": "blocks.185.attn1.to_out.weight", + "model.diffusion_model.output_blocks.10.1.time_stack.0.attn1.to_q.weight": "blocks.185.attn1.to_q.weight", + "model.diffusion_model.output_blocks.10.1.time_stack.0.attn1.to_v.weight": "blocks.185.attn1.to_v.weight", + "model.diffusion_model.output_blocks.10.1.time_stack.0.attn2.to_k.weight": "blocks.185.attn2.to_k.weight", + "model.diffusion_model.output_blocks.10.1.time_stack.0.attn2.to_out.0.bias": "blocks.185.attn2.to_out.bias", + "model.diffusion_model.output_blocks.10.1.time_stack.0.attn2.to_out.0.weight": "blocks.185.attn2.to_out.weight", + "model.diffusion_model.output_blocks.10.1.time_stack.0.attn2.to_q.weight": "blocks.185.attn2.to_q.weight", + "model.diffusion_model.output_blocks.10.1.time_stack.0.attn2.to_v.weight": "blocks.185.attn2.to_v.weight", + "model.diffusion_model.output_blocks.10.1.time_stack.0.ff.net.0.proj.bias": "blocks.185.act_fn_out.proj.bias", + "model.diffusion_model.output_blocks.10.1.time_stack.0.ff.net.0.proj.weight": "blocks.185.act_fn_out.proj.weight", + "model.diffusion_model.output_blocks.10.1.time_stack.0.ff.net.2.bias": "blocks.185.ff_out.bias", + "model.diffusion_model.output_blocks.10.1.time_stack.0.ff.net.2.weight": "blocks.185.ff_out.weight", + "model.diffusion_model.output_blocks.10.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.185.act_fn_in.proj.bias", + "model.diffusion_model.output_blocks.10.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.185.act_fn_in.proj.weight", + "model.diffusion_model.output_blocks.10.1.time_stack.0.ff_in.net.2.bias": "blocks.185.ff_in.bias", + "model.diffusion_model.output_blocks.10.1.time_stack.0.ff_in.net.2.weight": "blocks.185.ff_in.weight", + "model.diffusion_model.output_blocks.10.1.time_stack.0.norm1.bias": "blocks.185.norm1.bias", + "model.diffusion_model.output_blocks.10.1.time_stack.0.norm1.weight": "blocks.185.norm1.weight", + "model.diffusion_model.output_blocks.10.1.time_stack.0.norm2.bias": "blocks.185.norm2.bias", + "model.diffusion_model.output_blocks.10.1.time_stack.0.norm2.weight": "blocks.185.norm2.weight", + "model.diffusion_model.output_blocks.10.1.time_stack.0.norm3.bias": "blocks.185.norm_out.bias", + "model.diffusion_model.output_blocks.10.1.time_stack.0.norm3.weight": "blocks.185.norm_out.weight", + "model.diffusion_model.output_blocks.10.1.time_stack.0.norm_in.bias": "blocks.185.norm_in.bias", + "model.diffusion_model.output_blocks.10.1.time_stack.0.norm_in.weight": "blocks.185.norm_in.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_k.weight": "blocks.183.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.183.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.183.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_q.weight": "blocks.183.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_v.weight": "blocks.183.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k.weight": "blocks.183.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.183.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.183.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q.weight": "blocks.183.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v.weight": "blocks.183.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.183.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.183.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.bias": "blocks.183.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.weight": "blocks.183.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.bias": "blocks.183.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.weight": "blocks.183.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.bias": "blocks.183.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.weight": "blocks.183.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.bias": "blocks.183.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.weight": "blocks.183.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.11.0.emb_layers.1.bias": "blocks.188.time_emb_proj.bias", + "model.diffusion_model.output_blocks.11.0.emb_layers.1.weight": "blocks.188.time_emb_proj.weight", + "model.diffusion_model.output_blocks.11.0.in_layers.0.bias": "blocks.188.norm1.bias", + "model.diffusion_model.output_blocks.11.0.in_layers.0.weight": "blocks.188.norm1.weight", + "model.diffusion_model.output_blocks.11.0.in_layers.2.bias": "blocks.188.conv1.bias", + "model.diffusion_model.output_blocks.11.0.in_layers.2.weight": "blocks.188.conv1.weight", + "model.diffusion_model.output_blocks.11.0.out_layers.0.bias": "blocks.188.norm2.bias", + "model.diffusion_model.output_blocks.11.0.out_layers.0.weight": "blocks.188.norm2.weight", + "model.diffusion_model.output_blocks.11.0.out_layers.3.bias": "blocks.188.conv2.bias", + "model.diffusion_model.output_blocks.11.0.out_layers.3.weight": "blocks.188.conv2.weight", + "model.diffusion_model.output_blocks.11.0.skip_connection.bias": "blocks.188.conv_shortcut.bias", + "model.diffusion_model.output_blocks.11.0.skip_connection.weight": "blocks.188.conv_shortcut.weight", + "model.diffusion_model.output_blocks.11.0.time_mixer.mix_factor": "blocks.191.mix_factor", + "model.diffusion_model.output_blocks.11.0.time_stack.emb_layers.1.bias": "blocks.190.time_emb_proj.bias", + "model.diffusion_model.output_blocks.11.0.time_stack.emb_layers.1.weight": "blocks.190.time_emb_proj.weight", + "model.diffusion_model.output_blocks.11.0.time_stack.in_layers.0.bias": "blocks.190.norm1.bias", + "model.diffusion_model.output_blocks.11.0.time_stack.in_layers.0.weight": "blocks.190.norm1.weight", + "model.diffusion_model.output_blocks.11.0.time_stack.in_layers.2.bias": "blocks.190.conv1.bias", + "model.diffusion_model.output_blocks.11.0.time_stack.in_layers.2.weight": "blocks.190.conv1.weight", + "model.diffusion_model.output_blocks.11.0.time_stack.out_layers.0.bias": "blocks.190.norm2.bias", + "model.diffusion_model.output_blocks.11.0.time_stack.out_layers.0.weight": "blocks.190.norm2.weight", + "model.diffusion_model.output_blocks.11.0.time_stack.out_layers.3.bias": "blocks.190.conv2.bias", + "model.diffusion_model.output_blocks.11.0.time_stack.out_layers.3.weight": "blocks.190.conv2.weight", + "model.diffusion_model.output_blocks.11.1.norm.bias": "blocks.193.norm.bias", + "model.diffusion_model.output_blocks.11.1.norm.weight": "blocks.193.norm.weight", + "model.diffusion_model.output_blocks.11.1.proj_in.bias": "blocks.193.proj_in.bias", + "model.diffusion_model.output_blocks.11.1.proj_in.weight": "blocks.193.proj_in.weight", + "model.diffusion_model.output_blocks.11.1.proj_out.bias": "blocks.196.proj.bias", + "model.diffusion_model.output_blocks.11.1.proj_out.weight": "blocks.196.proj.weight", + "model.diffusion_model.output_blocks.11.1.time_mixer.mix_factor": "blocks.196.mix_factor", + "model.diffusion_model.output_blocks.11.1.time_pos_embed.0.bias": "blocks.195.positional_embedding_proj.0.bias", + "model.diffusion_model.output_blocks.11.1.time_pos_embed.0.weight": "blocks.195.positional_embedding_proj.0.weight", + "model.diffusion_model.output_blocks.11.1.time_pos_embed.2.bias": "blocks.195.positional_embedding_proj.2.bias", + "model.diffusion_model.output_blocks.11.1.time_pos_embed.2.weight": "blocks.195.positional_embedding_proj.2.weight", + "model.diffusion_model.output_blocks.11.1.time_stack.0.attn1.to_k.weight": "blocks.195.attn1.to_k.weight", + "model.diffusion_model.output_blocks.11.1.time_stack.0.attn1.to_out.0.bias": "blocks.195.attn1.to_out.bias", + "model.diffusion_model.output_blocks.11.1.time_stack.0.attn1.to_out.0.weight": "blocks.195.attn1.to_out.weight", + "model.diffusion_model.output_blocks.11.1.time_stack.0.attn1.to_q.weight": "blocks.195.attn1.to_q.weight", + "model.diffusion_model.output_blocks.11.1.time_stack.0.attn1.to_v.weight": "blocks.195.attn1.to_v.weight", + "model.diffusion_model.output_blocks.11.1.time_stack.0.attn2.to_k.weight": "blocks.195.attn2.to_k.weight", + "model.diffusion_model.output_blocks.11.1.time_stack.0.attn2.to_out.0.bias": "blocks.195.attn2.to_out.bias", + "model.diffusion_model.output_blocks.11.1.time_stack.0.attn2.to_out.0.weight": "blocks.195.attn2.to_out.weight", + "model.diffusion_model.output_blocks.11.1.time_stack.0.attn2.to_q.weight": "blocks.195.attn2.to_q.weight", + "model.diffusion_model.output_blocks.11.1.time_stack.0.attn2.to_v.weight": "blocks.195.attn2.to_v.weight", + "model.diffusion_model.output_blocks.11.1.time_stack.0.ff.net.0.proj.bias": "blocks.195.act_fn_out.proj.bias", + "model.diffusion_model.output_blocks.11.1.time_stack.0.ff.net.0.proj.weight": "blocks.195.act_fn_out.proj.weight", + "model.diffusion_model.output_blocks.11.1.time_stack.0.ff.net.2.bias": "blocks.195.ff_out.bias", + "model.diffusion_model.output_blocks.11.1.time_stack.0.ff.net.2.weight": "blocks.195.ff_out.weight", + "model.diffusion_model.output_blocks.11.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.195.act_fn_in.proj.bias", + "model.diffusion_model.output_blocks.11.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.195.act_fn_in.proj.weight", + "model.diffusion_model.output_blocks.11.1.time_stack.0.ff_in.net.2.bias": "blocks.195.ff_in.bias", + "model.diffusion_model.output_blocks.11.1.time_stack.0.ff_in.net.2.weight": "blocks.195.ff_in.weight", + "model.diffusion_model.output_blocks.11.1.time_stack.0.norm1.bias": "blocks.195.norm1.bias", + "model.diffusion_model.output_blocks.11.1.time_stack.0.norm1.weight": "blocks.195.norm1.weight", + "model.diffusion_model.output_blocks.11.1.time_stack.0.norm2.bias": "blocks.195.norm2.bias", + "model.diffusion_model.output_blocks.11.1.time_stack.0.norm2.weight": "blocks.195.norm2.weight", + "model.diffusion_model.output_blocks.11.1.time_stack.0.norm3.bias": "blocks.195.norm_out.bias", + "model.diffusion_model.output_blocks.11.1.time_stack.0.norm3.weight": "blocks.195.norm_out.weight", + "model.diffusion_model.output_blocks.11.1.time_stack.0.norm_in.bias": "blocks.195.norm_in.bias", + "model.diffusion_model.output_blocks.11.1.time_stack.0.norm_in.weight": "blocks.195.norm_in.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_k.weight": "blocks.193.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.193.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.193.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_q.weight": "blocks.193.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_v.weight": "blocks.193.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k.weight": "blocks.193.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.193.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.193.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q.weight": "blocks.193.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v.weight": "blocks.193.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.193.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.193.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.bias": "blocks.193.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.weight": "blocks.193.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias": "blocks.193.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.weight": "blocks.193.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.bias": "blocks.193.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.weight": "blocks.193.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.bias": "blocks.193.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.weight": "blocks.193.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "blocks.100.time_emb_proj.bias", + "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "blocks.100.time_emb_proj.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "blocks.100.norm1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "blocks.100.norm1.weight", + "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "blocks.100.conv1.bias", + "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "blocks.100.conv1.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "blocks.100.norm2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "blocks.100.norm2.weight", + "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "blocks.100.conv2.bias", + "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "blocks.100.conv2.weight", + "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "blocks.100.conv_shortcut.bias", + "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "blocks.100.conv_shortcut.weight", + "model.diffusion_model.output_blocks.2.0.time_mixer.mix_factor": "blocks.103.mix_factor", + "model.diffusion_model.output_blocks.2.0.time_stack.emb_layers.1.bias": "blocks.102.time_emb_proj.bias", + "model.diffusion_model.output_blocks.2.0.time_stack.emb_layers.1.weight": "blocks.102.time_emb_proj.weight", + "model.diffusion_model.output_blocks.2.0.time_stack.in_layers.0.bias": "blocks.102.norm1.bias", + "model.diffusion_model.output_blocks.2.0.time_stack.in_layers.0.weight": "blocks.102.norm1.weight", + "model.diffusion_model.output_blocks.2.0.time_stack.in_layers.2.bias": "blocks.102.conv1.bias", + "model.diffusion_model.output_blocks.2.0.time_stack.in_layers.2.weight": "blocks.102.conv1.weight", + "model.diffusion_model.output_blocks.2.0.time_stack.out_layers.0.bias": "blocks.102.norm2.bias", + "model.diffusion_model.output_blocks.2.0.time_stack.out_layers.0.weight": "blocks.102.norm2.weight", + "model.diffusion_model.output_blocks.2.0.time_stack.out_layers.3.bias": "blocks.102.conv2.bias", + "model.diffusion_model.output_blocks.2.0.time_stack.out_layers.3.weight": "blocks.102.conv2.weight", + "model.diffusion_model.output_blocks.2.1.conv.bias": "blocks.104.conv.bias", + "model.diffusion_model.output_blocks.2.1.conv.weight": "blocks.104.conv.weight", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "blocks.106.time_emb_proj.bias", + "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "blocks.106.time_emb_proj.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "blocks.106.norm1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "blocks.106.norm1.weight", + "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "blocks.106.conv1.bias", + "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "blocks.106.conv1.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "blocks.106.norm2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "blocks.106.norm2.weight", + "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "blocks.106.conv2.bias", + "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "blocks.106.conv2.weight", + "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "blocks.106.conv_shortcut.bias", + "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "blocks.106.conv_shortcut.weight", + "model.diffusion_model.output_blocks.3.0.time_mixer.mix_factor": "blocks.109.mix_factor", + "model.diffusion_model.output_blocks.3.0.time_stack.emb_layers.1.bias": "blocks.108.time_emb_proj.bias", + "model.diffusion_model.output_blocks.3.0.time_stack.emb_layers.1.weight": "blocks.108.time_emb_proj.weight", + "model.diffusion_model.output_blocks.3.0.time_stack.in_layers.0.bias": "blocks.108.norm1.bias", + "model.diffusion_model.output_blocks.3.0.time_stack.in_layers.0.weight": "blocks.108.norm1.weight", + "model.diffusion_model.output_blocks.3.0.time_stack.in_layers.2.bias": "blocks.108.conv1.bias", + "model.diffusion_model.output_blocks.3.0.time_stack.in_layers.2.weight": "blocks.108.conv1.weight", + "model.diffusion_model.output_blocks.3.0.time_stack.out_layers.0.bias": "blocks.108.norm2.bias", + "model.diffusion_model.output_blocks.3.0.time_stack.out_layers.0.weight": "blocks.108.norm2.weight", + "model.diffusion_model.output_blocks.3.0.time_stack.out_layers.3.bias": "blocks.108.conv2.bias", + "model.diffusion_model.output_blocks.3.0.time_stack.out_layers.3.weight": "blocks.108.conv2.weight", + "model.diffusion_model.output_blocks.3.1.norm.bias": "blocks.111.norm.bias", + "model.diffusion_model.output_blocks.3.1.norm.weight": "blocks.111.norm.weight", + "model.diffusion_model.output_blocks.3.1.proj_in.bias": "blocks.111.proj_in.bias", + "model.diffusion_model.output_blocks.3.1.proj_in.weight": "blocks.111.proj_in.weight", + "model.diffusion_model.output_blocks.3.1.proj_out.bias": "blocks.114.proj.bias", + "model.diffusion_model.output_blocks.3.1.proj_out.weight": "blocks.114.proj.weight", + "model.diffusion_model.output_blocks.3.1.time_mixer.mix_factor": "blocks.114.mix_factor", + "model.diffusion_model.output_blocks.3.1.time_pos_embed.0.bias": "blocks.113.positional_embedding_proj.0.bias", + "model.diffusion_model.output_blocks.3.1.time_pos_embed.0.weight": "blocks.113.positional_embedding_proj.0.weight", + "model.diffusion_model.output_blocks.3.1.time_pos_embed.2.bias": "blocks.113.positional_embedding_proj.2.bias", + "model.diffusion_model.output_blocks.3.1.time_pos_embed.2.weight": "blocks.113.positional_embedding_proj.2.weight", + "model.diffusion_model.output_blocks.3.1.time_stack.0.attn1.to_k.weight": "blocks.113.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.time_stack.0.attn1.to_out.0.bias": "blocks.113.attn1.to_out.bias", + "model.diffusion_model.output_blocks.3.1.time_stack.0.attn1.to_out.0.weight": "blocks.113.attn1.to_out.weight", + "model.diffusion_model.output_blocks.3.1.time_stack.0.attn1.to_q.weight": "blocks.113.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.time_stack.0.attn1.to_v.weight": "blocks.113.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.time_stack.0.attn2.to_k.weight": "blocks.113.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.time_stack.0.attn2.to_out.0.bias": "blocks.113.attn2.to_out.bias", + "model.diffusion_model.output_blocks.3.1.time_stack.0.attn2.to_out.0.weight": "blocks.113.attn2.to_out.weight", + "model.diffusion_model.output_blocks.3.1.time_stack.0.attn2.to_q.weight": "blocks.113.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.time_stack.0.attn2.to_v.weight": "blocks.113.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.time_stack.0.ff.net.0.proj.bias": "blocks.113.act_fn_out.proj.bias", + "model.diffusion_model.output_blocks.3.1.time_stack.0.ff.net.0.proj.weight": "blocks.113.act_fn_out.proj.weight", + "model.diffusion_model.output_blocks.3.1.time_stack.0.ff.net.2.bias": "blocks.113.ff_out.bias", + "model.diffusion_model.output_blocks.3.1.time_stack.0.ff.net.2.weight": "blocks.113.ff_out.weight", + "model.diffusion_model.output_blocks.3.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.113.act_fn_in.proj.bias", + "model.diffusion_model.output_blocks.3.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.113.act_fn_in.proj.weight", + "model.diffusion_model.output_blocks.3.1.time_stack.0.ff_in.net.2.bias": "blocks.113.ff_in.bias", + "model.diffusion_model.output_blocks.3.1.time_stack.0.ff_in.net.2.weight": "blocks.113.ff_in.weight", + "model.diffusion_model.output_blocks.3.1.time_stack.0.norm1.bias": "blocks.113.norm1.bias", + "model.diffusion_model.output_blocks.3.1.time_stack.0.norm1.weight": "blocks.113.norm1.weight", + "model.diffusion_model.output_blocks.3.1.time_stack.0.norm2.bias": "blocks.113.norm2.bias", + "model.diffusion_model.output_blocks.3.1.time_stack.0.norm2.weight": "blocks.113.norm2.weight", + "model.diffusion_model.output_blocks.3.1.time_stack.0.norm3.bias": "blocks.113.norm_out.bias", + "model.diffusion_model.output_blocks.3.1.time_stack.0.norm3.weight": "blocks.113.norm_out.weight", + "model.diffusion_model.output_blocks.3.1.time_stack.0.norm_in.bias": "blocks.113.norm_in.bias", + "model.diffusion_model.output_blocks.3.1.time_stack.0.norm_in.weight": "blocks.113.norm_in.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "blocks.111.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.111.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.111.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "blocks.111.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "blocks.111.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "blocks.111.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.111.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.111.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "blocks.111.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "blocks.111.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.111.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.111.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "blocks.111.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "blocks.111.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "blocks.111.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "blocks.111.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "blocks.111.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "blocks.111.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "blocks.111.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "blocks.111.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "blocks.116.time_emb_proj.bias", + "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "blocks.116.time_emb_proj.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "blocks.116.norm1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "blocks.116.norm1.weight", + "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "blocks.116.conv1.bias", + "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "blocks.116.conv1.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "blocks.116.norm2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "blocks.116.norm2.weight", + "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "blocks.116.conv2.bias", + "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "blocks.116.conv2.weight", + "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "blocks.116.conv_shortcut.bias", + "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "blocks.116.conv_shortcut.weight", + "model.diffusion_model.output_blocks.4.0.time_mixer.mix_factor": "blocks.119.mix_factor", + "model.diffusion_model.output_blocks.4.0.time_stack.emb_layers.1.bias": "blocks.118.time_emb_proj.bias", + "model.diffusion_model.output_blocks.4.0.time_stack.emb_layers.1.weight": "blocks.118.time_emb_proj.weight", + "model.diffusion_model.output_blocks.4.0.time_stack.in_layers.0.bias": "blocks.118.norm1.bias", + "model.diffusion_model.output_blocks.4.0.time_stack.in_layers.0.weight": "blocks.118.norm1.weight", + "model.diffusion_model.output_blocks.4.0.time_stack.in_layers.2.bias": "blocks.118.conv1.bias", + "model.diffusion_model.output_blocks.4.0.time_stack.in_layers.2.weight": "blocks.118.conv1.weight", + "model.diffusion_model.output_blocks.4.0.time_stack.out_layers.0.bias": "blocks.118.norm2.bias", + "model.diffusion_model.output_blocks.4.0.time_stack.out_layers.0.weight": "blocks.118.norm2.weight", + "model.diffusion_model.output_blocks.4.0.time_stack.out_layers.3.bias": "blocks.118.conv2.bias", + "model.diffusion_model.output_blocks.4.0.time_stack.out_layers.3.weight": "blocks.118.conv2.weight", + "model.diffusion_model.output_blocks.4.1.norm.bias": "blocks.121.norm.bias", + "model.diffusion_model.output_blocks.4.1.norm.weight": "blocks.121.norm.weight", + "model.diffusion_model.output_blocks.4.1.proj_in.bias": "blocks.121.proj_in.bias", + "model.diffusion_model.output_blocks.4.1.proj_in.weight": "blocks.121.proj_in.weight", + "model.diffusion_model.output_blocks.4.1.proj_out.bias": "blocks.124.proj.bias", + "model.diffusion_model.output_blocks.4.1.proj_out.weight": "blocks.124.proj.weight", + "model.diffusion_model.output_blocks.4.1.time_mixer.mix_factor": "blocks.124.mix_factor", + "model.diffusion_model.output_blocks.4.1.time_pos_embed.0.bias": "blocks.123.positional_embedding_proj.0.bias", + "model.diffusion_model.output_blocks.4.1.time_pos_embed.0.weight": "blocks.123.positional_embedding_proj.0.weight", + "model.diffusion_model.output_blocks.4.1.time_pos_embed.2.bias": "blocks.123.positional_embedding_proj.2.bias", + "model.diffusion_model.output_blocks.4.1.time_pos_embed.2.weight": "blocks.123.positional_embedding_proj.2.weight", + "model.diffusion_model.output_blocks.4.1.time_stack.0.attn1.to_k.weight": "blocks.123.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.time_stack.0.attn1.to_out.0.bias": "blocks.123.attn1.to_out.bias", + "model.diffusion_model.output_blocks.4.1.time_stack.0.attn1.to_out.0.weight": "blocks.123.attn1.to_out.weight", + "model.diffusion_model.output_blocks.4.1.time_stack.0.attn1.to_q.weight": "blocks.123.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.time_stack.0.attn1.to_v.weight": "blocks.123.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.time_stack.0.attn2.to_k.weight": "blocks.123.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.time_stack.0.attn2.to_out.0.bias": "blocks.123.attn2.to_out.bias", + "model.diffusion_model.output_blocks.4.1.time_stack.0.attn2.to_out.0.weight": "blocks.123.attn2.to_out.weight", + "model.diffusion_model.output_blocks.4.1.time_stack.0.attn2.to_q.weight": "blocks.123.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.time_stack.0.attn2.to_v.weight": "blocks.123.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.time_stack.0.ff.net.0.proj.bias": "blocks.123.act_fn_out.proj.bias", + "model.diffusion_model.output_blocks.4.1.time_stack.0.ff.net.0.proj.weight": "blocks.123.act_fn_out.proj.weight", + "model.diffusion_model.output_blocks.4.1.time_stack.0.ff.net.2.bias": "blocks.123.ff_out.bias", + "model.diffusion_model.output_blocks.4.1.time_stack.0.ff.net.2.weight": "blocks.123.ff_out.weight", + "model.diffusion_model.output_blocks.4.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.123.act_fn_in.proj.bias", + "model.diffusion_model.output_blocks.4.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.123.act_fn_in.proj.weight", + "model.diffusion_model.output_blocks.4.1.time_stack.0.ff_in.net.2.bias": "blocks.123.ff_in.bias", + "model.diffusion_model.output_blocks.4.1.time_stack.0.ff_in.net.2.weight": "blocks.123.ff_in.weight", + "model.diffusion_model.output_blocks.4.1.time_stack.0.norm1.bias": "blocks.123.norm1.bias", + "model.diffusion_model.output_blocks.4.1.time_stack.0.norm1.weight": "blocks.123.norm1.weight", + "model.diffusion_model.output_blocks.4.1.time_stack.0.norm2.bias": "blocks.123.norm2.bias", + "model.diffusion_model.output_blocks.4.1.time_stack.0.norm2.weight": "blocks.123.norm2.weight", + "model.diffusion_model.output_blocks.4.1.time_stack.0.norm3.bias": "blocks.123.norm_out.bias", + "model.diffusion_model.output_blocks.4.1.time_stack.0.norm3.weight": "blocks.123.norm_out.weight", + "model.diffusion_model.output_blocks.4.1.time_stack.0.norm_in.bias": "blocks.123.norm_in.bias", + "model.diffusion_model.output_blocks.4.1.time_stack.0.norm_in.weight": "blocks.123.norm_in.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.121.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.121.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.121.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.121.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.121.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.121.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.121.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.121.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.121.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.121.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.121.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.121.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.121.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.121.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.121.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.121.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.121.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.121.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.121.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.121.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "blocks.126.time_emb_proj.bias", + "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "blocks.126.time_emb_proj.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "blocks.126.norm1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "blocks.126.norm1.weight", + "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "blocks.126.conv1.bias", + "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "blocks.126.conv1.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "blocks.126.norm2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "blocks.126.norm2.weight", + "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "blocks.126.conv2.bias", + "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "blocks.126.conv2.weight", + "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "blocks.126.conv_shortcut.bias", + "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "blocks.126.conv_shortcut.weight", + "model.diffusion_model.output_blocks.5.0.time_mixer.mix_factor": "blocks.129.mix_factor", + "model.diffusion_model.output_blocks.5.0.time_stack.emb_layers.1.bias": "blocks.128.time_emb_proj.bias", + "model.diffusion_model.output_blocks.5.0.time_stack.emb_layers.1.weight": "blocks.128.time_emb_proj.weight", + "model.diffusion_model.output_blocks.5.0.time_stack.in_layers.0.bias": "blocks.128.norm1.bias", + "model.diffusion_model.output_blocks.5.0.time_stack.in_layers.0.weight": "blocks.128.norm1.weight", + "model.diffusion_model.output_blocks.5.0.time_stack.in_layers.2.bias": "blocks.128.conv1.bias", + "model.diffusion_model.output_blocks.5.0.time_stack.in_layers.2.weight": "blocks.128.conv1.weight", + "model.diffusion_model.output_blocks.5.0.time_stack.out_layers.0.bias": "blocks.128.norm2.bias", + "model.diffusion_model.output_blocks.5.0.time_stack.out_layers.0.weight": "blocks.128.norm2.weight", + "model.diffusion_model.output_blocks.5.0.time_stack.out_layers.3.bias": "blocks.128.conv2.bias", + "model.diffusion_model.output_blocks.5.0.time_stack.out_layers.3.weight": "blocks.128.conv2.weight", + "model.diffusion_model.output_blocks.5.1.norm.bias": "blocks.131.norm.bias", + "model.diffusion_model.output_blocks.5.1.norm.weight": "blocks.131.norm.weight", + "model.diffusion_model.output_blocks.5.1.proj_in.bias": "blocks.131.proj_in.bias", + "model.diffusion_model.output_blocks.5.1.proj_in.weight": "blocks.131.proj_in.weight", + "model.diffusion_model.output_blocks.5.1.proj_out.bias": "blocks.134.proj.bias", + "model.diffusion_model.output_blocks.5.1.proj_out.weight": "blocks.134.proj.weight", + "model.diffusion_model.output_blocks.5.1.time_mixer.mix_factor": "blocks.134.mix_factor", + "model.diffusion_model.output_blocks.5.1.time_pos_embed.0.bias": "blocks.133.positional_embedding_proj.0.bias", + "model.diffusion_model.output_blocks.5.1.time_pos_embed.0.weight": "blocks.133.positional_embedding_proj.0.weight", + "model.diffusion_model.output_blocks.5.1.time_pos_embed.2.bias": "blocks.133.positional_embedding_proj.2.bias", + "model.diffusion_model.output_blocks.5.1.time_pos_embed.2.weight": "blocks.133.positional_embedding_proj.2.weight", + "model.diffusion_model.output_blocks.5.1.time_stack.0.attn1.to_k.weight": "blocks.133.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.time_stack.0.attn1.to_out.0.bias": "blocks.133.attn1.to_out.bias", + "model.diffusion_model.output_blocks.5.1.time_stack.0.attn1.to_out.0.weight": "blocks.133.attn1.to_out.weight", + "model.diffusion_model.output_blocks.5.1.time_stack.0.attn1.to_q.weight": "blocks.133.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.time_stack.0.attn1.to_v.weight": "blocks.133.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.time_stack.0.attn2.to_k.weight": "blocks.133.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.time_stack.0.attn2.to_out.0.bias": "blocks.133.attn2.to_out.bias", + "model.diffusion_model.output_blocks.5.1.time_stack.0.attn2.to_out.0.weight": "blocks.133.attn2.to_out.weight", + "model.diffusion_model.output_blocks.5.1.time_stack.0.attn2.to_q.weight": "blocks.133.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.time_stack.0.attn2.to_v.weight": "blocks.133.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.time_stack.0.ff.net.0.proj.bias": "blocks.133.act_fn_out.proj.bias", + "model.diffusion_model.output_blocks.5.1.time_stack.0.ff.net.0.proj.weight": "blocks.133.act_fn_out.proj.weight", + "model.diffusion_model.output_blocks.5.1.time_stack.0.ff.net.2.bias": "blocks.133.ff_out.bias", + "model.diffusion_model.output_blocks.5.1.time_stack.0.ff.net.2.weight": "blocks.133.ff_out.weight", + "model.diffusion_model.output_blocks.5.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.133.act_fn_in.proj.bias", + "model.diffusion_model.output_blocks.5.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.133.act_fn_in.proj.weight", + "model.diffusion_model.output_blocks.5.1.time_stack.0.ff_in.net.2.bias": "blocks.133.ff_in.bias", + "model.diffusion_model.output_blocks.5.1.time_stack.0.ff_in.net.2.weight": "blocks.133.ff_in.weight", + "model.diffusion_model.output_blocks.5.1.time_stack.0.norm1.bias": "blocks.133.norm1.bias", + "model.diffusion_model.output_blocks.5.1.time_stack.0.norm1.weight": "blocks.133.norm1.weight", + "model.diffusion_model.output_blocks.5.1.time_stack.0.norm2.bias": "blocks.133.norm2.bias", + "model.diffusion_model.output_blocks.5.1.time_stack.0.norm2.weight": "blocks.133.norm2.weight", + "model.diffusion_model.output_blocks.5.1.time_stack.0.norm3.bias": "blocks.133.norm_out.bias", + "model.diffusion_model.output_blocks.5.1.time_stack.0.norm3.weight": "blocks.133.norm_out.weight", + "model.diffusion_model.output_blocks.5.1.time_stack.0.norm_in.bias": "blocks.133.norm_in.bias", + "model.diffusion_model.output_blocks.5.1.time_stack.0.norm_in.weight": "blocks.133.norm_in.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.131.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.131.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.131.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.131.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.131.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.131.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.131.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.131.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.131.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.131.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.131.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.131.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.131.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.131.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.131.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.131.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.131.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.131.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.131.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.131.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.5.2.conv.bias": "blocks.135.conv.bias", + "model.diffusion_model.output_blocks.5.2.conv.weight": "blocks.135.conv.weight", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "blocks.137.time_emb_proj.bias", + "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "blocks.137.time_emb_proj.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "blocks.137.norm1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "blocks.137.norm1.weight", + "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "blocks.137.conv1.bias", + "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "blocks.137.conv1.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "blocks.137.norm2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "blocks.137.norm2.weight", + "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "blocks.137.conv2.bias", + "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "blocks.137.conv2.weight", + "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "blocks.137.conv_shortcut.bias", + "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "blocks.137.conv_shortcut.weight", + "model.diffusion_model.output_blocks.6.0.time_mixer.mix_factor": "blocks.140.mix_factor", + "model.diffusion_model.output_blocks.6.0.time_stack.emb_layers.1.bias": "blocks.139.time_emb_proj.bias", + "model.diffusion_model.output_blocks.6.0.time_stack.emb_layers.1.weight": "blocks.139.time_emb_proj.weight", + "model.diffusion_model.output_blocks.6.0.time_stack.in_layers.0.bias": "blocks.139.norm1.bias", + "model.diffusion_model.output_blocks.6.0.time_stack.in_layers.0.weight": "blocks.139.norm1.weight", + "model.diffusion_model.output_blocks.6.0.time_stack.in_layers.2.bias": "blocks.139.conv1.bias", + "model.diffusion_model.output_blocks.6.0.time_stack.in_layers.2.weight": "blocks.139.conv1.weight", + "model.diffusion_model.output_blocks.6.0.time_stack.out_layers.0.bias": "blocks.139.norm2.bias", + "model.diffusion_model.output_blocks.6.0.time_stack.out_layers.0.weight": "blocks.139.norm2.weight", + "model.diffusion_model.output_blocks.6.0.time_stack.out_layers.3.bias": "blocks.139.conv2.bias", + "model.diffusion_model.output_blocks.6.0.time_stack.out_layers.3.weight": "blocks.139.conv2.weight", + "model.diffusion_model.output_blocks.6.1.norm.bias": "blocks.142.norm.bias", + "model.diffusion_model.output_blocks.6.1.norm.weight": "blocks.142.norm.weight", + "model.diffusion_model.output_blocks.6.1.proj_in.bias": "blocks.142.proj_in.bias", + "model.diffusion_model.output_blocks.6.1.proj_in.weight": "blocks.142.proj_in.weight", + "model.diffusion_model.output_blocks.6.1.proj_out.bias": "blocks.145.proj.bias", + "model.diffusion_model.output_blocks.6.1.proj_out.weight": "blocks.145.proj.weight", + "model.diffusion_model.output_blocks.6.1.time_mixer.mix_factor": "blocks.145.mix_factor", + "model.diffusion_model.output_blocks.6.1.time_pos_embed.0.bias": "blocks.144.positional_embedding_proj.0.bias", + "model.diffusion_model.output_blocks.6.1.time_pos_embed.0.weight": "blocks.144.positional_embedding_proj.0.weight", + "model.diffusion_model.output_blocks.6.1.time_pos_embed.2.bias": "blocks.144.positional_embedding_proj.2.bias", + "model.diffusion_model.output_blocks.6.1.time_pos_embed.2.weight": "blocks.144.positional_embedding_proj.2.weight", + "model.diffusion_model.output_blocks.6.1.time_stack.0.attn1.to_k.weight": "blocks.144.attn1.to_k.weight", + "model.diffusion_model.output_blocks.6.1.time_stack.0.attn1.to_out.0.bias": "blocks.144.attn1.to_out.bias", + "model.diffusion_model.output_blocks.6.1.time_stack.0.attn1.to_out.0.weight": "blocks.144.attn1.to_out.weight", + "model.diffusion_model.output_blocks.6.1.time_stack.0.attn1.to_q.weight": "blocks.144.attn1.to_q.weight", + "model.diffusion_model.output_blocks.6.1.time_stack.0.attn1.to_v.weight": "blocks.144.attn1.to_v.weight", + "model.diffusion_model.output_blocks.6.1.time_stack.0.attn2.to_k.weight": "blocks.144.attn2.to_k.weight", + "model.diffusion_model.output_blocks.6.1.time_stack.0.attn2.to_out.0.bias": "blocks.144.attn2.to_out.bias", + "model.diffusion_model.output_blocks.6.1.time_stack.0.attn2.to_out.0.weight": "blocks.144.attn2.to_out.weight", + "model.diffusion_model.output_blocks.6.1.time_stack.0.attn2.to_q.weight": "blocks.144.attn2.to_q.weight", + "model.diffusion_model.output_blocks.6.1.time_stack.0.attn2.to_v.weight": "blocks.144.attn2.to_v.weight", + "model.diffusion_model.output_blocks.6.1.time_stack.0.ff.net.0.proj.bias": "blocks.144.act_fn_out.proj.bias", + "model.diffusion_model.output_blocks.6.1.time_stack.0.ff.net.0.proj.weight": "blocks.144.act_fn_out.proj.weight", + "model.diffusion_model.output_blocks.6.1.time_stack.0.ff.net.2.bias": "blocks.144.ff_out.bias", + "model.diffusion_model.output_blocks.6.1.time_stack.0.ff.net.2.weight": "blocks.144.ff_out.weight", + "model.diffusion_model.output_blocks.6.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.144.act_fn_in.proj.bias", + "model.diffusion_model.output_blocks.6.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.144.act_fn_in.proj.weight", + "model.diffusion_model.output_blocks.6.1.time_stack.0.ff_in.net.2.bias": "blocks.144.ff_in.bias", + "model.diffusion_model.output_blocks.6.1.time_stack.0.ff_in.net.2.weight": "blocks.144.ff_in.weight", + "model.diffusion_model.output_blocks.6.1.time_stack.0.norm1.bias": "blocks.144.norm1.bias", + "model.diffusion_model.output_blocks.6.1.time_stack.0.norm1.weight": "blocks.144.norm1.weight", + "model.diffusion_model.output_blocks.6.1.time_stack.0.norm2.bias": "blocks.144.norm2.bias", + "model.diffusion_model.output_blocks.6.1.time_stack.0.norm2.weight": "blocks.144.norm2.weight", + "model.diffusion_model.output_blocks.6.1.time_stack.0.norm3.bias": "blocks.144.norm_out.bias", + "model.diffusion_model.output_blocks.6.1.time_stack.0.norm3.weight": "blocks.144.norm_out.weight", + "model.diffusion_model.output_blocks.6.1.time_stack.0.norm_in.bias": "blocks.144.norm_in.bias", + "model.diffusion_model.output_blocks.6.1.time_stack.0.norm_in.weight": "blocks.144.norm_in.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k.weight": "blocks.142.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.142.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.142.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q.weight": "blocks.142.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v.weight": "blocks.142.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight": "blocks.142.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.142.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.142.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight": "blocks.142.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight": "blocks.142.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.142.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.142.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias": "blocks.142.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight": "blocks.142.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias": "blocks.142.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight": "blocks.142.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.bias": "blocks.142.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight": "blocks.142.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias": "blocks.142.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight": "blocks.142.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "blocks.147.time_emb_proj.bias", + "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "blocks.147.time_emb_proj.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "blocks.147.norm1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "blocks.147.norm1.weight", + "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "blocks.147.conv1.bias", + "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "blocks.147.conv1.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "blocks.147.norm2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "blocks.147.norm2.weight", + "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "blocks.147.conv2.bias", + "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "blocks.147.conv2.weight", + "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "blocks.147.conv_shortcut.bias", + "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "blocks.147.conv_shortcut.weight", + "model.diffusion_model.output_blocks.7.0.time_mixer.mix_factor": "blocks.150.mix_factor", + "model.diffusion_model.output_blocks.7.0.time_stack.emb_layers.1.bias": "blocks.149.time_emb_proj.bias", + "model.diffusion_model.output_blocks.7.0.time_stack.emb_layers.1.weight": "blocks.149.time_emb_proj.weight", + "model.diffusion_model.output_blocks.7.0.time_stack.in_layers.0.bias": "blocks.149.norm1.bias", + "model.diffusion_model.output_blocks.7.0.time_stack.in_layers.0.weight": "blocks.149.norm1.weight", + "model.diffusion_model.output_blocks.7.0.time_stack.in_layers.2.bias": "blocks.149.conv1.bias", + "model.diffusion_model.output_blocks.7.0.time_stack.in_layers.2.weight": "blocks.149.conv1.weight", + "model.diffusion_model.output_blocks.7.0.time_stack.out_layers.0.bias": "blocks.149.norm2.bias", + "model.diffusion_model.output_blocks.7.0.time_stack.out_layers.0.weight": "blocks.149.norm2.weight", + "model.diffusion_model.output_blocks.7.0.time_stack.out_layers.3.bias": "blocks.149.conv2.bias", + "model.diffusion_model.output_blocks.7.0.time_stack.out_layers.3.weight": "blocks.149.conv2.weight", + "model.diffusion_model.output_blocks.7.1.norm.bias": "blocks.152.norm.bias", + "model.diffusion_model.output_blocks.7.1.norm.weight": "blocks.152.norm.weight", + "model.diffusion_model.output_blocks.7.1.proj_in.bias": "blocks.152.proj_in.bias", + "model.diffusion_model.output_blocks.7.1.proj_in.weight": "blocks.152.proj_in.weight", + "model.diffusion_model.output_blocks.7.1.proj_out.bias": "blocks.155.proj.bias", + "model.diffusion_model.output_blocks.7.1.proj_out.weight": "blocks.155.proj.weight", + "model.diffusion_model.output_blocks.7.1.time_mixer.mix_factor": "blocks.155.mix_factor", + "model.diffusion_model.output_blocks.7.1.time_pos_embed.0.bias": "blocks.154.positional_embedding_proj.0.bias", + "model.diffusion_model.output_blocks.7.1.time_pos_embed.0.weight": "blocks.154.positional_embedding_proj.0.weight", + "model.diffusion_model.output_blocks.7.1.time_pos_embed.2.bias": "blocks.154.positional_embedding_proj.2.bias", + "model.diffusion_model.output_blocks.7.1.time_pos_embed.2.weight": "blocks.154.positional_embedding_proj.2.weight", + "model.diffusion_model.output_blocks.7.1.time_stack.0.attn1.to_k.weight": "blocks.154.attn1.to_k.weight", + "model.diffusion_model.output_blocks.7.1.time_stack.0.attn1.to_out.0.bias": "blocks.154.attn1.to_out.bias", + "model.diffusion_model.output_blocks.7.1.time_stack.0.attn1.to_out.0.weight": "blocks.154.attn1.to_out.weight", + "model.diffusion_model.output_blocks.7.1.time_stack.0.attn1.to_q.weight": "blocks.154.attn1.to_q.weight", + "model.diffusion_model.output_blocks.7.1.time_stack.0.attn1.to_v.weight": "blocks.154.attn1.to_v.weight", + "model.diffusion_model.output_blocks.7.1.time_stack.0.attn2.to_k.weight": "blocks.154.attn2.to_k.weight", + "model.diffusion_model.output_blocks.7.1.time_stack.0.attn2.to_out.0.bias": "blocks.154.attn2.to_out.bias", + "model.diffusion_model.output_blocks.7.1.time_stack.0.attn2.to_out.0.weight": "blocks.154.attn2.to_out.weight", + "model.diffusion_model.output_blocks.7.1.time_stack.0.attn2.to_q.weight": "blocks.154.attn2.to_q.weight", + "model.diffusion_model.output_blocks.7.1.time_stack.0.attn2.to_v.weight": "blocks.154.attn2.to_v.weight", + "model.diffusion_model.output_blocks.7.1.time_stack.0.ff.net.0.proj.bias": "blocks.154.act_fn_out.proj.bias", + "model.diffusion_model.output_blocks.7.1.time_stack.0.ff.net.0.proj.weight": "blocks.154.act_fn_out.proj.weight", + "model.diffusion_model.output_blocks.7.1.time_stack.0.ff.net.2.bias": "blocks.154.ff_out.bias", + "model.diffusion_model.output_blocks.7.1.time_stack.0.ff.net.2.weight": "blocks.154.ff_out.weight", + "model.diffusion_model.output_blocks.7.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.154.act_fn_in.proj.bias", + "model.diffusion_model.output_blocks.7.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.154.act_fn_in.proj.weight", + "model.diffusion_model.output_blocks.7.1.time_stack.0.ff_in.net.2.bias": "blocks.154.ff_in.bias", + "model.diffusion_model.output_blocks.7.1.time_stack.0.ff_in.net.2.weight": "blocks.154.ff_in.weight", + "model.diffusion_model.output_blocks.7.1.time_stack.0.norm1.bias": "blocks.154.norm1.bias", + "model.diffusion_model.output_blocks.7.1.time_stack.0.norm1.weight": "blocks.154.norm1.weight", + "model.diffusion_model.output_blocks.7.1.time_stack.0.norm2.bias": "blocks.154.norm2.bias", + "model.diffusion_model.output_blocks.7.1.time_stack.0.norm2.weight": "blocks.154.norm2.weight", + "model.diffusion_model.output_blocks.7.1.time_stack.0.norm3.bias": "blocks.154.norm_out.bias", + "model.diffusion_model.output_blocks.7.1.time_stack.0.norm3.weight": "blocks.154.norm_out.weight", + "model.diffusion_model.output_blocks.7.1.time_stack.0.norm_in.bias": "blocks.154.norm_in.bias", + "model.diffusion_model.output_blocks.7.1.time_stack.0.norm_in.weight": "blocks.154.norm_in.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.152.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.152.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.152.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.152.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.152.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.152.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.152.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.152.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.152.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.152.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.152.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.152.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.152.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.152.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.152.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.152.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.152.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.152.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.152.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.152.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "blocks.157.time_emb_proj.bias", + "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "blocks.157.time_emb_proj.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "blocks.157.norm1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "blocks.157.norm1.weight", + "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "blocks.157.conv1.bias", + "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "blocks.157.conv1.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "blocks.157.norm2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "blocks.157.norm2.weight", + "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "blocks.157.conv2.bias", + "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "blocks.157.conv2.weight", + "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "blocks.157.conv_shortcut.bias", + "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "blocks.157.conv_shortcut.weight", + "model.diffusion_model.output_blocks.8.0.time_mixer.mix_factor": "blocks.160.mix_factor", + "model.diffusion_model.output_blocks.8.0.time_stack.emb_layers.1.bias": "blocks.159.time_emb_proj.bias", + "model.diffusion_model.output_blocks.8.0.time_stack.emb_layers.1.weight": "blocks.159.time_emb_proj.weight", + "model.diffusion_model.output_blocks.8.0.time_stack.in_layers.0.bias": "blocks.159.norm1.bias", + "model.diffusion_model.output_blocks.8.0.time_stack.in_layers.0.weight": "blocks.159.norm1.weight", + "model.diffusion_model.output_blocks.8.0.time_stack.in_layers.2.bias": "blocks.159.conv1.bias", + "model.diffusion_model.output_blocks.8.0.time_stack.in_layers.2.weight": "blocks.159.conv1.weight", + "model.diffusion_model.output_blocks.8.0.time_stack.out_layers.0.bias": "blocks.159.norm2.bias", + "model.diffusion_model.output_blocks.8.0.time_stack.out_layers.0.weight": "blocks.159.norm2.weight", + "model.diffusion_model.output_blocks.8.0.time_stack.out_layers.3.bias": "blocks.159.conv2.bias", + "model.diffusion_model.output_blocks.8.0.time_stack.out_layers.3.weight": "blocks.159.conv2.weight", + "model.diffusion_model.output_blocks.8.1.norm.bias": "blocks.162.norm.bias", + "model.diffusion_model.output_blocks.8.1.norm.weight": "blocks.162.norm.weight", + "model.diffusion_model.output_blocks.8.1.proj_in.bias": "blocks.162.proj_in.bias", + "model.diffusion_model.output_blocks.8.1.proj_in.weight": "blocks.162.proj_in.weight", + "model.diffusion_model.output_blocks.8.1.proj_out.bias": "blocks.165.proj.bias", + "model.diffusion_model.output_blocks.8.1.proj_out.weight": "blocks.165.proj.weight", + "model.diffusion_model.output_blocks.8.1.time_mixer.mix_factor": "blocks.165.mix_factor", + "model.diffusion_model.output_blocks.8.1.time_pos_embed.0.bias": "blocks.164.positional_embedding_proj.0.bias", + "model.diffusion_model.output_blocks.8.1.time_pos_embed.0.weight": "blocks.164.positional_embedding_proj.0.weight", + "model.diffusion_model.output_blocks.8.1.time_pos_embed.2.bias": "blocks.164.positional_embedding_proj.2.bias", + "model.diffusion_model.output_blocks.8.1.time_pos_embed.2.weight": "blocks.164.positional_embedding_proj.2.weight", + "model.diffusion_model.output_blocks.8.1.time_stack.0.attn1.to_k.weight": "blocks.164.attn1.to_k.weight", + "model.diffusion_model.output_blocks.8.1.time_stack.0.attn1.to_out.0.bias": "blocks.164.attn1.to_out.bias", + "model.diffusion_model.output_blocks.8.1.time_stack.0.attn1.to_out.0.weight": "blocks.164.attn1.to_out.weight", + "model.diffusion_model.output_blocks.8.1.time_stack.0.attn1.to_q.weight": "blocks.164.attn1.to_q.weight", + "model.diffusion_model.output_blocks.8.1.time_stack.0.attn1.to_v.weight": "blocks.164.attn1.to_v.weight", + "model.diffusion_model.output_blocks.8.1.time_stack.0.attn2.to_k.weight": "blocks.164.attn2.to_k.weight", + "model.diffusion_model.output_blocks.8.1.time_stack.0.attn2.to_out.0.bias": "blocks.164.attn2.to_out.bias", + "model.diffusion_model.output_blocks.8.1.time_stack.0.attn2.to_out.0.weight": "blocks.164.attn2.to_out.weight", + "model.diffusion_model.output_blocks.8.1.time_stack.0.attn2.to_q.weight": "blocks.164.attn2.to_q.weight", + "model.diffusion_model.output_blocks.8.1.time_stack.0.attn2.to_v.weight": "blocks.164.attn2.to_v.weight", + "model.diffusion_model.output_blocks.8.1.time_stack.0.ff.net.0.proj.bias": "blocks.164.act_fn_out.proj.bias", + "model.diffusion_model.output_blocks.8.1.time_stack.0.ff.net.0.proj.weight": "blocks.164.act_fn_out.proj.weight", + "model.diffusion_model.output_blocks.8.1.time_stack.0.ff.net.2.bias": "blocks.164.ff_out.bias", + "model.diffusion_model.output_blocks.8.1.time_stack.0.ff.net.2.weight": "blocks.164.ff_out.weight", + "model.diffusion_model.output_blocks.8.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.164.act_fn_in.proj.bias", + "model.diffusion_model.output_blocks.8.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.164.act_fn_in.proj.weight", + "model.diffusion_model.output_blocks.8.1.time_stack.0.ff_in.net.2.bias": "blocks.164.ff_in.bias", + "model.diffusion_model.output_blocks.8.1.time_stack.0.ff_in.net.2.weight": "blocks.164.ff_in.weight", + "model.diffusion_model.output_blocks.8.1.time_stack.0.norm1.bias": "blocks.164.norm1.bias", + "model.diffusion_model.output_blocks.8.1.time_stack.0.norm1.weight": "blocks.164.norm1.weight", + "model.diffusion_model.output_blocks.8.1.time_stack.0.norm2.bias": "blocks.164.norm2.bias", + "model.diffusion_model.output_blocks.8.1.time_stack.0.norm2.weight": "blocks.164.norm2.weight", + "model.diffusion_model.output_blocks.8.1.time_stack.0.norm3.bias": "blocks.164.norm_out.bias", + "model.diffusion_model.output_blocks.8.1.time_stack.0.norm3.weight": "blocks.164.norm_out.weight", + "model.diffusion_model.output_blocks.8.1.time_stack.0.norm_in.bias": "blocks.164.norm_in.bias", + "model.diffusion_model.output_blocks.8.1.time_stack.0.norm_in.weight": "blocks.164.norm_in.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.162.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.162.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.162.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.162.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.162.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.162.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.162.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.162.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.162.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.162.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.162.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.162.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.162.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.162.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.162.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.162.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.162.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.162.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.162.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.162.transformer_blocks.0.norm3.weight", + "model.diffusion_model.output_blocks.8.2.conv.bias": "blocks.166.conv.bias", + "model.diffusion_model.output_blocks.8.2.conv.weight": "blocks.166.conv.weight", + "model.diffusion_model.output_blocks.9.0.emb_layers.1.bias": "blocks.168.time_emb_proj.bias", + "model.diffusion_model.output_blocks.9.0.emb_layers.1.weight": "blocks.168.time_emb_proj.weight", + "model.diffusion_model.output_blocks.9.0.in_layers.0.bias": "blocks.168.norm1.bias", + "model.diffusion_model.output_blocks.9.0.in_layers.0.weight": "blocks.168.norm1.weight", + "model.diffusion_model.output_blocks.9.0.in_layers.2.bias": "blocks.168.conv1.bias", + "model.diffusion_model.output_blocks.9.0.in_layers.2.weight": "blocks.168.conv1.weight", + "model.diffusion_model.output_blocks.9.0.out_layers.0.bias": "blocks.168.norm2.bias", + "model.diffusion_model.output_blocks.9.0.out_layers.0.weight": "blocks.168.norm2.weight", + "model.diffusion_model.output_blocks.9.0.out_layers.3.bias": "blocks.168.conv2.bias", + "model.diffusion_model.output_blocks.9.0.out_layers.3.weight": "blocks.168.conv2.weight", + "model.diffusion_model.output_blocks.9.0.skip_connection.bias": "blocks.168.conv_shortcut.bias", + "model.diffusion_model.output_blocks.9.0.skip_connection.weight": "blocks.168.conv_shortcut.weight", + "model.diffusion_model.output_blocks.9.0.time_mixer.mix_factor": "blocks.171.mix_factor", + "model.diffusion_model.output_blocks.9.0.time_stack.emb_layers.1.bias": "blocks.170.time_emb_proj.bias", + "model.diffusion_model.output_blocks.9.0.time_stack.emb_layers.1.weight": "blocks.170.time_emb_proj.weight", + "model.diffusion_model.output_blocks.9.0.time_stack.in_layers.0.bias": "blocks.170.norm1.bias", + "model.diffusion_model.output_blocks.9.0.time_stack.in_layers.0.weight": "blocks.170.norm1.weight", + "model.diffusion_model.output_blocks.9.0.time_stack.in_layers.2.bias": "blocks.170.conv1.bias", + "model.diffusion_model.output_blocks.9.0.time_stack.in_layers.2.weight": "blocks.170.conv1.weight", + "model.diffusion_model.output_blocks.9.0.time_stack.out_layers.0.bias": "blocks.170.norm2.bias", + "model.diffusion_model.output_blocks.9.0.time_stack.out_layers.0.weight": "blocks.170.norm2.weight", + "model.diffusion_model.output_blocks.9.0.time_stack.out_layers.3.bias": "blocks.170.conv2.bias", + "model.diffusion_model.output_blocks.9.0.time_stack.out_layers.3.weight": "blocks.170.conv2.weight", + "model.diffusion_model.output_blocks.9.1.norm.bias": "blocks.173.norm.bias", + "model.diffusion_model.output_blocks.9.1.norm.weight": "blocks.173.norm.weight", + "model.diffusion_model.output_blocks.9.1.proj_in.bias": "blocks.173.proj_in.bias", + "model.diffusion_model.output_blocks.9.1.proj_in.weight": "blocks.173.proj_in.weight", + "model.diffusion_model.output_blocks.9.1.proj_out.bias": "blocks.176.proj.bias", + "model.diffusion_model.output_blocks.9.1.proj_out.weight": "blocks.176.proj.weight", + "model.diffusion_model.output_blocks.9.1.time_mixer.mix_factor": "blocks.176.mix_factor", + "model.diffusion_model.output_blocks.9.1.time_pos_embed.0.bias": "blocks.175.positional_embedding_proj.0.bias", + "model.diffusion_model.output_blocks.9.1.time_pos_embed.0.weight": "blocks.175.positional_embedding_proj.0.weight", + "model.diffusion_model.output_blocks.9.1.time_pos_embed.2.bias": "blocks.175.positional_embedding_proj.2.bias", + "model.diffusion_model.output_blocks.9.1.time_pos_embed.2.weight": "blocks.175.positional_embedding_proj.2.weight", + "model.diffusion_model.output_blocks.9.1.time_stack.0.attn1.to_k.weight": "blocks.175.attn1.to_k.weight", + "model.diffusion_model.output_blocks.9.1.time_stack.0.attn1.to_out.0.bias": "blocks.175.attn1.to_out.bias", + "model.diffusion_model.output_blocks.9.1.time_stack.0.attn1.to_out.0.weight": "blocks.175.attn1.to_out.weight", + "model.diffusion_model.output_blocks.9.1.time_stack.0.attn1.to_q.weight": "blocks.175.attn1.to_q.weight", + "model.diffusion_model.output_blocks.9.1.time_stack.0.attn1.to_v.weight": "blocks.175.attn1.to_v.weight", + "model.diffusion_model.output_blocks.9.1.time_stack.0.attn2.to_k.weight": "blocks.175.attn2.to_k.weight", + "model.diffusion_model.output_blocks.9.1.time_stack.0.attn2.to_out.0.bias": "blocks.175.attn2.to_out.bias", + "model.diffusion_model.output_blocks.9.1.time_stack.0.attn2.to_out.0.weight": "blocks.175.attn2.to_out.weight", + "model.diffusion_model.output_blocks.9.1.time_stack.0.attn2.to_q.weight": "blocks.175.attn2.to_q.weight", + "model.diffusion_model.output_blocks.9.1.time_stack.0.attn2.to_v.weight": "blocks.175.attn2.to_v.weight", + "model.diffusion_model.output_blocks.9.1.time_stack.0.ff.net.0.proj.bias": "blocks.175.act_fn_out.proj.bias", + "model.diffusion_model.output_blocks.9.1.time_stack.0.ff.net.0.proj.weight": "blocks.175.act_fn_out.proj.weight", + "model.diffusion_model.output_blocks.9.1.time_stack.0.ff.net.2.bias": "blocks.175.ff_out.bias", + "model.diffusion_model.output_blocks.9.1.time_stack.0.ff.net.2.weight": "blocks.175.ff_out.weight", + "model.diffusion_model.output_blocks.9.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.175.act_fn_in.proj.bias", + "model.diffusion_model.output_blocks.9.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.175.act_fn_in.proj.weight", + "model.diffusion_model.output_blocks.9.1.time_stack.0.ff_in.net.2.bias": "blocks.175.ff_in.bias", + "model.diffusion_model.output_blocks.9.1.time_stack.0.ff_in.net.2.weight": "blocks.175.ff_in.weight", + "model.diffusion_model.output_blocks.9.1.time_stack.0.norm1.bias": "blocks.175.norm1.bias", + "model.diffusion_model.output_blocks.9.1.time_stack.0.norm1.weight": "blocks.175.norm1.weight", + "model.diffusion_model.output_blocks.9.1.time_stack.0.norm2.bias": "blocks.175.norm2.bias", + "model.diffusion_model.output_blocks.9.1.time_stack.0.norm2.weight": "blocks.175.norm2.weight", + "model.diffusion_model.output_blocks.9.1.time_stack.0.norm3.bias": "blocks.175.norm_out.bias", + "model.diffusion_model.output_blocks.9.1.time_stack.0.norm3.weight": "blocks.175.norm_out.weight", + "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.bias": "blocks.175.norm_in.bias", + "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight": "blocks.175.norm_in.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_k.weight": "blocks.173.transformer_blocks.0.attn1.to_k.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.173.transformer_blocks.0.attn1.to_out.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.173.transformer_blocks.0.attn1.to_out.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_q.weight": "blocks.173.transformer_blocks.0.attn1.to_q.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_v.weight": "blocks.173.transformer_blocks.0.attn1.to_v.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight": "blocks.173.transformer_blocks.0.attn2.to_k.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.173.transformer_blocks.0.attn2.to_out.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.173.transformer_blocks.0.attn2.to_out.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q.weight": "blocks.173.transformer_blocks.0.attn2.to_q.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight": "blocks.173.transformer_blocks.0.attn2.to_v.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.173.transformer_blocks.0.act_fn.proj.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.173.transformer_blocks.0.act_fn.proj.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.bias": "blocks.173.transformer_blocks.0.ff.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.weight": "blocks.173.transformer_blocks.0.ff.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.bias": "blocks.173.transformer_blocks.0.norm1.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.weight": "blocks.173.transformer_blocks.0.norm1.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.bias": "blocks.173.transformer_blocks.0.norm2.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.weight": "blocks.173.transformer_blocks.0.norm2.weight", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.bias": "blocks.173.transformer_blocks.0.norm3.bias", + "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight": "blocks.173.transformer_blocks.0.norm3.weight", + "model.diffusion_model.time_embed.0.bias": "time_embedding.0.bias", + "model.diffusion_model.time_embed.0.weight": "time_embedding.0.weight", + "model.diffusion_model.time_embed.2.bias": "time_embedding.2.bias", + "model.diffusion_model.time_embed.2.weight": "time_embedding.2.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if ".proj_in." in name or ".proj_out." in name: + param = param.squeeze() + state_dict_[rename_dict[name]] = param + if add_positional_conv is not None: + extra_names = [ + "blocks.7.positional_conv", "blocks.17.positional_conv", "blocks.29.positional_conv", "blocks.39.positional_conv", + "blocks.51.positional_conv", "blocks.61.positional_conv", "blocks.83.positional_conv", "blocks.113.positional_conv", + "blocks.123.positional_conv", "blocks.133.positional_conv", "blocks.144.positional_conv", "blocks.154.positional_conv", + "blocks.164.positional_conv", "blocks.175.positional_conv", "blocks.185.positional_conv", "blocks.195.positional_conv", + ] + extra_channels = [320, 320, 640, 640, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320, 320] + for name, channels in zip(extra_names, extra_channels): + weight = torch.zeros((channels, channels, 3, 3, 3)) + weight[:,:,1,1,1] = torch.eye(channels, channels) + bias = torch.zeros((channels,)) + state_dict_[name + ".weight"] = weight + state_dict_[name + ".bias"] = bias + return state_dict_ diff --git a/diffsynth/models/svd_vae_decoder.py b/diffsynth/models/svd_vae_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3fe6446854cfb3252096c564daedaabc3d201068 --- /dev/null +++ b/diffsynth/models/svd_vae_decoder.py @@ -0,0 +1,577 @@ +import torch +from .attention import Attention +from .sd_unet import ResnetBlock, UpSampler +from .tiler import TileWorker +from einops import rearrange, repeat + + +class VAEAttentionBlock(torch.nn.Module): + + def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True) + + self.transformer_blocks = torch.nn.ModuleList([ + Attention( + inner_dim, + num_attention_heads, + attention_head_dim, + bias_q=True, + bias_kv=True, + bias_out=True + ) + for d in range(num_layers) + ]) + + def forward(self, hidden_states, time_emb, text_emb, res_stack): + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + + for block in self.transformer_blocks: + hidden_states = block(hidden_states) + + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = hidden_states + residual + + return hidden_states, time_emb, text_emb, res_stack + + +class TemporalResnetBlock(torch.nn.Module): + + def __init__(self, in_channels, out_channels, groups=32, eps=1e-5): + super().__init__() + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + self.conv1 = torch.nn.Conv3d(in_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0)) + self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + self.conv2 = torch.nn.Conv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0)) + self.nonlinearity = torch.nn.SiLU() + self.mix_factor = torch.nn.Parameter(torch.Tensor([0.5])) + + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): + x_spatial = hidden_states + x = rearrange(hidden_states, "T C H W -> 1 C T H W") + x = self.norm1(x) + x = self.nonlinearity(x) + x = self.conv1(x) + x = self.norm2(x) + x = self.nonlinearity(x) + x = self.conv2(x) + x_temporal = hidden_states + x[0].permute(1, 0, 2, 3) + alpha = torch.sigmoid(self.mix_factor) + hidden_states = alpha * x_temporal + (1 - alpha) * x_spatial + return hidden_states, time_emb, text_emb, res_stack + + +class SVDVAEDecoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.scaling_factor = 0.18215 + self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1) + + self.blocks = torch.nn.ModuleList([ + # UNetMidBlock + ResnetBlock(512, 512, eps=1e-6), + TemporalResnetBlock(512, 512, eps=1e-6), + VAEAttentionBlock(1, 512, 512, 1, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + TemporalResnetBlock(512, 512, eps=1e-6), + # UpDecoderBlock + ResnetBlock(512, 512, eps=1e-6), + TemporalResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + TemporalResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + TemporalResnetBlock(512, 512, eps=1e-6), + UpSampler(512), + # UpDecoderBlock + ResnetBlock(512, 512, eps=1e-6), + TemporalResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + TemporalResnetBlock(512, 512, eps=1e-6), + ResnetBlock(512, 512, eps=1e-6), + TemporalResnetBlock(512, 512, eps=1e-6), + UpSampler(512), + # UpDecoderBlock + ResnetBlock(512, 256, eps=1e-6), + TemporalResnetBlock(256, 256, eps=1e-6), + ResnetBlock(256, 256, eps=1e-6), + TemporalResnetBlock(256, 256, eps=1e-6), + ResnetBlock(256, 256, eps=1e-6), + TemporalResnetBlock(256, 256, eps=1e-6), + UpSampler(256), + # UpDecoderBlock + ResnetBlock(256, 128, eps=1e-6), + TemporalResnetBlock(128, 128, eps=1e-6), + ResnetBlock(128, 128, eps=1e-6), + TemporalResnetBlock(128, 128, eps=1e-6), + ResnetBlock(128, 128, eps=1e-6), + TemporalResnetBlock(128, 128, eps=1e-6), + ]) + + self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5) + self.conv_act = torch.nn.SiLU() + self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1) + self.time_conv_out = torch.nn.Conv3d(3, 3, kernel_size=(3, 1, 1), padding=(1, 0, 0)) + + + def forward(self, sample): + # 1. pre-process + hidden_states = rearrange(sample, "C T H W -> T C H W") + hidden_states = hidden_states / self.scaling_factor + hidden_states = self.conv_in(hidden_states) + time_emb, text_emb, res_stack = None, None, None + + # 2. blocks + for i, block in enumerate(self.blocks): + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + + # 3. output + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + hidden_states = rearrange(hidden_states, "T C H W -> C T H W") + hidden_states = self.time_conv_out(hidden_states) + + return hidden_states + + + def build_mask(self, data, is_bound): + _, T, H, W = data.shape + t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W) + h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W) + w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W) + border_width = (T + H + W) // 6 + pad = torch.ones_like(t) * border_width + mask = torch.stack([ + pad if is_bound[0] else t + 1, + pad if is_bound[1] else T - t, + pad if is_bound[2] else h + 1, + pad if is_bound[3] else H - h, + pad if is_bound[4] else w + 1, + pad if is_bound[5] else W - w + ]).min(dim=0).values + mask = mask.clip(1, border_width) + mask = (mask / border_width).to(dtype=data.dtype, device=data.device) + mask = rearrange(mask, "T H W -> 1 T H W") + return mask + + + def decode_video( + self, sample, + batch_time=8, batch_height=128, batch_width=128, + stride_time=4, stride_height=32, stride_width=32, + progress_bar=lambda x:x + ): + sample = sample.permute(1, 0, 2, 3) + data_device = sample.device + computation_device = self.conv_in.weight.device + torch_dtype = sample.dtype + _, T, H, W = sample.shape + + weight = torch.zeros((1, T, H*8, W*8), dtype=torch_dtype, device=data_device) + values = torch.zeros((3, T, H*8, W*8), dtype=torch_dtype, device=data_device) + + # Split tasks + tasks = [] + for t in range(0, T, stride_time): + for h in range(0, H, stride_height): + for w in range(0, W, stride_width): + if (t-stride_time >= 0 and t-stride_time+batch_time >= T)\ + or (h-stride_height >= 0 and h-stride_height+batch_height >= H)\ + or (w-stride_width >= 0 and w-stride_width+batch_width >= W): + continue + tasks.append((t, t+batch_time, h, h+batch_height, w, w+batch_width)) + + # Run + for tl, tr, hl, hr, wl, wr in progress_bar(tasks): + sample_batch = sample[:, tl:tr, hl:hr, wl:wr].to(computation_device) + sample_batch = self.forward(sample_batch).to(data_device) + mask = self.build_mask(sample_batch, is_bound=(tl==0, tr>=T, hl==0, hr>=H, wl==0, wr>=W)) + values[:, tl:tr, hl*8:hr*8, wl*8:wr*8] += sample_batch * mask + weight[:, tl:tr, hl*8:hr*8, wl*8:wr*8] += mask + values /= weight + return values + + + def state_dict_converter(self): + return SVDVAEDecoderStateDictConverter() + + +class SVDVAEDecoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + static_rename_dict = { + "decoder.conv_in": "conv_in", + "decoder.mid_block.attentions.0.group_norm": "blocks.2.norm", + "decoder.mid_block.attentions.0.to_q": "blocks.2.transformer_blocks.0.to_q", + "decoder.mid_block.attentions.0.to_k": "blocks.2.transformer_blocks.0.to_k", + "decoder.mid_block.attentions.0.to_v": "blocks.2.transformer_blocks.0.to_v", + "decoder.mid_block.attentions.0.to_out.0": "blocks.2.transformer_blocks.0.to_out", + "decoder.up_blocks.0.upsamplers.0.conv": "blocks.11.conv", + "decoder.up_blocks.1.upsamplers.0.conv": "blocks.18.conv", + "decoder.up_blocks.2.upsamplers.0.conv": "blocks.25.conv", + "decoder.conv_norm_out": "conv_norm_out", + "decoder.conv_out": "conv_out", + "decoder.time_conv_out": "time_conv_out" + } + prefix_rename_dict = { + "decoder.mid_block.resnets.0.spatial_res_block": "blocks.0", + "decoder.mid_block.resnets.0.temporal_res_block": "blocks.1", + "decoder.mid_block.resnets.0.time_mixer": "blocks.1", + "decoder.mid_block.resnets.1.spatial_res_block": "blocks.3", + "decoder.mid_block.resnets.1.temporal_res_block": "blocks.4", + "decoder.mid_block.resnets.1.time_mixer": "blocks.4", + + "decoder.up_blocks.0.resnets.0.spatial_res_block": "blocks.5", + "decoder.up_blocks.0.resnets.0.temporal_res_block": "blocks.6", + "decoder.up_blocks.0.resnets.0.time_mixer": "blocks.6", + "decoder.up_blocks.0.resnets.1.spatial_res_block": "blocks.7", + "decoder.up_blocks.0.resnets.1.temporal_res_block": "blocks.8", + "decoder.up_blocks.0.resnets.1.time_mixer": "blocks.8", + "decoder.up_blocks.0.resnets.2.spatial_res_block": "blocks.9", + "decoder.up_blocks.0.resnets.2.temporal_res_block": "blocks.10", + "decoder.up_blocks.0.resnets.2.time_mixer": "blocks.10", + + "decoder.up_blocks.1.resnets.0.spatial_res_block": "blocks.12", + "decoder.up_blocks.1.resnets.0.temporal_res_block": "blocks.13", + "decoder.up_blocks.1.resnets.0.time_mixer": "blocks.13", + "decoder.up_blocks.1.resnets.1.spatial_res_block": "blocks.14", + "decoder.up_blocks.1.resnets.1.temporal_res_block": "blocks.15", + "decoder.up_blocks.1.resnets.1.time_mixer": "blocks.15", + "decoder.up_blocks.1.resnets.2.spatial_res_block": "blocks.16", + "decoder.up_blocks.1.resnets.2.temporal_res_block": "blocks.17", + "decoder.up_blocks.1.resnets.2.time_mixer": "blocks.17", + + "decoder.up_blocks.2.resnets.0.spatial_res_block": "blocks.19", + "decoder.up_blocks.2.resnets.0.temporal_res_block": "blocks.20", + "decoder.up_blocks.2.resnets.0.time_mixer": "blocks.20", + "decoder.up_blocks.2.resnets.1.spatial_res_block": "blocks.21", + "decoder.up_blocks.2.resnets.1.temporal_res_block": "blocks.22", + "decoder.up_blocks.2.resnets.1.time_mixer": "blocks.22", + "decoder.up_blocks.2.resnets.2.spatial_res_block": "blocks.23", + "decoder.up_blocks.2.resnets.2.temporal_res_block": "blocks.24", + "decoder.up_blocks.2.resnets.2.time_mixer": "blocks.24", + + "decoder.up_blocks.3.resnets.0.spatial_res_block": "blocks.26", + "decoder.up_blocks.3.resnets.0.temporal_res_block": "blocks.27", + "decoder.up_blocks.3.resnets.0.time_mixer": "blocks.27", + "decoder.up_blocks.3.resnets.1.spatial_res_block": "blocks.28", + "decoder.up_blocks.3.resnets.1.temporal_res_block": "blocks.29", + "decoder.up_blocks.3.resnets.1.time_mixer": "blocks.29", + "decoder.up_blocks.3.resnets.2.spatial_res_block": "blocks.30", + "decoder.up_blocks.3.resnets.2.temporal_res_block": "blocks.31", + "decoder.up_blocks.3.resnets.2.time_mixer": "blocks.31", + } + suffix_rename_dict = { + "norm1.weight": "norm1.weight", + "conv1.weight": "conv1.weight", + "norm2.weight": "norm2.weight", + "conv2.weight": "conv2.weight", + "conv_shortcut.weight": "conv_shortcut.weight", + "norm1.bias": "norm1.bias", + "conv1.bias": "conv1.bias", + "norm2.bias": "norm2.bias", + "conv2.bias": "conv2.bias", + "conv_shortcut.bias": "conv_shortcut.bias", + "mix_factor": "mix_factor", + } + + state_dict_ = {} + for name in static_rename_dict: + state_dict_[static_rename_dict[name] + ".weight"] = state_dict[name + ".weight"] + state_dict_[static_rename_dict[name] + ".bias"] = state_dict[name + ".bias"] + for prefix_name in prefix_rename_dict: + for suffix_name in suffix_rename_dict: + name = prefix_name + "." + suffix_name + name_ = prefix_rename_dict[prefix_name] + "." + suffix_rename_dict[suffix_name] + if name in state_dict: + state_dict_[name_] = state_dict[name] + + return state_dict_ + + + def from_civitai(self, state_dict): + rename_dict = { + "first_stage_model.decoder.conv_in.bias": "conv_in.bias", + "first_stage_model.decoder.conv_in.weight": "conv_in.weight", + "first_stage_model.decoder.conv_out.bias": "conv_out.bias", + "first_stage_model.decoder.conv_out.time_mix_conv.bias": "time_conv_out.bias", + "first_stage_model.decoder.conv_out.time_mix_conv.weight": "time_conv_out.weight", + "first_stage_model.decoder.conv_out.weight": "conv_out.weight", + "first_stage_model.decoder.mid.attn_1.k.bias": "blocks.2.transformer_blocks.0.to_k.bias", + "first_stage_model.decoder.mid.attn_1.k.weight": "blocks.2.transformer_blocks.0.to_k.weight", + "first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.2.norm.bias", + "first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.2.norm.weight", + "first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.2.transformer_blocks.0.to_out.bias", + "first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.2.transformer_blocks.0.to_out.weight", + "first_stage_model.decoder.mid.attn_1.q.bias": "blocks.2.transformer_blocks.0.to_q.bias", + "first_stage_model.decoder.mid.attn_1.q.weight": "blocks.2.transformer_blocks.0.to_q.weight", + "first_stage_model.decoder.mid.attn_1.v.bias": "blocks.2.transformer_blocks.0.to_v.bias", + "first_stage_model.decoder.mid.attn_1.v.weight": "blocks.2.transformer_blocks.0.to_v.weight", + "first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias", + "first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight", + "first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias", + "first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight", + "first_stage_model.decoder.mid.block_1.mix_factor": "blocks.1.mix_factor", + "first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias", + "first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight", + "first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias", + "first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight", + "first_stage_model.decoder.mid.block_1.time_stack.in_layers.0.bias": "blocks.1.norm1.bias", + "first_stage_model.decoder.mid.block_1.time_stack.in_layers.0.weight": "blocks.1.norm1.weight", + "first_stage_model.decoder.mid.block_1.time_stack.in_layers.2.bias": "blocks.1.conv1.bias", + "first_stage_model.decoder.mid.block_1.time_stack.in_layers.2.weight": "blocks.1.conv1.weight", + "first_stage_model.decoder.mid.block_1.time_stack.out_layers.0.bias": "blocks.1.norm2.bias", + "first_stage_model.decoder.mid.block_1.time_stack.out_layers.0.weight": "blocks.1.norm2.weight", + "first_stage_model.decoder.mid.block_1.time_stack.out_layers.3.bias": "blocks.1.conv2.bias", + "first_stage_model.decoder.mid.block_1.time_stack.out_layers.3.weight": "blocks.1.conv2.weight", + "first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.3.conv1.bias", + "first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.3.conv1.weight", + "first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.3.conv2.bias", + "first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.3.conv2.weight", + "first_stage_model.decoder.mid.block_2.mix_factor": "blocks.4.mix_factor", + "first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.3.norm1.bias", + "first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.3.norm1.weight", + "first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.3.norm2.bias", + "first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.3.norm2.weight", + "first_stage_model.decoder.mid.block_2.time_stack.in_layers.0.bias": "blocks.4.norm1.bias", + "first_stage_model.decoder.mid.block_2.time_stack.in_layers.0.weight": "blocks.4.norm1.weight", + "first_stage_model.decoder.mid.block_2.time_stack.in_layers.2.bias": "blocks.4.conv1.bias", + "first_stage_model.decoder.mid.block_2.time_stack.in_layers.2.weight": "blocks.4.conv1.weight", + "first_stage_model.decoder.mid.block_2.time_stack.out_layers.0.bias": "blocks.4.norm2.bias", + "first_stage_model.decoder.mid.block_2.time_stack.out_layers.0.weight": "blocks.4.norm2.weight", + "first_stage_model.decoder.mid.block_2.time_stack.out_layers.3.bias": "blocks.4.conv2.bias", + "first_stage_model.decoder.mid.block_2.time_stack.out_layers.3.weight": "blocks.4.conv2.weight", + "first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias", + "first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight", + "first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.26.conv1.bias", + "first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.26.conv1.weight", + "first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.26.conv2.bias", + "first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.26.conv2.weight", + "first_stage_model.decoder.up.0.block.0.mix_factor": "blocks.27.mix_factor", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.26.conv_shortcut.bias", + "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.26.conv_shortcut.weight", + "first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.26.norm1.bias", + "first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.26.norm1.weight", + "first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.26.norm2.bias", + "first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.26.norm2.weight", + "first_stage_model.decoder.up.0.block.0.time_stack.in_layers.0.bias": "blocks.27.norm1.bias", + "first_stage_model.decoder.up.0.block.0.time_stack.in_layers.0.weight": "blocks.27.norm1.weight", + "first_stage_model.decoder.up.0.block.0.time_stack.in_layers.2.bias": "blocks.27.conv1.bias", + "first_stage_model.decoder.up.0.block.0.time_stack.in_layers.2.weight": "blocks.27.conv1.weight", + "first_stage_model.decoder.up.0.block.0.time_stack.out_layers.0.bias": "blocks.27.norm2.bias", + "first_stage_model.decoder.up.0.block.0.time_stack.out_layers.0.weight": "blocks.27.norm2.weight", + "first_stage_model.decoder.up.0.block.0.time_stack.out_layers.3.bias": "blocks.27.conv2.bias", + "first_stage_model.decoder.up.0.block.0.time_stack.out_layers.3.weight": "blocks.27.conv2.weight", + "first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.28.conv1.bias", + "first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.28.conv1.weight", + "first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.28.conv2.bias", + "first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.28.conv2.weight", + "first_stage_model.decoder.up.0.block.1.mix_factor": "blocks.29.mix_factor", + "first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.28.norm1.bias", + "first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.28.norm1.weight", + "first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.28.norm2.bias", + "first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.28.norm2.weight", + "first_stage_model.decoder.up.0.block.1.time_stack.in_layers.0.bias": "blocks.29.norm1.bias", + "first_stage_model.decoder.up.0.block.1.time_stack.in_layers.0.weight": "blocks.29.norm1.weight", + "first_stage_model.decoder.up.0.block.1.time_stack.in_layers.2.bias": "blocks.29.conv1.bias", + "first_stage_model.decoder.up.0.block.1.time_stack.in_layers.2.weight": "blocks.29.conv1.weight", + "first_stage_model.decoder.up.0.block.1.time_stack.out_layers.0.bias": "blocks.29.norm2.bias", + "first_stage_model.decoder.up.0.block.1.time_stack.out_layers.0.weight": "blocks.29.norm2.weight", + "first_stage_model.decoder.up.0.block.1.time_stack.out_layers.3.bias": "blocks.29.conv2.bias", + "first_stage_model.decoder.up.0.block.1.time_stack.out_layers.3.weight": "blocks.29.conv2.weight", + "first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.30.conv1.bias", + "first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.30.conv1.weight", + "first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.30.conv2.bias", + "first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.30.conv2.weight", + "first_stage_model.decoder.up.0.block.2.mix_factor": "blocks.31.mix_factor", + "first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.30.norm1.bias", + "first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.30.norm1.weight", + "first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.30.norm2.bias", + "first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.30.norm2.weight", + "first_stage_model.decoder.up.0.block.2.time_stack.in_layers.0.bias": "blocks.31.norm1.bias", + "first_stage_model.decoder.up.0.block.2.time_stack.in_layers.0.weight": "blocks.31.norm1.weight", + "first_stage_model.decoder.up.0.block.2.time_stack.in_layers.2.bias": "blocks.31.conv1.bias", + "first_stage_model.decoder.up.0.block.2.time_stack.in_layers.2.weight": "blocks.31.conv1.weight", + "first_stage_model.decoder.up.0.block.2.time_stack.out_layers.0.bias": "blocks.31.norm2.bias", + "first_stage_model.decoder.up.0.block.2.time_stack.out_layers.0.weight": "blocks.31.norm2.weight", + "first_stage_model.decoder.up.0.block.2.time_stack.out_layers.3.bias": "blocks.31.conv2.bias", + "first_stage_model.decoder.up.0.block.2.time_stack.out_layers.3.weight": "blocks.31.conv2.weight", + "first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.19.conv1.bias", + "first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.19.conv1.weight", + "first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.19.conv2.bias", + "first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.19.conv2.weight", + "first_stage_model.decoder.up.1.block.0.mix_factor": "blocks.20.mix_factor", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.19.conv_shortcut.bias", + "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.19.conv_shortcut.weight", + "first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.19.norm1.bias", + "first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.19.norm1.weight", + "first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.19.norm2.bias", + "first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.19.norm2.weight", + "first_stage_model.decoder.up.1.block.0.time_stack.in_layers.0.bias": "blocks.20.norm1.bias", + "first_stage_model.decoder.up.1.block.0.time_stack.in_layers.0.weight": "blocks.20.norm1.weight", + "first_stage_model.decoder.up.1.block.0.time_stack.in_layers.2.bias": "blocks.20.conv1.bias", + "first_stage_model.decoder.up.1.block.0.time_stack.in_layers.2.weight": "blocks.20.conv1.weight", + "first_stage_model.decoder.up.1.block.0.time_stack.out_layers.0.bias": "blocks.20.norm2.bias", + "first_stage_model.decoder.up.1.block.0.time_stack.out_layers.0.weight": "blocks.20.norm2.weight", + "first_stage_model.decoder.up.1.block.0.time_stack.out_layers.3.bias": "blocks.20.conv2.bias", + "first_stage_model.decoder.up.1.block.0.time_stack.out_layers.3.weight": "blocks.20.conv2.weight", + "first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.21.conv1.bias", + "first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.21.conv1.weight", + "first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.21.conv2.bias", + "first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.21.conv2.weight", + "first_stage_model.decoder.up.1.block.1.mix_factor": "blocks.22.mix_factor", + "first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.21.norm1.bias", + "first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.21.norm1.weight", + "first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.21.norm2.bias", + "first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.21.norm2.weight", + "first_stage_model.decoder.up.1.block.1.time_stack.in_layers.0.bias": "blocks.22.norm1.bias", + "first_stage_model.decoder.up.1.block.1.time_stack.in_layers.0.weight": "blocks.22.norm1.weight", + "first_stage_model.decoder.up.1.block.1.time_stack.in_layers.2.bias": "blocks.22.conv1.bias", + "first_stage_model.decoder.up.1.block.1.time_stack.in_layers.2.weight": "blocks.22.conv1.weight", + "first_stage_model.decoder.up.1.block.1.time_stack.out_layers.0.bias": "blocks.22.norm2.bias", + "first_stage_model.decoder.up.1.block.1.time_stack.out_layers.0.weight": "blocks.22.norm2.weight", + "first_stage_model.decoder.up.1.block.1.time_stack.out_layers.3.bias": "blocks.22.conv2.bias", + "first_stage_model.decoder.up.1.block.1.time_stack.out_layers.3.weight": "blocks.22.conv2.weight", + "first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.23.conv1.bias", + "first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.23.conv1.weight", + "first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.23.conv2.bias", + "first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.23.conv2.weight", + "first_stage_model.decoder.up.1.block.2.mix_factor": "blocks.24.mix_factor", + "first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.23.norm1.bias", + "first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.23.norm1.weight", + "first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.23.norm2.bias", + "first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.23.norm2.weight", + "first_stage_model.decoder.up.1.block.2.time_stack.in_layers.0.bias": "blocks.24.norm1.bias", + "first_stage_model.decoder.up.1.block.2.time_stack.in_layers.0.weight": "blocks.24.norm1.weight", + "first_stage_model.decoder.up.1.block.2.time_stack.in_layers.2.bias": "blocks.24.conv1.bias", + "first_stage_model.decoder.up.1.block.2.time_stack.in_layers.2.weight": "blocks.24.conv1.weight", + "first_stage_model.decoder.up.1.block.2.time_stack.out_layers.0.bias": "blocks.24.norm2.bias", + "first_stage_model.decoder.up.1.block.2.time_stack.out_layers.0.weight": "blocks.24.norm2.weight", + "first_stage_model.decoder.up.1.block.2.time_stack.out_layers.3.bias": "blocks.24.conv2.bias", + "first_stage_model.decoder.up.1.block.2.time_stack.out_layers.3.weight": "blocks.24.conv2.weight", + "first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.25.conv.bias", + "first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.25.conv.weight", + "first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.12.conv1.bias", + "first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.12.conv1.weight", + "first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.12.conv2.bias", + "first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.12.conv2.weight", + "first_stage_model.decoder.up.2.block.0.mix_factor": "blocks.13.mix_factor", + "first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.12.norm1.bias", + "first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.12.norm1.weight", + "first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.12.norm2.bias", + "first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.12.norm2.weight", + "first_stage_model.decoder.up.2.block.0.time_stack.in_layers.0.bias": "blocks.13.norm1.bias", + "first_stage_model.decoder.up.2.block.0.time_stack.in_layers.0.weight": "blocks.13.norm1.weight", + "first_stage_model.decoder.up.2.block.0.time_stack.in_layers.2.bias": "blocks.13.conv1.bias", + "first_stage_model.decoder.up.2.block.0.time_stack.in_layers.2.weight": "blocks.13.conv1.weight", + "first_stage_model.decoder.up.2.block.0.time_stack.out_layers.0.bias": "blocks.13.norm2.bias", + "first_stage_model.decoder.up.2.block.0.time_stack.out_layers.0.weight": "blocks.13.norm2.weight", + "first_stage_model.decoder.up.2.block.0.time_stack.out_layers.3.bias": "blocks.13.conv2.bias", + "first_stage_model.decoder.up.2.block.0.time_stack.out_layers.3.weight": "blocks.13.conv2.weight", + "first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.14.conv1.bias", + "first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.14.conv1.weight", + "first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.14.conv2.bias", + "first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.14.conv2.weight", + "first_stage_model.decoder.up.2.block.1.mix_factor": "blocks.15.mix_factor", + "first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.14.norm1.bias", + "first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.14.norm1.weight", + "first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.14.norm2.bias", + "first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.14.norm2.weight", + "first_stage_model.decoder.up.2.block.1.time_stack.in_layers.0.bias": "blocks.15.norm1.bias", + "first_stage_model.decoder.up.2.block.1.time_stack.in_layers.0.weight": "blocks.15.norm1.weight", + "first_stage_model.decoder.up.2.block.1.time_stack.in_layers.2.bias": "blocks.15.conv1.bias", + "first_stage_model.decoder.up.2.block.1.time_stack.in_layers.2.weight": "blocks.15.conv1.weight", + "first_stage_model.decoder.up.2.block.1.time_stack.out_layers.0.bias": "blocks.15.norm2.bias", + "first_stage_model.decoder.up.2.block.1.time_stack.out_layers.0.weight": "blocks.15.norm2.weight", + "first_stage_model.decoder.up.2.block.1.time_stack.out_layers.3.bias": "blocks.15.conv2.bias", + "first_stage_model.decoder.up.2.block.1.time_stack.out_layers.3.weight": "blocks.15.conv2.weight", + "first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.16.conv1.bias", + "first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.16.conv1.weight", + "first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.16.conv2.bias", + "first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.16.conv2.weight", + "first_stage_model.decoder.up.2.block.2.mix_factor": "blocks.17.mix_factor", + "first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.16.norm1.bias", + "first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.16.norm1.weight", + "first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.16.norm2.bias", + "first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.16.norm2.weight", + "first_stage_model.decoder.up.2.block.2.time_stack.in_layers.0.bias": "blocks.17.norm1.bias", + "first_stage_model.decoder.up.2.block.2.time_stack.in_layers.0.weight": "blocks.17.norm1.weight", + "first_stage_model.decoder.up.2.block.2.time_stack.in_layers.2.bias": "blocks.17.conv1.bias", + "first_stage_model.decoder.up.2.block.2.time_stack.in_layers.2.weight": "blocks.17.conv1.weight", + "first_stage_model.decoder.up.2.block.2.time_stack.out_layers.0.bias": "blocks.17.norm2.bias", + "first_stage_model.decoder.up.2.block.2.time_stack.out_layers.0.weight": "blocks.17.norm2.weight", + "first_stage_model.decoder.up.2.block.2.time_stack.out_layers.3.bias": "blocks.17.conv2.bias", + "first_stage_model.decoder.up.2.block.2.time_stack.out_layers.3.weight": "blocks.17.conv2.weight", + "first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.18.conv.bias", + "first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.18.conv.weight", + "first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.5.conv1.bias", + "first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.5.conv1.weight", + "first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.5.conv2.bias", + "first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.5.conv2.weight", + "first_stage_model.decoder.up.3.block.0.mix_factor": "blocks.6.mix_factor", + "first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.5.norm1.bias", + "first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.5.norm1.weight", + "first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.5.norm2.bias", + "first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.5.norm2.weight", + "first_stage_model.decoder.up.3.block.0.time_stack.in_layers.0.bias": "blocks.6.norm1.bias", + "first_stage_model.decoder.up.3.block.0.time_stack.in_layers.0.weight": "blocks.6.norm1.weight", + "first_stage_model.decoder.up.3.block.0.time_stack.in_layers.2.bias": "blocks.6.conv1.bias", + "first_stage_model.decoder.up.3.block.0.time_stack.in_layers.2.weight": "blocks.6.conv1.weight", + "first_stage_model.decoder.up.3.block.0.time_stack.out_layers.0.bias": "blocks.6.norm2.bias", + "first_stage_model.decoder.up.3.block.0.time_stack.out_layers.0.weight": "blocks.6.norm2.weight", + "first_stage_model.decoder.up.3.block.0.time_stack.out_layers.3.bias": "blocks.6.conv2.bias", + "first_stage_model.decoder.up.3.block.0.time_stack.out_layers.3.weight": "blocks.6.conv2.weight", + "first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.7.conv1.bias", + "first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.7.conv1.weight", + "first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.7.conv2.bias", + "first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.7.conv2.weight", + "first_stage_model.decoder.up.3.block.1.mix_factor": "blocks.8.mix_factor", + "first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.7.norm1.bias", + "first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.7.norm1.weight", + "first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.7.norm2.bias", + "first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.7.norm2.weight", + "first_stage_model.decoder.up.3.block.1.time_stack.in_layers.0.bias": "blocks.8.norm1.bias", + "first_stage_model.decoder.up.3.block.1.time_stack.in_layers.0.weight": "blocks.8.norm1.weight", + "first_stage_model.decoder.up.3.block.1.time_stack.in_layers.2.bias": "blocks.8.conv1.bias", + "first_stage_model.decoder.up.3.block.1.time_stack.in_layers.2.weight": "blocks.8.conv1.weight", + "first_stage_model.decoder.up.3.block.1.time_stack.out_layers.0.bias": "blocks.8.norm2.bias", + "first_stage_model.decoder.up.3.block.1.time_stack.out_layers.0.weight": "blocks.8.norm2.weight", + "first_stage_model.decoder.up.3.block.1.time_stack.out_layers.3.bias": "blocks.8.conv2.bias", + "first_stage_model.decoder.up.3.block.1.time_stack.out_layers.3.weight": "blocks.8.conv2.weight", + "first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.9.conv1.bias", + "first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.9.conv1.weight", + "first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.9.conv2.bias", + "first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.9.conv2.weight", + "first_stage_model.decoder.up.3.block.2.mix_factor": "blocks.10.mix_factor", + "first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.9.norm1.bias", + "first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.9.norm1.weight", + "first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.9.norm2.bias", + "first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.9.norm2.weight", + "first_stage_model.decoder.up.3.block.2.time_stack.in_layers.0.bias": "blocks.10.norm1.bias", + "first_stage_model.decoder.up.3.block.2.time_stack.in_layers.0.weight": "blocks.10.norm1.weight", + "first_stage_model.decoder.up.3.block.2.time_stack.in_layers.2.bias": "blocks.10.conv1.bias", + "first_stage_model.decoder.up.3.block.2.time_stack.in_layers.2.weight": "blocks.10.conv1.weight", + "first_stage_model.decoder.up.3.block.2.time_stack.out_layers.0.bias": "blocks.10.norm2.bias", + "first_stage_model.decoder.up.3.block.2.time_stack.out_layers.0.weight": "blocks.10.norm2.weight", + "first_stage_model.decoder.up.3.block.2.time_stack.out_layers.3.bias": "blocks.10.conv2.bias", + "first_stage_model.decoder.up.3.block.2.time_stack.out_layers.3.weight": "blocks.10.conv2.weight", + "first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.11.conv.bias", + "first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.11.conv.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if "blocks.2.transformer_blocks.0" in rename_dict[name]: + param = param.squeeze() + state_dict_[rename_dict[name]] = param + return state_dict_ diff --git a/diffsynth/models/svd_vae_encoder.py b/diffsynth/models/svd_vae_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3e84a59f9638c7b8429757711dc6ddf64d3a35e4 --- /dev/null +++ b/diffsynth/models/svd_vae_encoder.py @@ -0,0 +1,138 @@ +from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder + + +class SVDVAEEncoder(SDVAEEncoder): + def __init__(self): + super().__init__() + self.scaling_factor = 0.13025 + + def state_dict_converter(self): + return SVDVAEEncoderStateDictConverter() + + +class SVDVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter): + def __init__(self): + super().__init__() + + def from_diffusers(self, state_dict): + return super().from_diffusers(state_dict) + + def from_civitai(self, state_dict): + rename_dict = { + "conditioner.embedders.3.encoder.encoder.conv_in.bias": "conv_in.bias", + "conditioner.embedders.3.encoder.encoder.conv_in.weight": "conv_in.weight", + "conditioner.embedders.3.encoder.encoder.conv_out.bias": "conv_out.bias", + "conditioner.embedders.3.encoder.encoder.conv_out.weight": "conv_out.weight", + "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias", + "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight", + "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias", + "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight", + "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias", + "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight", + "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias", + "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight", + "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias", + "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight", + "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias", + "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight", + "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias", + "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight", + "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias", + "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight", + "conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias", + "conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight", + "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias", + "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight", + "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias", + "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight", + "conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias", + "conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight", + "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias", + "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight", + "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias", + "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight", + "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias", + "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight", + "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias", + "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight", + "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias", + "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight", + "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias", + "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight", + "conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias", + "conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight", + "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias", + "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight", + "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias", + "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight", + "conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias", + "conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight", + "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias", + "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight", + "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias", + "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight", + "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias", + "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight", + "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias", + "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight", + "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias", + "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight", + "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias", + "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight", + "conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias", + "conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight", + "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias", + "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight", + "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias", + "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight", + "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias", + "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight", + "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias", + "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight", + "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias", + "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight", + "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias", + "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight", + "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias", + "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight", + "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias", + "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight", + "conditioner.embedders.3.encoder.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias", + "conditioner.embedders.3.encoder.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight", + "conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias", + "conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight", + "conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias", + "conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight", + "conditioner.embedders.3.encoder.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias", + "conditioner.embedders.3.encoder.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight", + "conditioner.embedders.3.encoder.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias", + "conditioner.embedders.3.encoder.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight", + "conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias", + "conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight", + "conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias", + "conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight", + "conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias", + "conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight", + "conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias", + "conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight", + "conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias", + "conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight", + "conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias", + "conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight", + "conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias", + "conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight", + "conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias", + "conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight", + "conditioner.embedders.3.encoder.encoder.norm_out.bias": "conv_norm_out.bias", + "conditioner.embedders.3.encoder.encoder.norm_out.weight": "conv_norm_out.weight", + "conditioner.embedders.3.encoder.quant_conv.bias": "quant_conv.bias", + "conditioner.embedders.3.encoder.quant_conv.weight": "quant_conv.weight", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if "transformer_blocks" in rename_dict[name]: + param = param.squeeze() + state_dict_[rename_dict[name]] = param + return state_dict_ diff --git a/diffsynth/models/tiler.py b/diffsynth/models/tiler.py new file mode 100644 index 0000000000000000000000000000000000000000..af37ff62a0fe08e0822ed26a50420d5dced4b79c --- /dev/null +++ b/diffsynth/models/tiler.py @@ -0,0 +1,106 @@ +import torch +from einops import rearrange, repeat + + +class TileWorker: + def __init__(self): + pass + + + def mask(self, height, width, border_width): + # Create a mask with shape (height, width). + # The centre area is filled with 1, and the border line is filled with values in range (0, 1]. + x = torch.arange(height).repeat(width, 1).T + y = torch.arange(width).repeat(height, 1) + mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values + mask = (mask / border_width).clip(0, 1) + return mask + + + def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype): + # Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num) + batch_size, channel, _, _ = model_input.shape + model_input = model_input.to(device=tile_device, dtype=tile_dtype) + unfold_operator = torch.nn.Unfold( + kernel_size=(tile_size, tile_size), + stride=(tile_stride, tile_stride) + ) + model_input = unfold_operator(model_input) + model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1)) + + return model_input + + + def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype): + # Call y=forward_fn(x) for each tile + tile_num = model_input.shape[-1] + model_output_stack = [] + + for tile_id in range(0, tile_num, tile_batch_size): + + # process input + tile_id_ = min(tile_id + tile_batch_size, tile_num) + x = model_input[:, :, :, :, tile_id: tile_id_] + x = x.to(device=inference_device, dtype=inference_dtype) + x = rearrange(x, "b c h w n -> (n b) c h w") + + # process output + y = forward_fn(x) + y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id) + y = y.to(device=tile_device, dtype=tile_dtype) + model_output_stack.append(y) + + model_output = torch.concat(model_output_stack, dim=-1) + return model_output + + + def io_scale(self, model_output, tile_size): + # Determine the size modification happend in forward_fn + # We only consider the same scale on height and width. + io_scale = model_output.shape[2] / tile_size + return io_scale + + + def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype): + # The reversed function of tile + mask = self.mask(tile_size, tile_size, border_width) + mask = mask.to(device=tile_device, dtype=tile_dtype) + mask = rearrange(mask, "h w -> 1 1 h w 1") + model_output = model_output * mask + + fold_operator = torch.nn.Fold( + output_size=(height, width), + kernel_size=(tile_size, tile_size), + stride=(tile_stride, tile_stride) + ) + mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1]) + model_output = rearrange(model_output, "b c h w n -> b (c h w) n") + model_output = fold_operator(model_output) / fold_operator(mask) + + return model_output + + + def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None): + # Prepare + inference_device, inference_dtype = model_input.device, model_input.dtype + height, width = model_input.shape[2], model_input.shape[3] + border_width = int(tile_stride*0.5) if border_width is None else border_width + + # tile + model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype) + + # inference + model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype) + + # resize + io_scale = self.io_scale(model_output, tile_size) + height, width = int(height*io_scale), int(width*io_scale) + tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale) + border_width = int(border_width*io_scale) + + # untile + model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype) + + # Done! + model_output = model_output.to(device=inference_device, dtype=inference_dtype) + return model_output \ No newline at end of file diff --git a/diffsynth/pipelines/__init__.py b/diffsynth/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9482fe3bf5e774d0886728d4c448b0d2a636f825 --- /dev/null +++ b/diffsynth/pipelines/__init__.py @@ -0,0 +1,6 @@ +from .stable_diffusion import SDImagePipeline +from .stable_diffusion_xl import SDXLImagePipeline +from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner +from .stable_diffusion_xl_video import SDXLVideoPipeline +from .stable_video_diffusion import SVDVideoPipeline +from .hunyuan_dit import HunyuanDiTImagePipeline diff --git a/diffsynth/pipelines/dancer.py b/diffsynth/pipelines/dancer.py new file mode 100644 index 0000000000000000000000000000000000000000..d01a7f3f2b614de790687479e29788f02748cea5 --- /dev/null +++ b/diffsynth/pipelines/dancer.py @@ -0,0 +1,174 @@ +import torch +from ..models import SDUNet, SDMotionModel, SDXLUNet, SDXLMotionModel +from ..models.sd_unet import PushBlock, PopBlock +from ..controlnets import MultiControlNetManager + + +def lets_dance( + unet: SDUNet, + motion_modules: SDMotionModel = None, + controlnet: MultiControlNetManager = None, + sample = None, + timestep = None, + encoder_hidden_states = None, + ipadapter_kwargs_list = {}, + controlnet_frames = None, + unet_batch_size = 1, + controlnet_batch_size = 1, + cross_frame_attention = False, + tiled=False, + tile_size=64, + tile_stride=32, + device = "cuda", + vram_limit_level = 0, +): + # 1. ControlNet + # This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride. + # I leave it here because I intend to do something interesting on the ControlNets. + controlnet_insert_block_id = 30 + if controlnet is not None and controlnet_frames is not None: + res_stacks = [] + # process controlnet frames with batch + for batch_id in range(0, sample.shape[0], controlnet_batch_size): + batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0]) + res_stack = controlnet( + sample[batch_id: batch_id_], + timestep, + encoder_hidden_states[batch_id: batch_id_], + controlnet_frames[:, batch_id: batch_id_], + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride + ) + if vram_limit_level >= 1: + res_stack = [res.cpu() for res in res_stack] + res_stacks.append(res_stack) + # concat the residual + additional_res_stack = [] + for i in range(len(res_stacks[0])): + res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0) + additional_res_stack.append(res) + else: + additional_res_stack = None + + # 2. time + time_emb = unet.time_proj(timestep[None]).to(sample.dtype) + time_emb = unet.time_embedding(time_emb) + + # 3. pre-process + height, width = sample.shape[2], sample.shape[3] + hidden_states = unet.conv_in(sample) + text_emb = encoder_hidden_states + res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states] + + # 4. blocks + for block_id, block in enumerate(unet.blocks): + # 4.1 UNet + if isinstance(block, PushBlock): + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + if vram_limit_level>=1: + res_stack[-1] = res_stack[-1].cpu() + elif isinstance(block, PopBlock): + if vram_limit_level>=1: + res_stack[-1] = res_stack[-1].to(device) + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + else: + hidden_states_input = hidden_states + hidden_states_output = [] + for batch_id in range(0, sample.shape[0], unet_batch_size): + batch_id_ = min(batch_id + unet_batch_size, sample.shape[0]) + hidden_states, _, _, _ = block( + hidden_states_input[batch_id: batch_id_], + time_emb, + text_emb[batch_id: batch_id_], + res_stack, + cross_frame_attention=cross_frame_attention, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}), + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride + ) + hidden_states_output.append(hidden_states) + hidden_states = torch.concat(hidden_states_output, dim=0) + # 4.2 AnimateDiff + if motion_modules is not None: + if block_id in motion_modules.call_block_id: + motion_module_id = motion_modules.call_block_id[block_id] + hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id]( + hidden_states, time_emb, text_emb, res_stack, + batch_size=1 + ) + # 4.3 ControlNet + if block_id == controlnet_insert_block_id and additional_res_stack is not None: + hidden_states += additional_res_stack.pop().to(device) + if vram_limit_level>=1: + res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)] + else: + res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)] + + # 5. output + hidden_states = unet.conv_norm_out(hidden_states) + hidden_states = unet.conv_act(hidden_states) + hidden_states = unet.conv_out(hidden_states) + + return hidden_states + + + + +def lets_dance_xl( + unet: SDXLUNet, + motion_modules: SDXLMotionModel = None, + controlnet: MultiControlNetManager = None, + sample = None, + add_time_id = None, + add_text_embeds = None, + timestep = None, + encoder_hidden_states = None, + ipadapter_kwargs_list = {}, + controlnet_frames = None, + unet_batch_size = 1, + controlnet_batch_size = 1, + cross_frame_attention = False, + tiled=False, + tile_size=64, + tile_stride=32, + device = "cuda", + vram_limit_level = 0, +): + # 2. time + t_emb = unet.time_proj(timestep[None]).to(sample.dtype) + t_emb = unet.time_embedding(t_emb) + + time_embeds = unet.add_time_proj(add_time_id) + time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1)) + add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(sample.dtype) + add_embeds = unet.add_time_embedding(add_embeds) + + time_emb = t_emb + add_embeds + + # 3. pre-process + height, width = sample.shape[2], sample.shape[3] + hidden_states = unet.conv_in(sample) + text_emb = encoder_hidden_states + res_stack = [hidden_states] + + # 4. blocks + for block_id, block in enumerate(unet.blocks): + hidden_states, time_emb, text_emb, res_stack = block( + hidden_states, time_emb, text_emb, res_stack, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}) + ) + # 4.2 AnimateDiff + if motion_modules is not None: + if block_id in motion_modules.call_block_id: + motion_module_id = motion_modules.call_block_id[block_id] + hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id]( + hidden_states, time_emb, text_emb, res_stack, + batch_size=1 + ) + + # 5. output + hidden_states = unet.conv_norm_out(hidden_states) + hidden_states = unet.conv_act(hidden_states) + hidden_states = unet.conv_out(hidden_states) + + return hidden_states \ No newline at end of file diff --git a/diffsynth/pipelines/hunyuan_dit.py b/diffsynth/pipelines/hunyuan_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..6d592e4a97ab433980dd64b4d73bc8349183639a --- /dev/null +++ b/diffsynth/pipelines/hunyuan_dit.py @@ -0,0 +1,298 @@ +from ..models.hunyuan_dit import HunyuanDiT +from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder +from ..models.sdxl_vae_encoder import SDXLVAEEncoder +from ..models.sdxl_vae_decoder import SDXLVAEDecoder +from ..models import ModelManager +from ..prompts import HunyuanDiTPrompter +from ..schedulers import EnhancedDDIMScheduler +import torch +from tqdm import tqdm +from PIL import Image +import numpy as np + + + +class ImageSizeManager: + def __init__(self): + pass + + + def _to_tuple(self, x): + if isinstance(x, int): + return x, x + else: + return x + + + def get_fill_resize_and_crop(self, src, tgt): + th, tw = self._to_tuple(tgt) + h, w = self._to_tuple(src) + + tr = th / tw # base 分辨率 + r = h / w # 目标分辨率 + + # resize + if r > tr: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来 + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + + def get_meshgrid(self, start, *args): + if len(args) == 0: + # start is grid_size + num = self._to_tuple(start) + start = (0, 0) + stop = num + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + start = self._to_tuple(start) + stop = self._to_tuple(args[0]) + num = (stop[0] - start[0], stop[1] - start[1]) + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + start = self._to_tuple(start) # 左上角 eg: 12,0 + stop = self._to_tuple(args[0]) # 右下角 eg: 20,32 + num = self._to_tuple(args[1]) # 目标大小 eg: 32,124 + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份 + grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) # [2, W, H] + return grid + + + def get_2d_rotary_pos_embed(self, embed_dim, start, *args, use_real=True): + grid = self.get_meshgrid(start, *args) # [2, H, w] + grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致 + pos_embed = self.get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) + return pos_embed + + + def get_2d_rotary_pos_embed_from_grid(self, embed_dim, grid, use_real=False): + assert embed_dim % 4 == 0 + + # use half of dimensions to encode grid_h + emb_h = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4) + emb_w = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4) + + if use_real: + cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2) + sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2) + return cos, sin + else: + emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) + return emb + + + def get_1d_rotary_pos_embed(self, dim: int, pos, theta: float = 10000.0, use_real=False): + if isinstance(pos, int): + pos = np.arange(pos) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] + t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] + freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + + def calc_rope(self, height, width): + patch_size = 2 + head_size = 88 + th = height // 8 // patch_size + tw = width // 8 // patch_size + base_size = 512 // 8 // patch_size + start, stop = self.get_fill_resize_and_crop((th, tw), base_size) + sub_args = [start, stop, (th, tw)] + rope = self.get_2d_rotary_pos_embed(head_size, *sub_args) + return rope + + + +class HunyuanDiTImagePipeline(torch.nn.Module): + + def __init__(self, device="cuda", torch_dtype=torch.float16): + super().__init__() + self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03) + self.prompter = HunyuanDiTPrompter() + self.device = device + self.torch_dtype = torch_dtype + self.image_size_manager = ImageSizeManager() + # models + self.text_encoder: HunyuanDiTCLIPTextEncoder = None + self.text_encoder_t5: HunyuanDiTT5TextEncoder = None + self.dit: HunyuanDiT = None + self.vae_decoder: SDXLVAEDecoder = None + self.vae_encoder: SDXLVAEEncoder = None + + + def fetch_main_models(self, model_manager: ModelManager): + self.text_encoder = model_manager.hunyuan_dit_clip_text_encoder + self.text_encoder_t5 = model_manager.hunyuan_dit_t5_text_encoder + self.dit = model_manager.hunyuan_dit + self.vae_decoder = model_manager.vae_decoder + self.vae_encoder = model_manager.vae_encoder + + + def fetch_prompter(self, model_manager: ModelManager): + self.prompter.load_from_model_manager(model_manager) + + + @staticmethod + def from_model_manager(model_manager: ModelManager): + pipe = HunyuanDiTImagePipeline( + device=model_manager.device, + torch_dtype=model_manager.torch_dtype, + ) + pipe.fetch_main_models(model_manager) + pipe.fetch_prompter(model_manager) + return pipe + + + def preprocess_image(self, image): + image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) + return image + + + def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): + image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + image = image.cpu().permute(1, 2, 0).numpy() + image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) + return image + + + def prepare_extra_input(self, height=1024, width=1024, tiled=False, tile_size=64, tile_stride=32, batch_size=1): + if tiled: + height, width = tile_size * 16, tile_size * 16 + image_meta_size = torch.as_tensor([width, height, width, height, 0, 0]).to(device=self.device) + freqs_cis_img = self.image_size_manager.calc_rope(height, width) + image_meta_size = torch.stack([image_meta_size] * batch_size) + return { + "size_emb": image_meta_size, + "freq_cis_img": (freqs_cis_img[0].to(dtype=self.torch_dtype, device=self.device), freqs_cis_img[1].to(dtype=self.torch_dtype, device=self.device)), + "tiled": tiled, + "tile_size": tile_size, + "tile_stride": tile_stride + } + + + @torch.no_grad() + def __call__( + self, + prompt, + negative_prompt="", + cfg_scale=7.5, + clip_skip=1, + clip_skip_2=1, + input_image=None, + reference_images=[], + reference_strengths=[0.4], + denoising_strength=1.0, + height=1024, + width=1024, + num_inference_steps=20, + tiled=False, + tile_size=64, + tile_stride=32, + progress_bar_cmd=tqdm, + progress_bar_st=None, + ): + # Prepare scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength) + + # Prepare latent tensors + noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) + if input_image is not None: + image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) + latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype) + latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) + else: + latents = noise.clone() + + # Prepare reference latents + reference_latents = [] + for reference_image in reference_images: + reference_image = self.preprocess_image(reference_image).to(device=self.device, dtype=self.torch_dtype) + reference_latents.append(self.vae_encoder(reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)) + + # Encode prompts + prompt_emb_posi, attention_mask_posi, prompt_emb_t5_posi, attention_mask_t5_posi = self.prompter.encode_prompt( + self.text_encoder, + self.text_encoder_t5, + prompt, + clip_skip=clip_skip, + clip_skip_2=clip_skip_2, + positive=True, + device=self.device + ) + if cfg_scale != 1.0: + prompt_emb_nega, attention_mask_nega, prompt_emb_t5_nega, attention_mask_t5_nega = self.prompter.encode_prompt( + self.text_encoder, + self.text_encoder_t5, + negative_prompt, + clip_skip=clip_skip, + clip_skip_2=clip_skip_2, + positive=False, + device=self.device + ) + + # Prepare positional id + extra_input = self.prepare_extra_input(height, width, tiled, tile_size) + + # Denoise + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device) + + # In-context reference + for reference_latents_, reference_strength in zip(reference_latents, reference_strengths): + if progress_id < num_inference_steps * reference_strength: + noisy_reference_latents = self.scheduler.add_noise(reference_latents_, noise, self.scheduler.timesteps[progress_id]) + self.dit( + noisy_reference_latents, + prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi, + timestep, + **extra_input, + to_cache=True + ) + # Positive side + noise_pred_posi = self.dit( + latents, + prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi, + timestep, + **extra_input, + ) + if cfg_scale != 1.0: + # Negative side + noise_pred_nega = self.dit( + latents, + prompt_emb_nega, prompt_emb_t5_nega, attention_mask_nega, attention_mask_t5_nega, + timestep, + **extra_input + ) + # Classifier-free guidance + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) + + if progress_bar_st is not None: + progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) + + # Decode image + image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + + return image diff --git a/diffsynth/pipelines/stable_diffusion.py b/diffsynth/pipelines/stable_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..4f71f89d3dbab3144dc516e7daadb0519196b169 --- /dev/null +++ b/diffsynth/pipelines/stable_diffusion.py @@ -0,0 +1,167 @@ +from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder +from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator +from ..prompts import SDPrompter +from ..schedulers import EnhancedDDIMScheduler +from .dancer import lets_dance +from typing import List +import torch +from tqdm import tqdm +from PIL import Image +import numpy as np + + +class SDImagePipeline(torch.nn.Module): + + def __init__(self, device="cuda", torch_dtype=torch.float16): + super().__init__() + self.scheduler = EnhancedDDIMScheduler() + self.prompter = SDPrompter() + self.device = device + self.torch_dtype = torch_dtype + # models + self.text_encoder: SDTextEncoder = None + self.unet: SDUNet = None + self.vae_decoder: SDVAEDecoder = None + self.vae_encoder: SDVAEEncoder = None + self.controlnet: MultiControlNetManager = None + self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None + self.ipadapter: SDIpAdapter = None + + + def fetch_main_models(self, model_manager: ModelManager): + self.text_encoder = model_manager.text_encoder + self.unet = model_manager.unet + self.vae_decoder = model_manager.vae_decoder + self.vae_encoder = model_manager.vae_encoder + + + def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]): + controlnet_units = [] + for config in controlnet_config_units: + controlnet_unit = ControlNetUnit( + Annotator(config.processor_id), + model_manager.get_model_with_model_path(config.model_path), + config.scale + ) + controlnet_units.append(controlnet_unit) + self.controlnet = MultiControlNetManager(controlnet_units) + + + def fetch_ipadapter(self, model_manager: ModelManager): + if "ipadapter" in model_manager.model: + self.ipadapter = model_manager.ipadapter + if "ipadapter_image_encoder" in model_manager.model: + self.ipadapter_image_encoder = model_manager.ipadapter_image_encoder + + + def fetch_prompter(self, model_manager: ModelManager): + self.prompter.load_from_model_manager(model_manager) + + + @staticmethod + def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]): + pipe = SDImagePipeline( + device=model_manager.device, + torch_dtype=model_manager.torch_dtype, + ) + pipe.fetch_main_models(model_manager) + pipe.fetch_prompter(model_manager) + pipe.fetch_controlnet_models(model_manager, controlnet_config_units) + pipe.fetch_ipadapter(model_manager) + return pipe + + + def preprocess_image(self, image): + image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) + return image + + + def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): + image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + image = image.cpu().permute(1, 2, 0).numpy() + image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) + return image + + + @torch.no_grad() + def __call__( + self, + prompt, + negative_prompt="", + cfg_scale=7.5, + clip_skip=1, + input_image=None, + ipadapter_images=None, + ipadapter_scale=1.0, + controlnet_image=None, + denoising_strength=1.0, + height=512, + width=512, + num_inference_steps=20, + tiled=False, + tile_size=64, + tile_stride=32, + progress_bar_cmd=tqdm, + progress_bar_st=None, + ): + # Prepare scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength) + + # Prepare latent tensors + if input_image is not None: + image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) + latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) + latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) + else: + latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) + + # Encode prompts + prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True) + prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False) + + # IP-Adapter + if ipadapter_images is not None: + ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images) + ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale) + ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding)) + else: + ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {} + + # Prepare ControlNets + if controlnet_image is not None: + controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype) + controlnet_image = controlnet_image.unsqueeze(1) + + # Denoise + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = torch.IntTensor((timestep,))[0].to(self.device) + + # Classifier-free guidance + noise_pred_posi = lets_dance( + self.unet, motion_modules=None, controlnet=self.controlnet, + sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_image, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, + ipadapter_kwargs_list=ipadapter_kwargs_list_posi, + device=self.device, vram_limit_level=0 + ) + noise_pred_nega = lets_dance( + self.unet, motion_modules=None, controlnet=self.controlnet, + sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_image, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, + ipadapter_kwargs_list=ipadapter_kwargs_list_nega, + device=self.device, vram_limit_level=0 + ) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + + # DDIM + latents = self.scheduler.step(noise_pred, timestep, latents) + + # UI + if progress_bar_st is not None: + progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) + + # Decode image + image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + + return image diff --git a/diffsynth/pipelines/stable_diffusion_video.py b/diffsynth/pipelines/stable_diffusion_video.py new file mode 100644 index 0000000000000000000000000000000000000000..b204cad43cbe1feafe2ba927062e3a974e3bae9e --- /dev/null +++ b/diffsynth/pipelines/stable_diffusion_video.py @@ -0,0 +1,356 @@ +from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDMotionModel +from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator +from ..prompts import SDPrompter +from ..schedulers import EnhancedDDIMScheduler +from ..data import VideoData, save_frames, save_video +from .dancer import lets_dance +from ..processors.sequencial_processor import SequencialProcessor +from typing import List +import torch, os, json +from tqdm import tqdm +from PIL import Image +import numpy as np + + +def lets_dance_with_long_video( + unet: SDUNet, + motion_modules: SDMotionModel = None, + controlnet: MultiControlNetManager = None, + sample = None, + timestep = None, + encoder_hidden_states = None, + controlnet_frames = None, + animatediff_batch_size = 16, + animatediff_stride = 8, + unet_batch_size = 1, + controlnet_batch_size = 1, + cross_frame_attention = False, + device = "cuda", + vram_limit_level = 0, +): + num_frames = sample.shape[0] + hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)] + + for batch_id in range(0, num_frames, animatediff_stride): + batch_id_ = min(batch_id + animatediff_batch_size, num_frames) + + # process this batch + hidden_states_batch = lets_dance( + unet, motion_modules, controlnet, + sample[batch_id: batch_id_].to(device), + timestep, + encoder_hidden_states[batch_id: batch_id_].to(device), + controlnet_frames=controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None, + unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size, + cross_frame_attention=cross_frame_attention, + device=device, vram_limit_level=vram_limit_level + ).cpu() + + # update hidden_states + for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch): + bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2) + hidden_states, num = hidden_states_output[i] + hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias)) + hidden_states_output[i] = (hidden_states, num + bias) + + if batch_id_ == num_frames: + break + + # output + hidden_states = torch.stack([h for h, _ in hidden_states_output]) + return hidden_states + + +class SDVideoPipeline(torch.nn.Module): + + def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True): + super().__init__() + self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear") + self.prompter = SDPrompter() + self.device = device + self.torch_dtype = torch_dtype + # models + self.text_encoder: SDTextEncoder = None + self.unet: SDUNet = None + self.vae_decoder: SDVAEDecoder = None + self.vae_encoder: SDVAEEncoder = None + self.controlnet: MultiControlNetManager = None + self.motion_modules: SDMotionModel = None + + + def fetch_main_models(self, model_manager: ModelManager): + self.text_encoder = model_manager.text_encoder + self.unet = model_manager.unet + self.vae_decoder = model_manager.vae_decoder + self.vae_encoder = model_manager.vae_encoder + + + def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]): + controlnet_units = [] + for config in controlnet_config_units: + controlnet_unit = ControlNetUnit( + Annotator(config.processor_id), + model_manager.get_model_with_model_path(config.model_path), + config.scale + ) + controlnet_units.append(controlnet_unit) + self.controlnet = MultiControlNetManager(controlnet_units) + + + def fetch_motion_modules(self, model_manager: ModelManager): + if "motion_modules" in model_manager.model: + self.motion_modules = model_manager.motion_modules + + + def fetch_prompter(self, model_manager: ModelManager): + self.prompter.load_from_model_manager(model_manager) + + + @staticmethod + def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]): + pipe = SDVideoPipeline( + device=model_manager.device, + torch_dtype=model_manager.torch_dtype, + use_animatediff="motion_modules" in model_manager.model + ) + pipe.fetch_main_models(model_manager) + pipe.fetch_motion_modules(model_manager) + pipe.fetch_prompter(model_manager) + pipe.fetch_controlnet_models(model_manager, controlnet_config_units) + return pipe + + + def preprocess_image(self, image): + image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) + return image + + + def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): + image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + image = image.cpu().permute(1, 2, 0).numpy() + image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) + return image + + + def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32): + images = [ + self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + for frame_id in range(latents.shape[0]) + ] + return images + + + def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32): + latents = [] + for image in processed_images: + image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) + latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu() + latents.append(latent) + latents = torch.concat(latents, dim=0) + return latents + + + @torch.no_grad() + def __call__( + self, + prompt, + negative_prompt="", + cfg_scale=7.5, + clip_skip=1, + num_frames=None, + input_frames=None, + controlnet_frames=None, + denoising_strength=1.0, + height=512, + width=512, + num_inference_steps=20, + animatediff_batch_size = 16, + animatediff_stride = 8, + unet_batch_size = 1, + controlnet_batch_size = 1, + cross_frame_attention = False, + smoother=None, + smoother_progress_ids=[], + vram_limit_level=0, + progress_bar_cmd=tqdm, + progress_bar_st=None, + ): + # Prepare scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength) + + # Prepare latent tensors + if self.motion_modules is None: + noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1) + else: + noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype) + if input_frames is None or denoising_strength == 1.0: + latents = noise + else: + latents = self.encode_images(input_frames) + latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) + + # Encode prompts + prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True).cpu() + prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False).cpu() + prompt_emb_posi = prompt_emb_posi.repeat(num_frames, 1, 1) + prompt_emb_nega = prompt_emb_nega.repeat(num_frames, 1, 1) + + # Prepare ControlNets + if controlnet_frames is not None: + if isinstance(controlnet_frames[0], list): + controlnet_frames_ = [] + for processor_id in range(len(controlnet_frames)): + controlnet_frames_.append( + torch.stack([ + self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype) + for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id]) + ], dim=1) + ) + controlnet_frames = torch.concat(controlnet_frames_, dim=0) + else: + controlnet_frames = torch.stack([ + self.controlnet.process_image(controlnet_frame).to(self.torch_dtype) + for controlnet_frame in progress_bar_cmd(controlnet_frames) + ], dim=1) + + # Denoise + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = torch.IntTensor((timestep,))[0].to(self.device) + + # Classifier-free guidance + noise_pred_posi = lets_dance_with_long_video( + self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet, + sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames, + animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride, + unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size, + cross_frame_attention=cross_frame_attention, + device=self.device, vram_limit_level=vram_limit_level + ) + noise_pred_nega = lets_dance_with_long_video( + self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet, + sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames, + animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride, + unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size, + cross_frame_attention=cross_frame_attention, + device=self.device, vram_limit_level=vram_limit_level + ) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + + # DDIM and smoother + if smoother is not None and progress_id in smoother_progress_ids: + rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True) + rendered_frames = self.decode_images(rendered_frames) + rendered_frames = smoother(rendered_frames, original_frames=input_frames) + target_latents = self.encode_images(rendered_frames) + noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents) + latents = self.scheduler.step(noise_pred, timestep, latents) + + # UI + if progress_bar_st is not None: + progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) + + # Decode image + output_frames = self.decode_images(latents) + + # Post-process + if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids): + output_frames = smoother(output_frames, original_frames=input_frames) + + return output_frames + + + +class SDVideoPipelineRunner: + def __init__(self, in_streamlit=False): + self.in_streamlit = in_streamlit + + + def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units): + # Load models + model_manager = ModelManager(torch_dtype=torch.float16, device=device) + model_manager.load_textual_inversions(textual_inversion_folder) + model_manager.load_models(model_list, lora_alphas=lora_alphas) + pipe = SDVideoPipeline.from_model_manager( + model_manager, + [ + ControlNetConfigUnit( + processor_id=unit["processor_id"], + model_path=unit["model_path"], + scale=unit["scale"] + ) for unit in controlnet_units + ] + ) + return model_manager, pipe + + + def load_smoother(self, model_manager, smoother_configs): + smoother = SequencialProcessor.from_model_manager(model_manager, smoother_configs) + return smoother + + + def synthesize_video(self, model_manager, pipe, seed, smoother, **pipeline_inputs): + torch.manual_seed(seed) + if self.in_streamlit: + import streamlit as st + progress_bar_st = st.progress(0.0) + output_video = pipe(**pipeline_inputs, smoother=smoother, progress_bar_st=progress_bar_st) + progress_bar_st.progress(1.0) + else: + output_video = pipe(**pipeline_inputs, smoother=smoother) + model_manager.to("cpu") + return output_video + + + def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id): + video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width) + if start_frame_id is None: + start_frame_id = 0 + if end_frame_id is None: + end_frame_id = len(video) + frames = [video[i] for i in range(start_frame_id, end_frame_id)] + return frames + + + def add_data_to_pipeline_inputs(self, data, pipeline_inputs): + pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"]) + pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"]) + pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size + if len(data["controlnet_frames"]) > 0: + pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]] + return pipeline_inputs + + + def save_output(self, video, output_folder, fps, config): + os.makedirs(output_folder, exist_ok=True) + save_frames(video, os.path.join(output_folder, "frames")) + save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps) + config["pipeline"]["pipeline_inputs"]["input_frames"] = [] + config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = [] + with open(os.path.join(output_folder, "config.json"), 'w') as file: + json.dump(config, file, indent=4) + + + def run(self, config): + if self.in_streamlit: + import streamlit as st + if self.in_streamlit: st.markdown("Loading videos ...") + config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"]) + if self.in_streamlit: st.markdown("Loading videos ... done!") + if self.in_streamlit: st.markdown("Loading models ...") + model_manager, pipe = self.load_pipeline(**config["models"]) + if self.in_streamlit: st.markdown("Loading models ... done!") + if "smoother_configs" in config: + if self.in_streamlit: st.markdown("Loading smoother ...") + smoother = self.load_smoother(model_manager, config["smoother_configs"]) + if self.in_streamlit: st.markdown("Loading smoother ... done!") + else: + smoother = None + if self.in_streamlit: st.markdown("Synthesizing videos ...") + output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], smoother, **config["pipeline"]["pipeline_inputs"]) + if self.in_streamlit: st.markdown("Synthesizing videos ... done!") + if self.in_streamlit: st.markdown("Saving videos ...") + self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config) + if self.in_streamlit: st.markdown("Saving videos ... done!") + if self.in_streamlit: st.markdown("Finished!") + video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb') + if self.in_streamlit: st.video(video_file.read()) diff --git a/diffsynth/pipelines/stable_diffusion_xl.py b/diffsynth/pipelines/stable_diffusion_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..cf43fb3ae164564c1cf4d4406e93cd847655afe0 --- /dev/null +++ b/diffsynth/pipelines/stable_diffusion_xl.py @@ -0,0 +1,175 @@ +from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder +# TODO: SDXL ControlNet +from ..prompts import SDXLPrompter +from ..schedulers import EnhancedDDIMScheduler +from .dancer import lets_dance_xl +import torch +from tqdm import tqdm +from PIL import Image +import numpy as np + + +class SDXLImagePipeline(torch.nn.Module): + + def __init__(self, device="cuda", torch_dtype=torch.float16): + super().__init__() + self.scheduler = EnhancedDDIMScheduler() + self.prompter = SDXLPrompter() + self.device = device + self.torch_dtype = torch_dtype + # models + self.text_encoder: SDXLTextEncoder = None + self.text_encoder_2: SDXLTextEncoder2 = None + self.unet: SDXLUNet = None + self.vae_decoder: SDXLVAEDecoder = None + self.vae_encoder: SDXLVAEEncoder = None + self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None + self.ipadapter: SDXLIpAdapter = None + # TODO: SDXL ControlNet + + def fetch_main_models(self, model_manager: ModelManager): + self.text_encoder = model_manager.text_encoder + self.text_encoder_2 = model_manager.text_encoder_2 + self.unet = model_manager.unet + self.vae_decoder = model_manager.vae_decoder + self.vae_encoder = model_manager.vae_encoder + + + def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs): + # TODO: SDXL ControlNet + pass + + + def fetch_ipadapter(self, model_manager: ModelManager): + if "ipadapter_xl" in model_manager.model: + self.ipadapter = model_manager.ipadapter_xl + if "ipadapter_xl_image_encoder" in model_manager.model: + self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder + + + def fetch_prompter(self, model_manager: ModelManager): + self.prompter.load_from_model_manager(model_manager) + + + @staticmethod + def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs): + pipe = SDXLImagePipeline( + device=model_manager.device, + torch_dtype=model_manager.torch_dtype, + ) + pipe.fetch_main_models(model_manager) + pipe.fetch_prompter(model_manager) + pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units) + pipe.fetch_ipadapter(model_manager) + return pipe + + + def preprocess_image(self, image): + image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) + return image + + + def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): + image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + image = image.cpu().permute(1, 2, 0).numpy() + image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) + return image + + + @torch.no_grad() + def __call__( + self, + prompt, + negative_prompt="", + cfg_scale=7.5, + clip_skip=1, + clip_skip_2=2, + input_image=None, + ipadapter_images=None, + ipadapter_scale=1.0, + controlnet_image=None, + denoising_strength=1.0, + height=1024, + width=1024, + num_inference_steps=20, + tiled=False, + tile_size=64, + tile_stride=32, + progress_bar_cmd=tqdm, + progress_bar_st=None, + ): + # Prepare scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength) + + # Prepare latent tensors + if input_image is not None: + image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) + latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype) + noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) + latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) + else: + latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) + + # Encode prompts + add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt( + self.text_encoder, + self.text_encoder_2, + prompt, + clip_skip=clip_skip, clip_skip_2=clip_skip_2, + device=self.device, + positive=True, + ) + if cfg_scale != 1.0: + add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt( + self.text_encoder, + self.text_encoder_2, + negative_prompt, + clip_skip=clip_skip, clip_skip_2=clip_skip_2, + device=self.device, + positive=False, + ) + + # Prepare positional id + add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device) + + # IP-Adapter + if ipadapter_images is not None: + ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images) + ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale) + ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding)) + else: + ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {} + + # Denoise + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = torch.IntTensor((timestep,))[0].to(self.device) + + # Classifier-free guidance + noise_pred_posi = lets_dance_xl( + self.unet, + sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, + add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, + ipadapter_kwargs_list=ipadapter_kwargs_list_posi, + ) + if cfg_scale != 1.0: + noise_pred_nega = lets_dance_xl( + self.unet, + sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, + add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, + ipadapter_kwargs_list=ipadapter_kwargs_list_nega, + ) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + latents = self.scheduler.step(noise_pred, timestep, latents) + + if progress_bar_st is not None: + progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) + + # Decode image + image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + + return image diff --git a/diffsynth/pipelines/stable_diffusion_xl_video.py b/diffsynth/pipelines/stable_diffusion_xl_video.py new file mode 100644 index 0000000000000000000000000000000000000000..ceb4c674cb9577221c8b27529c886acc73f6136c --- /dev/null +++ b/diffsynth/pipelines/stable_diffusion_xl_video.py @@ -0,0 +1,190 @@ +from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLMotionModel +from .dancer import lets_dance_xl +# TODO: SDXL ControlNet +from ..prompts import SDXLPrompter +from ..schedulers import EnhancedDDIMScheduler +import torch +from tqdm import tqdm +from PIL import Image +import numpy as np + + +class SDXLVideoPipeline(torch.nn.Module): + + def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True): + super().__init__() + self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear") + self.prompter = SDXLPrompter() + self.device = device + self.torch_dtype = torch_dtype + # models + self.text_encoder: SDXLTextEncoder = None + self.text_encoder_2: SDXLTextEncoder2 = None + self.unet: SDXLUNet = None + self.vae_decoder: SDXLVAEDecoder = None + self.vae_encoder: SDXLVAEEncoder = None + # TODO: SDXL ControlNet + self.motion_modules: SDXLMotionModel = None + + + def fetch_main_models(self, model_manager: ModelManager): + self.text_encoder = model_manager.text_encoder + self.text_encoder_2 = model_manager.text_encoder_2 + self.unet = model_manager.unet + self.vae_decoder = model_manager.vae_decoder + self.vae_encoder = model_manager.vae_encoder + + + def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs): + # TODO: SDXL ControlNet + pass + + + def fetch_motion_modules(self, model_manager: ModelManager): + if "motion_modules_xl" in model_manager.model: + self.motion_modules = model_manager.motion_modules_xl + + + def fetch_prompter(self, model_manager: ModelManager): + self.prompter.load_from_model_manager(model_manager) + + + @staticmethod + def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs): + pipe = SDXLVideoPipeline( + device=model_manager.device, + torch_dtype=model_manager.torch_dtype, + use_animatediff="motion_modules_xl" in model_manager.model + ) + pipe.fetch_main_models(model_manager) + pipe.fetch_motion_modules(model_manager) + pipe.fetch_prompter(model_manager) + pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units) + return pipe + + + def preprocess_image(self, image): + image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) + return image + + + def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): + image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + image = image.cpu().permute(1, 2, 0).numpy() + image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) + return image + + + def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32): + images = [ + self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + for frame_id in range(latents.shape[0]) + ] + return images + + + def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32): + latents = [] + for image in processed_images: + image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) + latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu() + latents.append(latent) + latents = torch.concat(latents, dim=0) + return latents + + + @torch.no_grad() + def __call__( + self, + prompt, + negative_prompt="", + cfg_scale=7.5, + clip_skip=1, + clip_skip_2=2, + num_frames=None, + input_frames=None, + controlnet_frames=None, + denoising_strength=1.0, + height=512, + width=512, + num_inference_steps=20, + animatediff_batch_size = 16, + animatediff_stride = 8, + unet_batch_size = 1, + controlnet_batch_size = 1, + cross_frame_attention = False, + smoother=None, + smoother_progress_ids=[], + vram_limit_level=0, + progress_bar_cmd=tqdm, + progress_bar_st=None, + ): + # Prepare scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength) + + # Prepare latent tensors + if self.motion_modules is None: + noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1) + else: + noise = torch.randn((num_frames, 4, height//8, width//8), device="cuda", dtype=self.torch_dtype) + if input_frames is None or denoising_strength == 1.0: + latents = noise + else: + latents = self.encode_images(input_frames) + latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) + + # Encode prompts + add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt( + self.text_encoder, + self.text_encoder_2, + prompt, + clip_skip=clip_skip, clip_skip_2=clip_skip_2, + device=self.device, + positive=True, + ) + if cfg_scale != 1.0: + add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt( + self.text_encoder, + self.text_encoder_2, + negative_prompt, + clip_skip=clip_skip, clip_skip_2=clip_skip_2, + device=self.device, + positive=False, + ) + + # Prepare positional id + add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device) + + # Denoise + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = torch.IntTensor((timestep,))[0].to(self.device) + + # Classifier-free guidance + noise_pred_posi = lets_dance_xl( + self.unet, motion_modules=self.motion_modules, controlnet=None, + sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi, + timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames, + cross_frame_attention=cross_frame_attention, + device=self.device, vram_limit_level=vram_limit_level + ) + if cfg_scale != 1.0: + noise_pred_nega = lets_dance_xl( + self.unet, motion_modules=self.motion_modules, controlnet=None, + sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega, + timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames, + cross_frame_attention=cross_frame_attention, + device=self.device, vram_limit_level=vram_limit_level + ) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + latents = self.scheduler.step(noise_pred, timestep, latents) + + if progress_bar_st is not None: + progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) + + # Decode image + image = self.decode_images(latents.to(torch.float32)) + + return image diff --git a/diffsynth/pipelines/stable_video_diffusion.py b/diffsynth/pipelines/stable_video_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..431303c5394b08a5ae3835bb0989d57083ab8029 --- /dev/null +++ b/diffsynth/pipelines/stable_video_diffusion.py @@ -0,0 +1,307 @@ +from ..models import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, SVDVAEDecoder +from ..schedulers import ContinuousODEScheduler +import torch +from tqdm import tqdm +from PIL import Image +import numpy as np +from einops import rearrange, repeat + + + +class SVDVideoPipeline(torch.nn.Module): + + def __init__(self, device="cuda", torch_dtype=torch.float16): + super().__init__() + self.scheduler = ContinuousODEScheduler() + self.device = device + self.torch_dtype = torch_dtype + # models + self.image_encoder: SVDImageEncoder = None + self.unet: SVDUNet = None + self.vae_encoder: SVDVAEEncoder = None + self.vae_decoder: SVDVAEDecoder = None + + + def fetch_main_models(self, model_manager: ModelManager): + self.image_encoder = model_manager.image_encoder + self.unet = model_manager.unet + self.vae_encoder = model_manager.vae_encoder + self.vae_decoder = model_manager.vae_decoder + + + @staticmethod + def from_model_manager(model_manager: ModelManager, **kwargs): + pipe = SVDVideoPipeline(device=model_manager.device, torch_dtype=model_manager.torch_dtype) + pipe.fetch_main_models(model_manager) + return pipe + + + def preprocess_image(self, image): + image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) + return image + + + def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): + image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + image = image.cpu().permute(1, 2, 0).numpy() + image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) + return image + + + def encode_image_with_clip(self, image): + image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) + image = SVDCLIPImageProcessor().resize_with_antialiasing(image, (224, 224)) + image = (image + 1.0) / 2.0 + mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype) + std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype) + image = (image - mean) / std + image_emb = self.image_encoder(image) + return image_emb + + + def encode_image_with_vae(self, image, noise_aug_strength): + image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) + noise = torch.randn(image.shape, device="cpu", dtype=self.torch_dtype).to(self.device) + image = image + noise_aug_strength * noise + image_emb = self.vae_encoder(image) / self.vae_encoder.scaling_factor + return image_emb + + + def encode_video_with_vae(self, video): + video = torch.concat([self.preprocess_image(frame) for frame in video], dim=0) + video = rearrange(video, "T C H W -> 1 C T H W") + video = video.to(device=self.device, dtype=self.torch_dtype) + latents = self.vae_encoder.encode_video(video) + latents = rearrange(latents[0], "C T H W -> T C H W") + return latents + + + def tensor2video(self, frames): + frames = rearrange(frames, "C T H W -> T H W C") + frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8) + frames = [Image.fromarray(frame) for frame in frames] + return frames + + + def calculate_noise_pred( + self, + latents, + timestep, + add_time_id, + cfg_scales, + image_emb_vae_posi, image_emb_clip_posi, + image_emb_vae_nega, image_emb_clip_nega + ): + # Positive side + noise_pred_posi = self.unet( + torch.cat([latents, image_emb_vae_posi], dim=1), + timestep, image_emb_clip_posi, add_time_id + ) + # Negative side + noise_pred_nega = self.unet( + torch.cat([latents, image_emb_vae_nega], dim=1), + timestep, image_emb_clip_nega, add_time_id + ) + + # Classifier-free guidance + noise_pred = noise_pred_nega + cfg_scales * (noise_pred_posi - noise_pred_nega) + + return noise_pred + + + def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0): + if post_normalize: + mean, std = latents.mean(), latents.std() + latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean + latents = latents * contrast_enhance_scale + return latents + + + @torch.no_grad() + def __call__( + self, + input_image=None, + input_video=None, + mask_frames=[], + mask_frame_ids=[], + min_cfg_scale=1.0, + max_cfg_scale=3.0, + denoising_strength=1.0, + num_frames=25, + height=576, + width=1024, + fps=7, + motion_bucket_id=127, + noise_aug_strength=0.02, + num_inference_steps=20, + post_normalize=True, + contrast_enhance_scale=1.2, + progress_bar_cmd=tqdm, + progress_bar_st=None, + ): + # Prepare scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength) + + # Prepare latent tensors + noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).to(self.device) + if denoising_strength == 1.0: + latents = noise.clone() + else: + latents = self.encode_video_with_vae(input_video) + latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0]) + + # Prepare mask frames + if len(mask_frames) > 0: + mask_latents = self.encode_video_with_vae(mask_frames) + + # Encode image + image_emb_clip_posi = self.encode_image_with_clip(input_image) + image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi) + image_emb_vae_posi = repeat(self.encode_image_with_vae(input_image, noise_aug_strength), "B C H W -> (B T) C H W", T=num_frames) + image_emb_vae_nega = torch.zeros_like(image_emb_vae_posi) + + # Prepare classifier-free guidance + cfg_scales = torch.linspace(min_cfg_scale, max_cfg_scale, num_frames) + cfg_scales = cfg_scales.reshape(num_frames, 1, 1, 1).to(device=self.device, dtype=self.torch_dtype) + + # Prepare positional id + add_time_id = torch.tensor([[fps-1, motion_bucket_id, noise_aug_strength]], device=self.device) + + # Denoise + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + + # Mask frames + for frame_id, mask_frame_id in enumerate(mask_frame_ids): + latents[mask_frame_id] = self.scheduler.add_noise(mask_latents[frame_id], noise[mask_frame_id], timestep) + + # Fetch model output + noise_pred = self.calculate_noise_pred( + latents, timestep, add_time_id, cfg_scales, + image_emb_vae_posi, image_emb_clip_posi, image_emb_vae_nega, image_emb_clip_nega + ) + + # Forward Euler + latents = self.scheduler.step(noise_pred, timestep, latents) + + # Update progress bar + if progress_bar_st is not None: + progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) + + # Decode image + latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale) + video = self.vae_decoder.decode_video(latents, progress_bar=progress_bar_cmd) + video = self.tensor2video(video) + + return video + + + +class SVDCLIPImageProcessor: + def __init__(self): + pass + + def resize_with_antialiasing(self, input, size, interpolation="bicubic", align_corners=True): + h, w = input.shape[-2:] + factors = (h / size[0], w / size[1]) + + # First, we have to determine sigma + # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 + sigmas = ( + max((factors[0] - 1.0) / 2.0, 0.001), + max((factors[1] - 1.0) / 2.0, 0.001), + ) + + # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma + # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 + # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now + ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) + + # Make sure it is odd + if (ks[0] % 2) == 0: + ks = ks[0] + 1, ks[1] + + if (ks[1] % 2) == 0: + ks = ks[0], ks[1] + 1 + + input = self._gaussian_blur2d(input, ks, sigmas) + + output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) + return output + + + def _compute_padding(self, kernel_size): + """Compute padding tuple.""" + # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) + # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad + if len(kernel_size) < 2: + raise AssertionError(kernel_size) + computed = [k - 1 for k in kernel_size] + + # for even kernels we need to do asymmetric padding :( + out_padding = 2 * len(kernel_size) * [0] + + for i in range(len(kernel_size)): + computed_tmp = computed[-(i + 1)] + + pad_front = computed_tmp // 2 + pad_rear = computed_tmp - pad_front + + out_padding[2 * i + 0] = pad_front + out_padding[2 * i + 1] = pad_rear + + return out_padding + + + def _filter2d(self, input, kernel): + # prepare kernel + b, c, h, w = input.shape + tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) + + tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) + + height, width = tmp_kernel.shape[-2:] + + padding_shape: list[int] = self._compute_padding([height, width]) + input = torch.nn.functional.pad(input, padding_shape, mode="reflect") + + # kernel and input tensor reshape to align element-wise or batch-wise params + tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) + input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) + + # convolve the tensor with the kernel. + output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) + + out = output.view(b, c, h, w) + return out + + + def _gaussian(self, window_size: int, sigma): + if isinstance(sigma, float): + sigma = torch.tensor([[sigma]]) + + batch_size = sigma.shape[0] + + x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) + + if window_size % 2 == 0: + x = x + 0.5 + + gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) + + return gauss / gauss.sum(-1, keepdim=True) + + + def _gaussian_blur2d(self, input, kernel_size, sigma): + if isinstance(sigma, tuple): + sigma = torch.tensor([sigma], dtype=input.dtype) + else: + sigma = sigma.to(dtype=input.dtype) + + ky, kx = int(kernel_size[0]), int(kernel_size[1]) + bs = sigma.shape[0] + kernel_x = self._gaussian(kx, sigma[:, 1].view(bs, 1)) + kernel_y = self._gaussian(ky, sigma[:, 0].view(bs, 1)) + out_x = self._filter2d(input, kernel_x[..., None, :]) + out = self._filter2d(out_x, kernel_y[..., None]) + + return out diff --git a/diffsynth/processors/FastBlend.py b/diffsynth/processors/FastBlend.py new file mode 100644 index 0000000000000000000000000000000000000000..fed33f4fdd215c8c9dc46f3b07d9453a12cc6b98 --- /dev/null +++ b/diffsynth/processors/FastBlend.py @@ -0,0 +1,142 @@ +from PIL import Image +import cupy as cp +import numpy as np +from tqdm import tqdm +from ..extensions.FastBlend.patch_match import PyramidPatchMatcher +from ..extensions.FastBlend.runners.fast import TableManager +from .base import VideoProcessor + + +class FastBlendSmoother(VideoProcessor): + def __init__( + self, + inference_mode="fast", batch_size=8, window_size=60, + minimum_patch_size=5, threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0, initialize="identity", tracking_window_size=0 + ): + self.inference_mode = inference_mode + self.batch_size = batch_size + self.window_size = window_size + self.ebsynth_config = { + "minimum_patch_size": minimum_patch_size, + "threads_per_block": threads_per_block, + "num_iter": num_iter, + "gpu_id": gpu_id, + "guide_weight": guide_weight, + "initialize": initialize, + "tracking_window_size": tracking_window_size + } + + @staticmethod + def from_model_manager(model_manager, **kwargs): + # TODO: fetch GPU ID from model_manager + return FastBlendSmoother(**kwargs) + + def inference_fast(self, frames_guide, frames_style): + table_manager = TableManager() + patch_match_engine = PyramidPatchMatcher( + image_height=frames_style[0].shape[0], + image_width=frames_style[0].shape[1], + channel=3, + **self.ebsynth_config + ) + # left part + table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, self.batch_size, desc="Fast Mode Step 1/4") + table_l = table_manager.remapping_table_to_blending_table(table_l) + table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 2/4") + # right part + table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, self.batch_size, desc="Fast Mode Step 3/4") + table_r = table_manager.remapping_table_to_blending_table(table_r) + table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 4/4")[::-1] + # merge + frames = [] + for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r): + weight_m = -1 + weight = weight_l + weight_m + weight_r + frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight) + frames.append(frame) + frames = [frame.clip(0, 255).astype("uint8") for frame in frames] + frames = [Image.fromarray(frame) for frame in frames] + return frames + + def inference_balanced(self, frames_guide, frames_style): + patch_match_engine = PyramidPatchMatcher( + image_height=frames_style[0].shape[0], + image_width=frames_style[0].shape[1], + channel=3, + **self.ebsynth_config + ) + output_frames = [] + # tasks + n = len(frames_style) + tasks = [] + for target in range(n): + for source in range(target - self.window_size, target + self.window_size + 1): + if source >= 0 and source < n and source != target: + tasks.append((source, target)) + # run + frames = [(None, 1) for i in range(n)] + for batch_id in tqdm(range(0, len(tasks), self.batch_size), desc="Balanced Mode"): + tasks_batch = tasks[batch_id: min(batch_id+self.batch_size, len(tasks))] + source_guide = np.stack([frames_guide[source] for source, target in tasks_batch]) + target_guide = np.stack([frames_guide[target] for source, target in tasks_batch]) + source_style = np.stack([frames_style[source] for source, target in tasks_batch]) + _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) + for (source, target), result in zip(tasks_batch, target_style): + frame, weight = frames[target] + if frame is None: + frame = frames_style[target] + frames[target] = ( + frame * (weight / (weight + 1)) + result / (weight + 1), + weight + 1 + ) + if weight + 1 == min(n, target + self.window_size + 1) - max(0, target - self.window_size): + frame = frame.clip(0, 255).astype("uint8") + output_frames.append(Image.fromarray(frame)) + frames[target] = (None, 1) + return output_frames + + def inference_accurate(self, frames_guide, frames_style): + patch_match_engine = PyramidPatchMatcher( + image_height=frames_style[0].shape[0], + image_width=frames_style[0].shape[1], + channel=3, + use_mean_target_style=True, + **self.ebsynth_config + ) + output_frames = [] + # run + n = len(frames_style) + for target in tqdm(range(n), desc="Accurate Mode"): + l, r = max(target - self.window_size, 0), min(target + self.window_size + 1, n) + remapped_frames = [] + for i in range(l, r, self.batch_size): + j = min(i + self.batch_size, r) + source_guide = np.stack([frames_guide[source] for source in range(i, j)]) + target_guide = np.stack([frames_guide[target]] * (j - i)) + source_style = np.stack([frames_style[source] for source in range(i, j)]) + _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) + remapped_frames.append(target_style) + frame = np.concatenate(remapped_frames, axis=0).mean(axis=0) + frame = frame.clip(0, 255).astype("uint8") + output_frames.append(Image.fromarray(frame)) + return output_frames + + def release_vram(self): + mempool = cp.get_default_memory_pool() + pinned_mempool = cp.get_default_pinned_memory_pool() + mempool.free_all_blocks() + pinned_mempool.free_all_blocks() + + def __call__(self, rendered_frames, original_frames=None, **kwargs): + rendered_frames = [np.array(frame) for frame in rendered_frames] + original_frames = [np.array(frame) for frame in original_frames] + if self.inference_mode == "fast": + output_frames = self.inference_fast(original_frames, rendered_frames) + elif self.inference_mode == "balanced": + output_frames = self.inference_balanced(original_frames, rendered_frames) + elif self.inference_mode == "accurate": + output_frames = self.inference_accurate(original_frames, rendered_frames) + else: + raise ValueError("inference_mode must be fast, balanced or accurate") + self.release_vram() + return output_frames diff --git a/diffsynth/processors/PILEditor.py b/diffsynth/processors/PILEditor.py new file mode 100644 index 0000000000000000000000000000000000000000..01011d8724f61283550d503c5c20ae6fd0375ec7 --- /dev/null +++ b/diffsynth/processors/PILEditor.py @@ -0,0 +1,28 @@ +from PIL import ImageEnhance +from .base import VideoProcessor + + +class ContrastEditor(VideoProcessor): + def __init__(self, rate=1.5): + self.rate = rate + + @staticmethod + def from_model_manager(model_manager, **kwargs): + return ContrastEditor(**kwargs) + + def __call__(self, rendered_frames, **kwargs): + rendered_frames = [ImageEnhance.Contrast(i).enhance(self.rate) for i in rendered_frames] + return rendered_frames + + +class SharpnessEditor(VideoProcessor): + def __init__(self, rate=1.5): + self.rate = rate + + @staticmethod + def from_model_manager(model_manager, **kwargs): + return SharpnessEditor(**kwargs) + + def __call__(self, rendered_frames, **kwargs): + rendered_frames = [ImageEnhance.Sharpness(i).enhance(self.rate) for i in rendered_frames] + return rendered_frames diff --git a/diffsynth/processors/RIFE.py b/diffsynth/processors/RIFE.py new file mode 100644 index 0000000000000000000000000000000000000000..4186eb31496e9a1bf38df06eb64921226f07ee09 --- /dev/null +++ b/diffsynth/processors/RIFE.py @@ -0,0 +1,77 @@ +import torch +import numpy as np +from PIL import Image +from .base import VideoProcessor + + +class RIFESmoother(VideoProcessor): + def __init__(self, model, device="cuda", scale=1.0, batch_size=4, interpolate=True): + self.model = model + self.device = device + + # IFNet only does not support float16 + self.torch_dtype = torch.float32 + + # Other parameters + self.scale = scale + self.batch_size = batch_size + self.interpolate = interpolate + + @staticmethod + def from_model_manager(model_manager, **kwargs): + return RIFESmoother(model_manager.RIFE, device=model_manager.device, **kwargs) + + def process_image(self, image): + width, height = image.size + if width % 32 != 0 or height % 32 != 0: + width = (width + 31) // 32 + height = (height + 31) // 32 + image = image.resize((width, height)) + image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1) + return image + + def process_images(self, images): + images = [self.process_image(image) for image in images] + images = torch.stack(images) + return images + + def decode_images(self, images): + images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8) + images = [Image.fromarray(image) for image in images] + return images + + def process_tensors(self, input_tensor, scale=1.0, batch_size=4): + output_tensor = [] + for batch_id in range(0, input_tensor.shape[0], batch_size): + batch_id_ = min(batch_id + batch_size, input_tensor.shape[0]) + batch_input_tensor = input_tensor[batch_id: batch_id_] + batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype) + flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale]) + output_tensor.append(merged[2].cpu()) + output_tensor = torch.concat(output_tensor, dim=0) + return output_tensor + + @torch.no_grad() + def __call__(self, rendered_frames, **kwargs): + # Preprocess + processed_images = self.process_images(rendered_frames) + + # Input + input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1) + + # Interpolate + output_tensor = self.process_tensors(input_tensor, scale=self.scale, batch_size=self.batch_size) + + if self.interpolate: + # Blend + input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1) + output_tensor = self.process_tensors(input_tensor, scale=self.scale, batch_size=self.batch_size) + processed_images[1:-1] = output_tensor + else: + processed_images[1:-1] = (processed_images[1:-1] + output_tensor) / 2 + + # To images + output_images = self.decode_images(processed_images) + if output_images[0].size != rendered_frames[0].size: + output_images = [image.resize(rendered_frames[0].size) for image in output_images] + return output_images diff --git a/diffsynth/processors/__init__.py b/diffsynth/processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/diffsynth/processors/base.py b/diffsynth/processors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..278a9c1b74044987cc116de35292a96de8b13737 --- /dev/null +++ b/diffsynth/processors/base.py @@ -0,0 +1,6 @@ +class VideoProcessor: + def __init__(self): + pass + + def __call__(self): + raise NotImplementedError diff --git a/diffsynth/processors/sequencial_processor.py b/diffsynth/processors/sequencial_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..9b5bc9454f0b9d74f10bb4a6bff92db77f26325c --- /dev/null +++ b/diffsynth/processors/sequencial_processor.py @@ -0,0 +1,41 @@ +from .base import VideoProcessor + + +class AutoVideoProcessor(VideoProcessor): + def __init__(self): + pass + + @staticmethod + def from_model_manager(model_manager, processor_type, **kwargs): + if processor_type == "FastBlend": + from .FastBlend import FastBlendSmoother + return FastBlendSmoother.from_model_manager(model_manager, **kwargs) + elif processor_type == "Contrast": + from .PILEditor import ContrastEditor + return ContrastEditor.from_model_manager(model_manager, **kwargs) + elif processor_type == "Sharpness": + from .PILEditor import SharpnessEditor + return SharpnessEditor.from_model_manager(model_manager, **kwargs) + elif processor_type == "RIFE": + from .RIFE import RIFESmoother + return RIFESmoother.from_model_manager(model_manager, **kwargs) + else: + raise ValueError(f"invalid processor_type: {processor_type}") + + +class SequencialProcessor(VideoProcessor): + def __init__(self, processors=[]): + self.processors = processors + + @staticmethod + def from_model_manager(model_manager, configs): + processors = [ + AutoVideoProcessor.from_model_manager(model_manager, config["processor_type"], **config["config"]) + for config in configs + ] + return SequencialProcessor(processors) + + def __call__(self, rendered_frames, **kwargs): + for processor in self.processors: + rendered_frames = processor(rendered_frames, **kwargs) + return rendered_frames diff --git a/diffsynth/prompts/__init__.py b/diffsynth/prompts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b07578b9117740fe0b602fd737ce682e757f434a --- /dev/null +++ b/diffsynth/prompts/__init__.py @@ -0,0 +1,3 @@ +from .sd_prompter import SDPrompter +from .sdxl_prompter import SDXLPrompter +from .hunyuan_dit_prompter import HunyuanDiTPrompter diff --git a/diffsynth/prompts/hunyuan_dit_prompter.py b/diffsynth/prompts/hunyuan_dit_prompter.py new file mode 100644 index 0000000000000000000000000000000000000000..ee12cbda36e57d33247a51fd8f7e3b0ad48b8060 --- /dev/null +++ b/diffsynth/prompts/hunyuan_dit_prompter.py @@ -0,0 +1,56 @@ +from .utils import Prompter +from transformers import BertModel, T5EncoderModel, BertTokenizer, AutoTokenizer +import warnings + + +class HunyuanDiTPrompter(Prompter): + def __init__( + self, + tokenizer_path="configs/hunyuan_dit/tokenizer", + tokenizer_t5_path="configs/hunyuan_dit/tokenizer_t5" + ): + super().__init__() + self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.tokenizer_t5 = AutoTokenizer.from_pretrained(tokenizer_t5_path) + + + def encode_prompt_using_signle_model(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + clip_skip=clip_skip + ) + return prompt_embeds, attention_mask + + + def encode_prompt( + self, + text_encoder: BertModel, + text_encoder_t5: T5EncoderModel, + prompt, + clip_skip=1, + clip_skip_2=1, + positive=True, + device="cuda" + ): + prompt = self.process_prompt(prompt, positive=positive) + + # CLIP + prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device) + + # T5 + prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device) + + return prompt_emb, attention_mask, prompt_emb_t5, attention_mask_t5 diff --git a/diffsynth/prompts/sd_prompter.py b/diffsynth/prompts/sd_prompter.py new file mode 100644 index 0000000000000000000000000000000000000000..6d4407c9ee619bf9993496dca811f76b0c4283fc --- /dev/null +++ b/diffsynth/prompts/sd_prompter.py @@ -0,0 +1,17 @@ +from .utils import Prompter, tokenize_long_prompt +from transformers import CLIPTokenizer +from ..models import SDTextEncoder + + +class SDPrompter(Prompter): + def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) + + def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True): + prompt = self.process_prompt(prompt, positive=positive) + input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device) + prompt_emb = text_encoder(input_ids, clip_skip=clip_skip) + prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1)) + + return prompt_emb \ No newline at end of file diff --git a/diffsynth/prompts/sdxl_prompter.py b/diffsynth/prompts/sdxl_prompter.py new file mode 100644 index 0000000000000000000000000000000000000000..44b721e230171a3751922bd818f1f28f22c86091 --- /dev/null +++ b/diffsynth/prompts/sdxl_prompter.py @@ -0,0 +1,43 @@ +from .utils import Prompter, tokenize_long_prompt +from transformers import CLIPTokenizer +from ..models import SDXLTextEncoder, SDXLTextEncoder2 +import torch + + +class SDXLPrompter(Prompter): + def __init__( + self, + tokenizer_path="configs/stable_diffusion/tokenizer", + tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2" + ): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) + self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path) + + def encode_prompt( + self, + text_encoder: SDXLTextEncoder, + text_encoder_2: SDXLTextEncoder2, + prompt, + clip_skip=1, + clip_skip_2=2, + positive=True, + device="cuda" + ): + prompt = self.process_prompt(prompt, positive=positive) + + # 1 + input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device) + prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip) + + # 2 + input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device) + add_text_embeds, prompt_emb_2 = text_encoder_2(input_ids_2, clip_skip=clip_skip_2) + + # Merge + prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1) + + # For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`. + add_text_embeds = add_text_embeds[0:1] + prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1)) + return add_text_embeds, prompt_emb diff --git a/diffsynth/prompts/utils.py b/diffsynth/prompts/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f0412286b7b70432d2b6a749e3a279501745d242 --- /dev/null +++ b/diffsynth/prompts/utils.py @@ -0,0 +1,123 @@ +from transformers import CLIPTokenizer, AutoTokenizer +from ..models import ModelManager +import os + + +def tokenize_long_prompt(tokenizer, prompt): + # Get model_max_length from self.tokenizer + length = tokenizer.model_max_length + + # To avoid the warning. set self.tokenizer.model_max_length to +oo. + tokenizer.model_max_length = 99999999 + + # Tokenize it! + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + # Determine the real length. + max_length = (input_ids.shape[1] + length - 1) // length * length + + # Restore tokenizer.model_max_length + tokenizer.model_max_length = length + + # Tokenize it again with fixed length. + input_ids = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True + ).input_ids + + # Reshape input_ids to fit the text encoder. + num_sentence = input_ids.shape[1] // length + input_ids = input_ids.reshape((num_sentence, length)) + + return input_ids + + +class BeautifulPrompt: + def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + self.model = model + self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:' + + def __call__(self, raw_prompt): + model_input = self.template.format(raw_prompt=raw_prompt) + input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device) + outputs = self.model.generate( + input_ids, + max_new_tokens=384, + do_sample=True, + temperature=0.9, + top_k=50, + top_p=0.95, + repetition_penalty=1.1, + num_return_sequences=1 + ) + prompt = raw_prompt + ", " + self.tokenizer.batch_decode( + outputs[:, input_ids.size(1):], + skip_special_tokens=True + )[0].strip() + return prompt + + +class Translator: + def __init__(self, tokenizer_path="configs/translator/tokenizer", model=None): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + self.model = model + + def __call__(self, prompt): + input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device) + output_ids = self.model.generate(input_ids) + prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] + return prompt + + +class Prompter: + def __init__(self): + self.tokenizer: CLIPTokenizer = None + self.keyword_dict = {} + self.translator: Translator = None + self.beautiful_prompt: BeautifulPrompt = None + + def load_textual_inversion(self, textual_inversion_dict): + self.keyword_dict = {} + additional_tokens = [] + for keyword in textual_inversion_dict: + tokens, _ = textual_inversion_dict[keyword] + additional_tokens += tokens + self.keyword_dict[keyword] = " " + " ".join(tokens) + " " + self.tokenizer.add_tokens(additional_tokens) + + def load_beautiful_prompt(self, model, model_path): + model_folder = os.path.dirname(model_path) + self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model) + if model_folder.endswith("v2"): + self.beautiful_prompt.template = """Converts a simple image description into a prompt. \ +Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \ +or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \ +but make sure there is a correlation between the input and output.\n\ +### Input: {raw_prompt}\n### Output:""" + + def load_translator(self, model, model_path): + model_folder = os.path.dirname(model_path) + self.translator = Translator(tokenizer_path=model_folder, model=model) + + def load_from_model_manager(self, model_manager: ModelManager): + self.load_textual_inversion(model_manager.textual_inversion_dict) + if "translator" in model_manager.model: + self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"]) + if "beautiful_prompt" in model_manager.model: + self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"]) + + def process_prompt(self, prompt, positive=True): + for keyword in self.keyword_dict: + if keyword in prompt: + prompt = prompt.replace(keyword, self.keyword_dict[keyword]) + if positive and self.translator is not None: + prompt = self.translator(prompt) + print(f"Your prompt is translated: \"{prompt}\"") + if positive and self.beautiful_prompt is not None: + prompt = self.beautiful_prompt(prompt) + print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"") + return prompt diff --git a/diffsynth/schedulers/__init__.py b/diffsynth/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1620e13ad7674cdce701858a8933bd1652c9c8e4 --- /dev/null +++ b/diffsynth/schedulers/__init__.py @@ -0,0 +1,2 @@ +from .ddim import EnhancedDDIMScheduler +from .continuous_ode import ContinuousODEScheduler diff --git a/diffsynth/schedulers/continuous_ode.py b/diffsynth/schedulers/continuous_ode.py new file mode 100644 index 0000000000000000000000000000000000000000..e6cd8378ad0c09b5ed02c2b22feba5562bd22aff --- /dev/null +++ b/diffsynth/schedulers/continuous_ode.py @@ -0,0 +1,59 @@ +import torch + + +class ContinuousODEScheduler(): + + def __init__(self, num_inference_steps=100, sigma_max=700.0, sigma_min=0.002, rho=7.0): + self.sigma_max = sigma_max + self.sigma_min = sigma_min + self.rho = rho + self.set_timesteps(num_inference_steps) + + + def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0): + ramp = torch.linspace(1-denoising_strength, 1, num_inference_steps) + min_inv_rho = torch.pow(torch.tensor((self.sigma_min,)), (1 / self.rho)) + max_inv_rho = torch.pow(torch.tensor((self.sigma_max,)), (1 / self.rho)) + self.sigmas = torch.pow(max_inv_rho + ramp * (min_inv_rho - max_inv_rho), self.rho) + self.timesteps = torch.log(self.sigmas) * 0.25 + + + def step(self, model_output, timestep, sample, to_final=False): + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + sample *= (sigma*sigma + 1).sqrt() + estimated_sample = -sigma / (sigma*sigma + 1).sqrt() * model_output + 1 / (sigma*sigma + 1) * sample + if to_final or timestep_id + 1 >= len(self.timesteps): + prev_sample = estimated_sample + else: + sigma_ = self.sigmas[timestep_id + 1] + derivative = 1 / sigma * (sample - estimated_sample) + prev_sample = sample + derivative * (sigma_ - sigma) + prev_sample /= (sigma_*sigma_ + 1).sqrt() + return prev_sample + + + def return_to_timestep(self, timestep, sample, sample_stablized): + # This scheduler doesn't support this function. + pass + + + def add_noise(self, original_samples, noise, timestep): + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + sample = (original_samples + noise * sigma) / (sigma*sigma + 1).sqrt() + return sample + + + def training_target(self, sample, noise, timestep): + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + target = (-(sigma*sigma + 1).sqrt() / sigma + 1 / (sigma*sigma + 1).sqrt() / sigma) * sample + 1 / (sigma*sigma + 1).sqrt() * noise + return target + + + def training_weight(self, timestep): + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + weight = (1 + sigma*sigma).sqrt() / sigma + return weight diff --git a/diffsynth/schedulers/ddim.py b/diffsynth/schedulers/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..4b4d9806d9937da4d64955ebda8fb59475c907c6 --- /dev/null +++ b/diffsynth/schedulers/ddim.py @@ -0,0 +1,73 @@ +import torch, math + + +class EnhancedDDIMScheduler(): + + def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon"): + self.num_train_timesteps = num_train_timesteps + if beta_schedule == "scaled_linear": + betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32)) + elif beta_schedule == "linear": + betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented") + self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist() + self.set_timesteps(10) + self.prediction_type = prediction_type + + + def set_timesteps(self, num_inference_steps, denoising_strength=1.0): + # The timesteps are aligned to 999...0, which is different from other implementations, + # but I think this implementation is more reasonable in theory. + max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0) + num_inference_steps = min(num_inference_steps, max_timestep + 1) + if num_inference_steps == 1: + self.timesteps = [max_timestep] + else: + step_length = max_timestep / (num_inference_steps - 1) + self.timesteps = [round(max_timestep - i*step_length) for i in range(num_inference_steps)] + + + def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev): + if self.prediction_type == "epsilon": + weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t) + weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t) + prev_sample = sample * weight_x + model_output * weight_e + elif self.prediction_type == "v_prediction": + weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev)) + weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev)) + prev_sample = sample * weight_x + model_output * weight_e + else: + raise NotImplementedError(f"{self.prediction_type} is not implemented") + return prev_sample + + + def step(self, model_output, timestep, sample, to_final=False): + alpha_prod_t = self.alphas_cumprod[timestep] + timestep_id = self.timesteps.index(timestep) + if to_final or timestep_id + 1 >= len(self.timesteps): + alpha_prod_t_prev = 1.0 + else: + timestep_prev = self.timesteps[timestep_id + 1] + alpha_prod_t_prev = self.alphas_cumprod[timestep_prev] + + return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev) + + + def return_to_timestep(self, timestep, sample, sample_stablized): + alpha_prod_t = self.alphas_cumprod[timestep] + noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t) + return noise_pred + + + def add_noise(self, original_samples, noise, timestep): + sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep]) + sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep]) + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def training_target(self, sample, noise, timestep): + sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep]) + sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep]) + target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return target