from utils.utils import instantiate_from_config
import torch
import copy
from omegaconf import OmegaConf
import logging

main_logger = logging.getLogger("main_logger")


def expand_conv_kernel(pretrained_dict):
    """expand 2d conv parameters from 4D -> 5D"""
    for k, v in pretrained_dict.items():
        if v.dim() == 4 and not k.startswith("first_stage_model"):
            v = v.unsqueeze(2)
            pretrained_dict[k] = v
    return pretrained_dict


def print_state_dict(state_dict):
    print("====== Dumping State Dict ======")
    for k, v in state_dict.items():
        print(k, v.shape)


def load_from_pretrainedSD_checkpoint(
    model,
    pretained_ckpt,
    expand_to_3d=True,
    adapt_keyname=False,
    echo_empty_params=False,
):
    sd_state_dict = torch.load(pretained_ckpt, map_location="cpu")
    if "state_dict" in list(sd_state_dict.keys()):
        sd_state_dict = sd_state_dict["state_dict"]
    model_state_dict = model.state_dict()
    # delete ema_weights just for <precise param counting>
    for k in list(sd_state_dict.keys()):
        if k.startswith("model_ema"):
            del sd_state_dict[k]
    main_logger.info(
        f"Num of model params of Source:{len(sd_state_dict.keys())} VS. Target:{len(model_state_dict.keys())}"
    )
    # print_state_dict(model_state_dict)
    # print_state_dict(sd_state_dict)

    if adapt_keyname:
        # adapting to standard 2d network: modify the key name because of the add of temporal-attention
        mapping_dict = {
            "middle_block.2": "middle_block.3",
            "output_blocks.5.2": "output_blocks.5.3",
            "output_blocks.8.2": "output_blocks.8.3",
        }
        cnt = 0
        for k in list(sd_state_dict.keys()):
            for src_word, dst_word in mapping_dict.items():
                if src_word in k:
                    new_key = k.replace(src_word, dst_word)
                    sd_state_dict[new_key] = sd_state_dict[k]
                    del sd_state_dict[k]
                    cnt += 1
        main_logger.info(f"[renamed {cnt} Source keys to match Target model]")

    pretrained_dict = {
        k: v for k, v in sd_state_dict.items() if k in model_state_dict
    }  # drop extra keys
    empty_paras = [
        k for k, v in model_state_dict.items() if k not in pretrained_dict
    ]  # log no pretrained keys
    assert len(empty_paras) + len(pretrained_dict.keys()) == len(
        model_state_dict.keys()
    )

    if expand_to_3d:
        # adapting to 2d inflated network
        pretrained_dict = expand_conv_kernel(pretrained_dict)

    # overwrite entries in the existing state dict
    model_state_dict.update(pretrained_dict)

    # load the new state dict
    try:
        model.load_state_dict(model_state_dict)
    except:
        skipped = []
        model_dict_ori = model.state_dict()
        for n, p in model_state_dict.items():
            if p.shape != model_dict_ori[n].shape:
                # skip by using original empty paras
                model_state_dict[n] = model_dict_ori[n]
                main_logger.info(
                    f"Skip para: {n}, size={pretrained_dict[n].shape} in pretrained, {model_state_dict[n].shape} in current model"
                )
                skipped.append(n)
        main_logger.info(
            f"[INFO] Skip {len(skipped)} parameters becasuse of size mismatch!"
        )
        model.load_state_dict(model_state_dict)
        empty_paras += skipped

    # only count Unet  part of depth estimation model
    unet_empty_paras = [
        name for name in empty_paras if name.startswith("model.diffusion_model")
    ]
    main_logger.info(
        f"Pretrained parameters: {len(pretrained_dict.keys())} | Empty parameters: {len(empty_paras)} [Unet:{len(unet_empty_paras)}]"
    )
    if echo_empty_params:
        print("Printing empty parameters:")
        for k in empty_paras:
            print(k)
    return model, empty_paras


# Below: written by Yingqing --------------------------------------------------------


