import argparse from typing import Any, Dict import torch from huggingface_hub import hf_hub_download from safetensors.torch import load_file from diffusers import AutoencoderDC def remap_qkv_(key: str, state_dict: Dict[str, Any]): qkv = state_dict.pop(key) q, k, v = torch.chunk(qkv, 3, dim=0) parent_module, _, _ = key.rpartition(".qkv.conv.weight") state_dict[f"{parent_module}.to_q.weight"] = q.squeeze() state_dict[f"{parent_module}.to_k.weight"] = k.squeeze() state_dict[f"{parent_module}.to_v.weight"] = v.squeeze() def remap_proj_conv_(key: str, state_dict: Dict[str, Any]): parent_module, _, _ = key.rpartition(".proj.conv.weight") state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze() AE_KEYS_RENAME_DICT = { # common "main.": "", "op_list.": "", "context_module": "attn", "local_module": "conv_out", # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1 # If there were more scales, there would be more layers, so a loop would be better to handle this "aggreg.0.0": "to_qkv_multiscale.0.proj_in", "aggreg.0.1": "to_qkv_multiscale.0.proj_out", "depth_conv.conv": "conv_depth", "inverted_conv.conv": "conv_inverted", "point_conv.conv": "conv_point", "point_conv.norm": "norm", "conv.conv.": "conv.", "conv1.conv": "conv1", "conv2.conv": "conv2", "conv2.norm": "norm", "proj.norm": "norm_out", # encoder "encoder.project_in.conv": "encoder.conv_in", "encoder.project_out.0.conv": "encoder.conv_out", "encoder.stages": "encoder.down_blocks", # decoder "decoder.project_in.conv": "decoder.conv_in", "decoder.project_out.0": "decoder.norm_out", "decoder.project_out.2.conv": "decoder.conv_out", "decoder.stages": "decoder.up_blocks", } AE_F32C32_KEYS = { # encoder "encoder.project_in.conv": "encoder.conv_in.conv", # decoder "decoder.project_out.2.conv": "decoder.conv_out.conv", } AE_F64C128_KEYS = { # encoder "encoder.project_in.conv": "encoder.conv_in.conv", # decoder "decoder.project_out.2.conv": "decoder.conv_out.conv", } AE_F128C512_KEYS = { # encoder "encoder.project_in.conv": "encoder.conv_in.conv", # decoder "decoder.project_out.2.conv": "decoder.conv_out.conv", } AE_SPECIAL_KEYS_REMAP = { "qkv.conv.weight": remap_qkv_, "proj.conv.weight": remap_proj_conv_, } def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: state_dict = saved_dict if "model" in saved_dict.keys(): state_dict = state_dict["model"] if "module" in saved_dict.keys(): state_dict = state_dict["module"] if "state_dict" in saved_dict.keys(): state_dict = state_dict["state_dict"] return state_dict def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: state_dict[new_key] = state_dict.pop(old_key) def convert_ae(config_name: str, dtype: torch.dtype): config = get_ae_config(config_name) hub_id = f"mit-han-lab/{config_name}" ckpt_path = hf_hub_download(hub_id, "model.safetensors") original_state_dict = get_state_dict(load_file(ckpt_path)) ae = AutoencoderDC(**config).to(dtype=dtype) for key in list(original_state_dict.keys()): new_key = key[:] for replace_key, rename_key in AE_KEYS_RENAME_DICT.items(): new_key = new_key.replace(replace_key, rename_key) update_state_dict_(original_state_dict, key, new_key) for key in list(original_state_dict.keys()): for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items(): if special_key not in key: continue handler_fn_inplace(key, original_state_dict) ae.load_state_dict(original_state_dict, strict=True) return ae def get_ae_config(name: str): if name in ["dc-ae-f32c32-sana-1.0"]: config = { "latent_channels": 32, "encoder_block_types": ( "ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock", ), "decoder_block_types": ( "ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock", ), "encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024), "decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024), "encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)), "decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)), "encoder_layers_per_block": (2, 2, 2, 3, 3, 3), "decoder_layers_per_block": [3, 3, 3, 3, 3, 3], "downsample_block_type": "conv", "upsample_block_type": "interpolate", "decoder_norm_types": "rms_norm", "decoder_act_fns": "silu", "scaling_factor": 0.41407, } elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]: AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS) config = { "latent_channels": 32, "encoder_block_types": [ "ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock", ], "decoder_block_types": [ "ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock", ], "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], "encoder_layers_per_block": [0, 4, 8, 2, 2, 2], "decoder_layers_per_block": [0, 5, 10, 2, 2, 2], "encoder_qkv_multiscales": ((), (), (), (), (), ()), "decoder_qkv_multiscales": ((), (), (), (), (), ()), "decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"], "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"], } if name == "dc-ae-f32c32-in-1.0": config["scaling_factor"] = 0.3189 elif name == "dc-ae-f32c32-mix-1.0": config["scaling_factor"] = 0.4552 elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]: AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS) config = { "latent_channels": 128, "encoder_block_types": [ "ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock", ], "decoder_block_types": [ "ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock", ], "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048], "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048], "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2], "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2], "encoder_qkv_multiscales": ((), (), (), (), (), (), ()), "decoder_qkv_multiscales": ((), (), (), (), (), (), ()), "decoder_norm_types": [ "batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm", "rms_norm", ], "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"], } if name == "dc-ae-f64c128-in-1.0": config["scaling_factor"] = 0.2889 elif name == "dc-ae-f64c128-mix-1.0": config["scaling_factor"] = 0.4538 elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]: AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS) config = { "latent_channels": 512, "encoder_block_types": [ "ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock", ], "decoder_block_types": [ "ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock", ], "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048], "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048], "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2], "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2], "encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()), "decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()), "decoder_norm_types": [ "batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm", "rms_norm", "rms_norm", ], "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"], } if name == "dc-ae-f128c512-in-1.0": config["scaling_factor"] = 0.4883 elif name == "dc-ae-f128c512-mix-1.0": config["scaling_factor"] = 0.3620 else: raise ValueError("Invalid config name provided.") return config def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--config_name", type=str, default="dc-ae-f32c32-sana-1.0", choices=[ "dc-ae-f32c32-sana-1.0", "dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0", "dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0", "dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0", ], help="The DCAE checkpoint to convert", ) parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") return parser.parse_args() DTYPE_MAPPING = { "fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16, } VARIANT_MAPPING = { "fp32": None, "fp16": "fp16", "bf16": "bf16", } if __name__ == "__main__": args = get_args() dtype = DTYPE_MAPPING[args.dtype] variant = VARIANT_MAPPING[args.dtype] ae = convert_ae(args.config_name, dtype) ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)