# -*- coding: utf-8 -*-

# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de

import os
import gc

import logging
from lib.common.config import cfg
from lib.dataset.mesh_util import (
    load_checkpoint,
    update_mesh_shape_prior_losses,
    blend_rgb_norm,
    unwrap,
    remesh,
    tensor2variable,
)

from lib.dataset.TestDataset import TestDataset
from lib.net.local_affine import LocalAffine
from pytorch3d.structures import Meshes
from apps.ICON import ICON

from termcolor import colored
import numpy as np
from PIL import Image
import trimesh
import numpy as np
from tqdm import tqdm

import torch
torch.backends.cudnn.benchmark = True

logging.getLogger("trimesh").setLevel(logging.ERROR)


def generate_model(in_path, model_type):

    torch.cuda.empty_cache()
    
    if model_type == 'ICON':
        model_type = 'icon-filter'
    else:
        model_type = model_type.lower()

    config_dict = {'loop_smpl': 100,
                   'loop_cloth': 200,
                   'patience': 5,
                   'out_dir': './results',
                   'hps_type': 'pymaf',
                   'config': f"./configs/{model_type}.yaml"}

    # cfg read and merge
    cfg.merge_from_file(config_dict['config'])
    cfg.merge_from_file("./lib/pymaf/configs/pymaf_config.yaml")

    os.makedirs(config_dict['out_dir'], exist_ok=True)

    cfg_show_list = [
        "test_gpus",
        [0],
        "mcube_res",
        256,
        "clean_mesh",
        True,
    ]

    cfg.merge_from_list(cfg_show_list)
    cfg.freeze()

    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    device = torch.device(f"cuda:0")

    # load model and dataloader
    model = ICON(cfg)
    model = load_checkpoint(model, cfg)

    dataset_param = {
        'image_path': in_path,
        'seg_dir': None,
        'has_det': True,            # w/ or w/o detection
        'hps_type': 'pymaf'   # pymaf/pare/pixie
    }

    if config_dict['hps_type'] == "pixie" and "pamir" in config_dict['config']:
        print(colored("PIXIE isn't compatible with PaMIR, thus switch to PyMAF", "red"))
        dataset_param["hps_type"] = "pymaf"

    dataset = TestDataset(dataset_param, device)

    print(colored(f"Dataset Size: {len(dataset)}", "green"))

    pbar = tqdm(dataset)

    for data in pbar:

        pbar.set_description(f"{data['name']}")

        in_tensor = {"smpl_faces": data["smpl_faces"], "image": data["image"]}

        # The optimizer and variables
        optimed_pose = torch.tensor(
            data["body_pose"], device=device, requires_grad=True
        )  # [1,23,3,3]
        optimed_trans = torch.tensor(
            data["trans"], device=device, requires_grad=True
        )  # [3]
        optimed_betas = torch.tensor(
            data["betas"], device=device, requires_grad=True
        )  # [1,10]
        optimed_orient = torch.tensor(
            data["global_orient"], device=device, requires_grad=True
        )  # [1,1,3,3]

        optimizer_smpl = torch.optim.SGD(
            [optimed_pose, optimed_trans, optimed_betas, optimed_orient],
            lr=1e-3,
            momentum=0.9,
        )
        scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer_smpl,
            mode="min",
            factor=0.5,
            verbose=0,
            min_lr=1e-5,
            patience=config_dict['patience'],
        )

        losses = {
            # Cloth: Normal_recon - Normal_pred
            "cloth": {"weight": 1e1, "value": 0.0},
            # Cloth: [RT]_v1 - [RT]_v2 (v1-edge-v2)
            "stiffness": {"weight": 1e5, "value": 0.0},
            # Cloth: det(R) = 1
            "rigid": {"weight": 1e5, "value": 0.0},
            # Cloth: edge length
            "edge": {"weight": 0, "value": 0.0},
            # Cloth: normal consistency
            "nc": {"weight": 0, "value": 0.0},
            # Cloth: laplacian smoonth
            "laplacian": {"weight": 1e2, "value": 0.0},
            # Body: Normal_pred - Normal_smpl
            "normal": {"weight": 1e0, "value": 0.0},
            # Body: Silhouette_pred - Silhouette_smpl
            "silhouette": {"weight": 1e0, "value": 0.0},
        }

        # smpl optimization

        loop_smpl = tqdm(range(config_dict['loop_smpl']))

        for _ in loop_smpl:

            optimizer_smpl.zero_grad()

            if dataset_param["hps_type"] != "pixie":
                smpl_out = dataset.smpl_model(
                    betas=optimed_betas,
                    body_pose=optimed_pose,
                    global_orient=optimed_orient,
                    pose2rot=False,
                )

                smpl_verts = ((smpl_out.vertices) +
                              optimed_trans) * data["scale"]
            else:
                smpl_verts, _, _ = dataset.smpl_model(
                    shape_params=optimed_betas,
                    expression_params=tensor2variable(data["exp"], device),
                    body_pose=optimed_pose,
                    global_pose=optimed_orient,
                    jaw_pose=tensor2variable(data["jaw_pose"], device),
                    left_hand_pose=tensor2variable(
                        data["left_hand_pose"], device),
                    right_hand_pose=tensor2variable(
                        data["right_hand_pose"], device),
                )

                smpl_verts = (smpl_verts + optimed_trans) * data["scale"]

            # render optimized mesh (normal, T_normal, image [-1,1])
            in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal(
                smpl_verts *
                torch.tensor([1.0, -1.0, -1.0]
                             ).to(device), in_tensor["smpl_faces"]
            )
            T_mask_F, T_mask_B = dataset.render.get_silhouette_image()

            with torch.no_grad():
                in_tensor["normal_F"], in_tensor["normal_B"] = model.netG.normal_filter(
                    in_tensor
                )

            diff_F_smpl = torch.abs(
                in_tensor["T_normal_F"] - in_tensor["normal_F"])
            diff_B_smpl = torch.abs(
                in_tensor["T_normal_B"] - in_tensor["normal_B"])

            losses["normal"]["value"] = (diff_F_smpl + diff_B_smpl).mean()

            # silhouette loss
            smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)[0]
            gt_arr = torch.cat(
                [in_tensor["normal_F"][0], in_tensor["normal_B"][0]], dim=2
            ).permute(1, 2, 0)
            gt_arr = ((gt_arr + 1.0) * 0.5).to(device)
            bg_color = (
                torch.Tensor([0.5, 0.5, 0.5]).unsqueeze(
                    0).unsqueeze(0).to(device)
            )
            gt_arr = ((gt_arr - bg_color).sum(dim=-1) != 0.0).float()
            diff_S = torch.abs(smpl_arr - gt_arr)
            losses["silhouette"]["value"] = diff_S.mean()

            # Weighted sum of the losses
            smpl_loss = 0.0
            pbar_desc = "Body Fitting --- "
            for k in ["normal", "silhouette"]:
                pbar_desc += f"{k}: {losses[k]['value'] * losses[k]['weight']:.3f} | "
                smpl_loss += losses[k]["value"] * losses[k]["weight"]
            pbar_desc += f"Total: {smpl_loss:.3f}"
            loop_smpl.set_description(pbar_desc)

            smpl_loss.backward()
            optimizer_smpl.step()
            scheduler_smpl.step(smpl_loss)
            in_tensor["smpl_verts"] = smpl_verts * \
                torch.tensor([1.0, 1.0, -1.0]).to(device)

        # visualize the optimization process
        # 1. SMPL Fitting
        # 2. Clothes Refinement

        os.makedirs(os.path.join(config_dict['out_dir'], cfg.name,
                    "refinement"), exist_ok=True)

        # visualize the final results in self-rotation mode
        os.makedirs(os.path.join(config_dict['out_dir'],
                    cfg.name, "vid"), exist_ok=True)

        # final results rendered as image
        # 1. Render the final fitted SMPL (xxx_smpl.png)
        # 2. Render the final reconstructed clothed human (xxx_cloth.png)
        # 3. Blend the original image with predicted cloth normal (xxx_overlap.png)

        os.makedirs(os.path.join(config_dict['out_dir'],
                    cfg.name, "png"), exist_ok=True)

        # final reconstruction meshes
        # 1. SMPL mesh (xxx_smpl.obj)
        # 2. SMPL params (xxx_smpl.npy)
        # 3. clohted mesh (xxx_recon.obj)
        # 4. remeshed clothed mesh (xxx_remesh.obj)
        # 5. refined clothed mesh (xxx_refine.obj)

        os.makedirs(os.path.join(config_dict['out_dir'],
                    cfg.name, "obj"), exist_ok=True)

        norm_pred = (
            ((in_tensor["normal_F"][0].permute(1, 2, 0) + 1.0) * 255.0 / 2.0)
            .detach()
            .cpu()
            .numpy()
            .astype(np.uint8)
        )

        norm_orig = unwrap(norm_pred, data)
        mask_orig = unwrap(
            np.repeat(
                data["mask"].permute(1, 2, 0).detach().cpu().numpy(), 3, axis=2
            ).astype(np.uint8),
            data,
        )
        rgb_norm = blend_rgb_norm(data["ori_image"], norm_orig, mask_orig)

        Image.fromarray(
            np.concatenate(
                [data["ori_image"].astype(np.uint8), rgb_norm], axis=1)
        ).save(os.path.join(config_dict['out_dir'], cfg.name, f"png/{data['name']}_overlap.png"))

        smpl_obj = trimesh.Trimesh(
            in_tensor["smpl_verts"].detach().cpu()[0] *
            torch.tensor([1.0, -1.0, 1.0]),
            in_tensor['smpl_faces'].detach().cpu()[0],
            process=False,
            maintains_order=True
        )
        smpl_obj.visual.vertex_colors = (smpl_obj.vertex_normals+1.0)*255.0*0.5
        smpl_obj.export(
            f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.obj")
        smpl_obj.export(
            f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb")

        smpl_info = {'betas': optimed_betas,
                     'pose': optimed_pose,
                     'orient': optimed_orient,
                     'trans': optimed_trans}

        np.save(
            f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.npy", smpl_info, allow_pickle=True)

        # ------------------------------------------------------------------------------------------------------------------

        # cloth optimization

        # cloth recon
        in_tensor.update(
            dataset.compute_vis_cmap(
                in_tensor["smpl_verts"][0], in_tensor["smpl_faces"][0]
            )
        )

        if cfg.net.prior_type == "pamir":
            in_tensor.update(
                dataset.compute_voxel_verts(
                    optimed_pose,
                    optimed_orient,
                    optimed_betas,
                    optimed_trans,
                    data["scale"],
                )
            )

        with torch.no_grad():
            verts_pr, faces_pr, _ = model.test_single(in_tensor)

        recon_obj = trimesh.Trimesh(
            verts_pr, faces_pr, process=False, maintains_order=True
        )
        recon_obj.visual.vertex_colors = (
            recon_obj.vertex_normals+1.0)*255.0*0.5
        recon_obj.export(
            os.path.join(config_dict['out_dir'], cfg.name,
                         f"obj/{data['name']}_recon.obj")
        )

        # Isotropic Explicit Remeshing for better geometry topology
        verts_refine, faces_refine = remesh(os.path.join(config_dict['out_dir'], cfg.name,
                                                         f"obj/{data['name']}_recon.obj"), 0.5, device)

        # define local_affine deform verts
        mesh_pr = Meshes(verts_refine, faces_refine).to(device)
        local_affine_model = LocalAffine(
            mesh_pr.verts_padded().shape[1], mesh_pr.verts_padded().shape[0], mesh_pr.edges_packed()).to(device)
        optimizer_cloth = torch.optim.Adam(
            [{'params': local_affine_model.parameters()}], lr=1e-4, amsgrad=True)

        scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer_cloth,
            mode="min",
            factor=0.1,
            verbose=0,
            min_lr=1e-5,
            patience=config_dict['patience'],
        )

        final = None

        if config_dict['loop_cloth'] > 0:

            loop_cloth = tqdm(range(config_dict['loop_cloth']))

            for _ in loop_cloth:

                optimizer_cloth.zero_grad()

                deformed_verts, stiffness, rigid = local_affine_model(
                    verts_refine.to(device), return_stiff=True)
                mesh_pr = mesh_pr.update_padded(deformed_verts)

                # losses for laplacian, edge, normal consistency
                update_mesh_shape_prior_losses(mesh_pr, losses)

                in_tensor["P_normal_F"], in_tensor["P_normal_B"] = dataset.render_normal(
                    mesh_pr.verts_padded(), mesh_pr.faces_padded())

                diff_F_cloth = torch.abs(
                    in_tensor["P_normal_F"] - in_tensor["normal_F"])
                diff_B_cloth = torch.abs(
                    in_tensor["P_normal_B"] - in_tensor["normal_B"])

                losses["cloth"]["value"] = (diff_F_cloth + diff_B_cloth).mean()
                losses["stiffness"]["value"] = torch.mean(stiffness)
                losses["rigid"]["value"] = torch.mean(rigid)

                # Weighted sum of the losses
                cloth_loss = torch.tensor(0.0, requires_grad=True).to(device)
                pbar_desc = "Cloth Refinement --- "

                for k in losses.keys():
                    if k not in ["normal", "silhouette"] and losses[k]["weight"] > 0.0:
                        cloth_loss = cloth_loss + \
                            losses[k]["value"] * losses[k]["weight"]
                        pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.5f} | "

                pbar_desc += f"Total: {cloth_loss:.5f}"
                loop_cloth.set_description(pbar_desc)

                # update params
                cloth_loss.backward()
                optimizer_cloth.step()
                scheduler_cloth.step(cloth_loss)

            final = trimesh.Trimesh(
                mesh_pr.verts_packed().detach().squeeze(0).cpu(),
                mesh_pr.faces_packed().detach().squeeze(0).cpu(),
                process=False, maintains_order=True
            )

            # without front texture
            final_colors = (mesh_pr.verts_normals_padded().squeeze(
                0).detach().cpu() + 1.0) * 0.5 * 255.0
            final.visual.vertex_colors = final_colors
            final.export(
                f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.obj")
            final.export(
                f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.glb")

        # always export visualized video regardless of the cloth refinment
        verts_lst = [smpl_obj.vertices, final.vertices]
        faces_lst = [smpl_obj.faces, final.faces]

        # self-rotated video
        dataset.render.load_meshes(
            verts_lst, faces_lst)
        dataset.render.get_rendered_video(
            [data["ori_image"], rgb_norm],
            os.path.join(config_dict['out_dir'], cfg.name,
                         f"vid/{data['name']}_cloth.mp4"),
        )

    smpl_obj_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.obj"
    smpl_glb_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb"
    smpl_npy_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.npy"
    refine_obj_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.obj"
    refine_glb_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.glb"

    video_path = os.path.join(
        config_dict['out_dir'], cfg.name, f"vid/{data['name']}_cloth.mp4")
    overlap_path = os.path.join(
        config_dict['out_dir'], cfg.name, f"png/{data['name']}_overlap.png")

    # clean all the variables
    for element in dir():
        if 'path' not in element:
            del locals()[element]
    gc.collect()
    torch.cuda.empty_cache()
    
    return [smpl_glb_path, smpl_obj_path,smpl_npy_path,
            refine_glb_path, refine_obj_path,
            video_path, video_path, overlap_path]