def load_model_from_config(config, ckpt, verbose=False):
    pl_sd = torch.load(ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        main_logger.info("missing keys:")
        main_logger.info(m)
    if len(u) > 0 and verbose:
        main_logger.info("unexpected keys:")
        main_logger.info(u)

    model.eval()
    return model


def init_and_load_ldm_model(config_path, ckpt_path, device=None):
    assert config_path.endswith(".yaml"), f"config_path = {config_path}"
    assert ckpt_path.endswith(".ckpt"), f"ckpt_path = {ckpt_path}"
    config = OmegaConf.load(config_path)
    model = load_model_from_config(config, ckpt_path)
    if device is not None:
        model = model.to(device)
    return model


def load_img_model_to_video_model(
    model,
    device=None,
    expand_to_3d=True,
    adapt_keyname=False,
    config_path="configs/latent-diffusion/txt2img-1p4B-eval.yaml",
    ckpt_path="models/ldm/text2img-large/model.ckpt",
):
    pretrained_ldm = init_and_load_ldm_model(config_path, ckpt_path, device)
    model, empty_paras = load_partial_weights(
        model,
        pretrained_ldm.state_dict(),
        expand_to_3d=expand_to_3d,
        adapt_keyname=adapt_keyname,
    )
    return model, empty_paras


def load_partial_weights(
    model, pretrained_dict, expand_to_3d=True, adapt_keyname=False
):
    model2 = copy.deepcopy(model)
    model_dict = model.state_dict()
    model_dict_ori = copy.deepcopy(model_dict)

    main_logger.info(f"[Load pretrained LDM weights]")
    main_logger.info(
        f"Num of parameters of source model:{len(pretrained_dict.keys())} VS. target model:{len(model_dict.keys())}"
    )

    if adapt_keyname:
        # adapting to menghan's standard 2d network: modify the key name because of the add of temporal-attention
        mapping_dict = {
            "middle_block.2": "middle_block.3",
            "output_blocks.5.2": "output_blocks.5.3",
            "output_blocks.8.2": "output_blocks.8.3",
        }
        cnt = 0
        newpretrained_dict = copy.deepcopy(pretrained_dict)
        for k, v in newpretrained_dict.items():
            for src_word, dst_word in mapping_dict.items():
                if src_word in k:
                    new_key = k.replace(src_word, dst_word)
                    pretrained_dict[new_key] = v
                    pretrained_dict.pop(k)
                    cnt += 1
        main_logger.info(f"--renamed {cnt} source keys to match target model.")
    pretrained_dict = {
        k: v for k, v in pretrained_dict.items() if k in model_dict
    }  # drop extra keys
    empty_paras = [
        k for k, v in model_dict.items() if k not in pretrained_dict
    ]  # log no pretrained keys
    main_logger.info(
        f"Pretrained parameters: {len(pretrained_dict.keys())} | Empty parameters: {len(empty_paras)}"
    )
    # disable info
    # main_logger.info(f'Empty parameters: {empty_paras} ')
    assert len(empty_paras) + len(pretrained_dict.keys()) == len(model_dict.keys())

    if expand_to_3d:
        # adapting to yingqing's 2d inflation network
        pretrained_dict = expand_conv_kernel(pretrained_dict)

    # overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)

    # load the new state dict
    try:
        model2.load_state_dict(model_dict)
    except:
        # if parameter size mismatch, skip them
        skipped = []
        for n, p in model_dict.items():
            if p.shape != model_dict_ori[n].shape:
                # skip by using original empty paras
                model_dict[n] = model_dict_ori[n]
                main_logger.info(
                    f"Skip para: {n}, size={pretrained_dict[n].shape} in pretrained, {model_dict[n].shape} in current model"
                )
                skipped.append(n)
        main_logger.info(
            f"[INFO] Skip {len(skipped)} parameters becasuse of size mismatch!"
        )
        model2.load_state_dict(model_dict)
        empty_paras += skipped
        main_logger.info(f"Empty parameters: {len(empty_paras)} ")

    main_logger.info(f"Finished.")
    return model2, empty_paras


def load_autoencoder(model, config_path=None, ckpt_path=None, device=None):
    if config_path is None:
        config_path = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
    if ckpt_path is None:
        ckpt_path = "models/ldm/text2img-large/model.ckpt"

    pretrained_ldm = init_and_load_ldm_model(config_path, ckpt_path, device)
    autoencoder_dict = {}
    for n, p in pretrained_ldm.state_dict().items():
        if n.startswith("first_stage_model"):
            autoencoder_dict[n] = p
    model_dict = model.state_dict()
    model_dict.update(autoencoder_dict)
    main_logger.info(f"Load [{len(autoencoder_dict)}] autoencoder parameters!")

    model.load_state_dict(model_dict)

    return model