import argparse
import math
from pathlib import Path

import cv2
import numpy as np
import PIL.Image as Image
import selfcontact
import selfcontact.losses
import shapely.geometry
import torch
import torch.nn as nn
import torch.optim as optim
import torchgeometry
import tqdm
import trimesh
from skimage import measure

import fist_pose
import hist_cub
import losses
import pose_estimation
import spin

PE_KSP_TO_SPIN = {
    "Head": "Head",
    "Neck": "Neck",
    "Right Shoulder": "Right ForeArm",
    "Right Arm": "Right Arm",
    "Right Hand": "Right Hand",
    "Left Shoulder": "Left ForeArm",
    "Left Arm": "Left Arm",
    "Left Hand": "Left Hand",
    "Spine": "Spine1",
    "Hips": "Hips",
    "Right Upper Leg": "Right Upper Leg",
    "Right Leg": "Right Leg",
    "Right Foot": "Right Foot",
    "Left Upper Leg": "Left Upper Leg",
    "Left Leg": "Left Leg",
    "Left Foot": "Left Foot",
    "Left Toe": "Left Toe",
    "Right Toe": "Right Toe",
}
MODELS_DIR = "models"


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--pose-estimation-model-path",
        type=str,
        default=f"./{MODELS_DIR}/hrn_w48_384x288.onnx",
        help="Pose Estimation model",
    )

    parser.add_argument(
        "--contact-model-path",
        type=str,
        default=f"./{MODELS_DIR}/contact_hrn_w32_256x192.onnx",
        help="Contact model",
    )

    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        choices=["cpu", "cuda"],
        help="Torch device",
    )

    parser.add_argument(
        "--spin-model-path",
        type=str,
        default=f"./{MODELS_DIR}/spin_model_smplx_eft_18.pt",
        help="SPIN model path",
    )

    parser.add_argument(
        "--smpl-type",
        type=str,
        default="smplx",
        choices=["smplx"],
        help="SMPL model type",
    )
    parser.add_argument(
        "--smpl-model-dir",
        type=str,
        default=f"./{MODELS_DIR}/models/smplx",
        help="SMPL model dir",
    )
    parser.add_argument(
        "--smpl-mean-params-path",
        type=str,
        default=f"./{MODELS_DIR}/data/smpl_mean_params.npz",
        help="SMPL mean params",
    )
    parser.add_argument(
        "--essentials-dir",
        type=str,
        default=f"./{MODELS_DIR}/smplify-xmc-essentials",
        help="SMPL Essentials folder for contacts",
    )

    parser.add_argument(
        "--parametrization-path",
        type=str,
        default=f"./{MODELS_DIR}/smplx_parametrization/parametrization.npy",
        help="Parametrization path",
    )
    parser.add_argument(
        "--bone-parametrization-path",
        type=str,
        default=f"./{MODELS_DIR}/smplx_parametrization/bone_to_param2.npy",
        help="Bone parametrization path",
    )
    parser.add_argument(
        "--foot-inds-path",
        type=str,
        default=f"./{MODELS_DIR}/smplx_parametrization/foot_inds.npy",
        help="Foot indinces",
    )

    parser.add_argument(
        "--save-path",
        type=str,
        required=True,
        help="Path to save the results",
    )

    parser.add_argument(
        "--img-path",
        type=str,
        required=True,
        help="Path to img to test",
    )

    parser.add_argument(
        "--use-contacts",
        action="store_true",
        help="Use contact model",
    )
    parser.add_argument(
        "--use-msc",
        action="store_true",
        help="Use MSC loss",
    )
    parser.add_argument(
        "--use-natural",
        action="store_true",
        help="Use regularity",
    )
    parser.add_argument(
        "--use-cos",
        action="store_true",
        help="Use cos model",
    )
    parser.add_argument(
        "--use-angle-transf",
        action="store_true",
        help="Use cube foreshortening transformation",
    )

    parser.add_argument(
        "--c-mse",
        type=float,
        default=0,
        help="MSE weight",
    )
    parser.add_argument(
        "--c-par",
        type=float,
        default=10,
        help="Parallel weight",
    )

    parser.add_argument(
        "--c-f",
        type=float,
        default=1000,
        help="Cos coef",
    )
    parser.add_argument(
        "--c-parallel",
        type=float,
        default=100,
        help="Parallel weight",
    )
    parser.add_argument(
        "--c-reg",
        type=float,
        default=1000,
        help="Regularity weight",
    )
    parser.add_argument(
        "--c-cont2d",
        type=float,
        default=1,
        help="Contact 2D weight",
    )
    parser.add_argument(
        "--c-msc",
        type=float,
        default=17_500,
        help="MSC weight",
    )

    parser.add_argument(
        "--fist",
        nargs="+",
        type=str,
        choices=list(fist_pose.INT_TO_FIST),
    )

    args = parser.parse_args()

    return args


def freeze_layers(model):
    for module in model.modules():
        if type(module) is False:
            continue

        if isinstance(module, nn.modules.batchnorm._BatchNorm):
            module.eval()
            for m in module.parameters():
                m.requires_grad = False

        if isinstance(module, nn.Dropout):
            module.eval()
            for m in module.parameters():
                m.requires_grad = False


def project_and_normalize_to_spin(vertices_3d, camera):
    vertices_2d = vertices_3d  # [:, :2]

    scale, translate = camera[0], camera[1:]
    translate = scale.new_zeros(3)
    translate[:2] = camera[1:]

    vertices_2d = vertices_2d + translate
    vertices_2d = scale * vertices_2d + 1
    vertices_2d = spin.constants.IMG_RES / 2 * vertices_2d

    return vertices_2d


def project_and_normalize_to_spin_legs(vertices_3d, A, camera):
    A, J = A
    A = A[0]
    J = J[0]
    L = vertices_3d.new_tensor(
        [
            [0.98619063, 0.16560926, 0.00127302],
            [-0.16560601, 0.98603675, 0.01749799],
            [0.00164258, -0.01746717, 0.99984609],
        ]
    )
    R = vertices_3d.new_tensor(
        [
            [0.9910211, -0.13368178, -0.0025208],
            [0.13367888, 0.99027076, 0.03864949],
            [-0.00267045, -0.03863944, 0.99924965],
        ]
    )
    scale = camera[0]
    R = A[2, :3, :3] @ R  # 2 - right
    L = A[1, :3, :3] @ L  # 1 - left
    r = J[5] - J[2]
    l = J[4] - J[1]

    rleg = scale * spin.constants.IMG_RES / 2 * R @ r
    lleg = scale * spin.constants.IMG_RES / 2 * L @ l

    rleg = rleg[:2]
    lleg = lleg[:2]

    return rleg, lleg


def rotation_matrix_to_angle_axis(rotmat):
    bs, n_joints, *_ = rotmat.size()
    rotmat = torch.cat(
        [
            rotmat.view(-1, 3, 3),
            rotmat.new_tensor([0, 0, 1], dtype=torch.float32)
            .view(bs, 3, 1)
            .expand(n_joints, -1, -1),
        ],
        dim=-1,
    )
    aa = torchgeometry.rotation_matrix_to_angle_axis(rotmat)
    aa = aa.reshape(bs, 3 * n_joints)

    return aa


def get_smpl_output(smpl, rotmat, betas, use_betas=True, zero_hands=False):
    if smpl.name() == "SMPL":
        smpl_output = smpl(
            betas=betas if use_betas else None,
            body_pose=rotmat[:, 1:],
            global_orient=rotmat[:, 0].unsqueeze(1),
            pose2rot=False,
        )
    elif smpl.name() == "SMPL-X":
        rotmat = rotation_matrix_to_angle_axis(rotmat)
        if zero_hands:
            for i in [20, 21]:
                rotmat[:, 3 * i : 3 * (i + 1)] = 0

            for i in [12, 15]:  # neck, head
                rotmat[:, 3 * i + 1] = 0  # y
        smpl_output = smpl(
            betas=betas if use_betas else None,
            body_pose=rotmat[:, 3:],
            global_orient=rotmat[:, :3],
            pose2rot=True,
        )
    else:
        raise NotImplementedError

    return smpl_output, rotmat


def get_predictions(model_hmr, smpl, input_img, use_betas=True, zero_hands=False):
    input_img = input_img.unsqueeze(0)
    rotmat, betas, camera = model_hmr(input_img)

    smpl_output, rotmat = get_smpl_output(
        smpl, rotmat, betas, use_betas=use_betas, zero_hands=zero_hands
    )

    rotmat = rotmat.squeeze(0)
    betas = betas.squeeze(0)
    camera = camera.squeeze(0)
    z = smpl_output.joints
    z = z.squeeze(0)

    return rotmat, betas, camera, smpl_output, z


def get_pred_and_data(
    model_hmr, smpl, selector, input_img, use_betas=True, zero_hands=False
):
    rotmat, betas, camera, smpl_output, zz = get_predictions(
        model_hmr, smpl, input_img, use_betas=use_betas, zero_hands=zero_hands
    )

    joints = smpl_output.joints.squeeze(0)
    joints_2d = project_and_normalize_to_spin(joints, camera)
    rleg, lleg = project_and_normalize_to_spin_legs(joints, smpl_output.A, camera)
    joints_2d_orig = joints_2d
    joints_2d = joints_2d[selector]

    vertices = smpl_output.vertices.squeeze(0)
    vertices_2d = project_and_normalize_to_spin(vertices, camera)

    zz = zz[selector]

    return (
        rotmat,
        betas,
        camera,
        joints_2d,
        zz,
        vertices_2d,
        smpl_output,
        (rleg, lleg),
        joints_2d_orig,
    )


def normalize_keypoints_to_spin(keypoints_2d, img_size):
    h, w = img_size
    if h > w:  # vertically
        ax1 = 1
        ax2 = 0
    else:  # horizontal
        ax1 = 0
        ax2 = 1

    shift = (img_size[ax1] - img_size[ax2]) / 2
    scale = spin.constants.IMG_RES / img_size[ax2]
    keypoints_2d_normalized = np.copy(keypoints_2d)
    keypoints_2d_normalized[:, ax2] -= shift
    keypoints_2d_normalized *= scale

    return keypoints_2d_normalized, shift, scale, ax2


def unnormalize_keypoints_from_spin(keypoints_2d, shift, scale, ax2):
    keypoints_2d_normalized = np.copy(keypoints_2d)
    keypoints_2d_normalized /= scale
    keypoints_2d_normalized[:, ax2] += shift

    return keypoints_2d_normalized


def get_vertices_in_heatmap(contact_heatmap):
    contact_heatmap_size = contact_heatmap.shape[:2]
    label = measure.label(contact_heatmap)

    y_data_conts = []
    for i in range(1, label.max() + 1):
        predicted_kps_contact = np.vstack(np.nonzero(label == i)[::-1]).T.astype(
            "float"
        )
        predicted_kps_contact_scaled, *_ = normalize_keypoints_to_spin(
            predicted_kps_contact, contact_heatmap_size
        )
        y_data_cont = torch.from_numpy(predicted_kps_contact_scaled).int().tolist()
        y_data_cont = shapely.geometry.MultiPoint(y_data_cont).convex_hull
        y_data_conts.append(y_data_cont)

    return y_data_conts


def get_contact_heatmap(model_contact, img_path, thresh=0.5):
    contact_heatmap = pose_estimation.infer_single_image(
        model_contact,
        img_path,
        input_img_size=(192, 256),
        return_kps=False,
    )
    contact_heatmap = contact_heatmap.squeeze(0)
    contact_heatmap_orig = contact_heatmap.copy()

    mi = contact_heatmap.min()
    ma = contact_heatmap.max()
    contact_heatmap = (contact_heatmap - mi) / (ma - mi)
    contact_heatmap_ = ((contact_heatmap > thresh) * 255).astype("uint8")

    contact_heatmap = np.repeat(contact_heatmap[..., None], repeats=3, axis=-1)
    contact_heatmap = (contact_heatmap * 255).astype("uint8")

    return contact_heatmap_, contact_heatmap, contact_heatmap_orig


def discretize(parametrization, n_bins=100):
    bins = np.linspace(0, 1, n_bins + 1)
    inds = np.digitize(parametrization, bins)
    disc_parametrization = bins[inds - 1]

    return disc_parametrization


def get_mapping_from_params_to_verts(verts, params):
    mapping = {}
    for v, t in zip(verts, params):
        mapping.setdefault(t, []).append(v)

    return mapping


def find_contacts(y_data_conts, keypoints_2d, bone_to_params, thresh=12, step=0.0072246375):
    n_bins = int(math.ceil(1 / step)) - 1  # mean face's circumradius
    contact = []
    contact_2d = []
    for_mask = []
    for y_data_cont in y_data_conts:
        contact_loc = []
        contact_2d_loc = []
        buffer = y_data_cont.buffer(thresh)
        mask_add = False
        for i, j in pose_estimation.SKELETON:
            verts, t3d = bone_to_params[(i, j)]
            if len(verts) == 0:
                continue

            t3d = discretize(t3d, n_bins=n_bins)
            t3d_to_verts = get_mapping_from_params_to_verts(verts, t3d)
            t3d_to_verts_sorted = sorted(t3d_to_verts.items(), key=lambda x: x[0])
            t3d_sorted_np = np.array([x for x, _ in t3d_to_verts_sorted])

            line = shapely.geometry.LineString([keypoints_2d[i], keypoints_2d[j]])
            lint = buffer.intersection(line)
            if len(lint.boundary.geoms) < 2:
                continue

            t2d_start = line.project(lint.boundary.geoms[0], normalized=True)
            t2d_end = line.project(lint.boundary.geoms[1], normalized=True)
            assert t2d_start <= t2d_end

            t2ds = discretize(
                np.linspace(t2d_start, t2d_end, n_bins + 1), n_bins=n_bins
            )
            to_add = False
            for t2d in t2ds:
                if t2d < t3d_sorted_np[0] or t2d > t3d_sorted_np[-1]:
                    continue

                t2d_ind = np.searchsorted(t3d_sorted_np, t2d)
                c = t3d_to_verts_sorted[t2d_ind][1]

                contact_loc.extend(c)
                to_add = True
                mask_add = True

                if t2d_ind + 1 < len(t3d_to_verts_sorted):
                    c = t3d_to_verts_sorted[t2d_ind + 1][1]
                    contact_loc.extend(c)

                if t2d_ind > 0:
                    c = t3d_to_verts_sorted[t2d_ind - 1][1]
                    contact_loc.extend(c)

            if to_add:
                contact_2d_loc.append((i, j, t2d_start + 0.5 * (t2d_end - t2d_start)))

        if mask_add:
            for_mask.append(buffer.exterior.coords.xy)

        contact_loc = sorted(set(contact_loc))
        contact_loc = np.array(contact_loc, dtype="int")
        contact.append(contact_loc)
        contact_2d.append(contact_2d_loc)

    for_mask = [np.stack((x, y), axis=0).T[:, None].astype("int") for x, y in for_mask]

    return contact, contact_2d, for_mask


def optimize(
    model_hmr,
    smpl,
    selector,
    input_img,
    keypoints_2d,
    optimizer,
    args,
    loss_mse=None,
    loss_parallel=None,
    c_mse=0.0,
    c_new_mse=1.0,
    c_beta=1e-3,
    sc_crit=None,
    msc_crit=None,
    contact=None,
    n_steps=60,
    i_ini=0,
):
    mean_zfoot_val = {}
    with tqdm.trange(n_steps) as pbar:
        for i in pbar:
            global_step = i + i_ini
            optimizer.zero_grad()

            (
                rotmat_pred,
                betas_pred,
                camera_pred,
                keypoints_3d_pred,
                z,
                vertices_2d_pred,
                smpl_output,
                (rleg, lleg),
                joints_2d_orig,
            ) = get_pred_and_data(
                model_hmr,
                smpl,
                selector,
                input_img,
            )
            keypoints_2d_pred = keypoints_3d_pred[:, :2]

            loss = l2 = 0.0
            if c_mse > 0 and loss_mse is not None:
                l2 = loss_mse(keypoints_2d_pred, keypoints_2d)
                loss = loss + c_mse * l2

            vertices_pred = smpl_output.vertices

            lpar = z_loss = loss_sh = 0.0
            if c_new_mse > 0 and loss_parallel is not None:
                Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d = loss_parallel(
                    keypoints_3d_pred,
                    keypoints_2d,
                    z,
                    (rleg, lleg),
                    global_step=global_step,
                )
                lpar = (
                    Ltan
                    + c_new_mse * (args.c_f * Lcos + args.c_parallel * Lpar)
                    + Lspine
                    + args.c_reg * Lgr
                    + args.c_reg * Lstraight3d
                    + args.c_cont2d * Lcon2d
                )
                loss = loss + 300 * lpar

                for side in ["left", "right"]:
                    attr = f"{side}_foot_inds"
                    if hasattr(loss_parallel, attr):
                        foot_inds = getattr(loss_parallel, attr)
                        zind = 1
                        if attr not in mean_zfoot_val:
                            with torch.no_grad():
                                mean_zfoot_val[attr] = torch.median(
                                    vertices_pred[0, foot_inds, zind], dim=0
                                ).values

                        loss_foot = (
                            (vertices_pred[0, foot_inds, zind] - mean_zfoot_val[attr])
                            ** 2
                        ).sum()
                        loss = loss + args.c_reg * loss_foot

                if hasattr(loss_parallel, "silhuette_vertices_inds"):
                    inds = loss_parallel.silhuette_vertices_inds
                    loss_sh = (
                        (vertices_pred[0, inds, 1] - loss_parallel.ground) ** 2
                    ).sum()
                    loss = loss + args.c_reg * loss_sh

            lbeta = (betas_pred**2).mean()
            lcam = ((torch.exp(-camera_pred[0] * 10)) ** 2).mean()
            loss = loss + c_beta * lbeta + lcam

            lgsc_a = gsc_contact_loss = faces_angle_loss = 0.0
            if sc_crit is not None:
                gsc_contact_loss, faces_angle_loss = sc_crit(
                    vertices_pred,
                )
                lgsc_a = 1000 * gsc_contact_loss + 0.1 * faces_angle_loss
                loss = loss + lgsc_a

            msc_loss = 0.0
            if contact is not None and len(contact) > 0 and msc_crit is not None:
                if not isinstance(contact, list):
                    contact = [contact]

                for cntct in contact:
                    msc_loss = msc_crit(
                        cntct,
                        vertices_pred,
                    )
                    loss = loss + args.c_msc * msc_loss

            loss.backward()
            optimizer.step()

            epoch_loss = loss.item()
            pbar.set_postfix(
                **{
                    "l": f"{epoch_loss:.3}",
                    "l2": f"{l2:.3}",
                    "par": f"{lpar:.3}",
                    "beta": f"{lbeta:.3}",
                    "cam": f"{lcam:.3}",
                    "z": f"{z_loss:.3}",
                    "gsc_contact": f"{float(gsc_contact_loss):.3}",
                    "faces_angle": f"{float(faces_angle_loss):.3}",
                    "msc": f"{float(msc_loss):.3}",
                }
            )

    with torch.no_grad():
        (
            rotmat_pred,
            betas_pred,
            camera_pred,
            keypoints_3d_pred,
            z,
            vertices_2d_pred,
            smpl_output,
            (rleg, lleg),
            joints_2d_orig,
        ) = get_pred_and_data(
            model_hmr,
            smpl,
            selector,
            input_img,
            zero_hands=True,
        )

    return (
        rotmat_pred,
        betas_pred,
        camera_pred,
        keypoints_3d_pred,
        vertices_2d_pred,
        smpl_output,
        z,
        joints_2d_orig,
    )


def optimize_ft(
    theta,
    camera,
    smpl,
    selector,
    keypoints_2d,
    args,
    loss_mse=None,
    loss_parallel=None,
    c_mse=0.0,
    c_new_mse=1.0,
    sc_crit=None,
    msc_crit=None,
    contact=None,
    n_steps=60,
    i_ini=0,
    zero_hands=False,
    fist=None,
):
    mean_zfoot_val = {}

    theta = theta.detach().clone()
    camera = camera.detach().clone()
    rotmat_pred = nn.Parameter(theta)
    camera_pred = nn.Parameter(camera)
    optimizer = torch.optim.Adam(
        [
            rotmat_pred,
            camera_pred,
        ],
        lr=1e-3,
    )
    global_step = i_ini

    with tqdm.trange(n_steps) as pbar:
        for i in pbar:
            global_step = i + i_ini
            optimizer.zero_grad()

            global_orient = rotmat_pred[:3]
            body_pose = rotmat_pred[3:]
            smpl_output = smpl(
                global_orient=global_orient.unsqueeze(0),
                body_pose=body_pose.unsqueeze(0),
                pose2rot=True,
            )

            z = smpl_output.joints
            z = z.squeeze(0)

            joints = smpl_output.joints.squeeze(0)
            joints_2d = project_and_normalize_to_spin(joints, camera_pred)
            rleg, lleg = project_and_normalize_to_spin_legs(
                joints, smpl_output.A, camera_pred
            )
            joints_2d = joints_2d[selector]
            z = z[selector]
            keypoints_3d_pred = joints_2d

            keypoints_2d_pred = keypoints_3d_pred[:, :2]

            lprior = ((rotmat_pred - theta) ** 2).sum() + (
                (camera_pred - camera) ** 2
            ).sum()
            loss = lprior

            l2 = 0.0
            if c_mse > 0 and loss_mse is not None:
                l2 = loss_mse(keypoints_2d_pred, keypoints_2d)
                loss = loss + c_mse * l2

            vertices_pred = smpl_output.vertices

            lpar = z_loss = loss_sh = 0.0
            if c_new_mse > 0 and loss_parallel is not None:
                Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d = loss_parallel(
                    keypoints_3d_pred,
                    keypoints_2d,
                    z,
                    (rleg, lleg),
                    global_step=global_step,
                )
                lpar = (
                    Ltan
                    + c_new_mse * (args.c_f * Lcos + args.c_parallel * Lpar)
                    + Lspine
                    + args.c_reg * Lgr
                    + args.c_reg * Lstraight3d
                    + args.c_cont2d * Lcon2d
                )
                loss = loss + 300 * lpar

                for side in ["left", "right"]:
                    attr = f"{side}_foot_inds"
                    if hasattr(loss_parallel, attr):
                        foot_inds = getattr(loss_parallel, attr)
                        zind = 1
                        if attr not in mean_zfoot_val:
                            with torch.no_grad():
                                mean_zfoot_val[attr] = torch.median(
                                    vertices_pred[0, foot_inds, zind], dim=0
                                ).values

                        loss_foot = (
                            (vertices_pred[0, foot_inds, zind] - mean_zfoot_val[attr])
                            ** 2
                        ).sum()
                        loss = loss + args.c_reg * loss_foot

                if hasattr(loss_parallel, "silhuette_vertices_inds"):
                    inds = loss_parallel.silhuette_vertices_inds
                    loss_sh = (
                        (vertices_pred[0, inds, 1] - loss_parallel.ground) ** 2
                    ).sum()
                    loss = loss + args.c_reg * loss_sh

            lgsc_a = gsc_contact_loss = faces_angle_loss = 0.0
            if sc_crit is not None:
                gsc_contact_loss, faces_angle_loss = sc_crit(vertices_pred)
                lgsc_a = 1000 * gsc_contact_loss + 0.1 * faces_angle_loss
                loss = loss + lgsc_a

            msc_loss = 0.0
            if contact is not None and len(contact) > 0 and msc_crit is not None:
                if not isinstance(contact, list):
                    contact = [contact]

                for cntct in contact:
                    msc_loss = msc_crit(
                        cntct,
                        vertices_pred,
                    )
                    loss = loss + args.c_msc * msc_loss

            loss.backward()
            optimizer.step()

            epoch_loss = loss.item()
            pbar.set_postfix(
                **{
                    "l": f"{epoch_loss:.3}",
                    "l2": f"{l2:.3}",
                    "par": f"{lpar:.3}",
                    "z": f"{z_loss:.3}",
                    "gsc_contact": f"{float(gsc_contact_loss):.3}",
                    "faces_angle": f"{float(faces_angle_loss):.3}",
                    "msc": f"{float(msc_loss):.3}",
                }
            )

    rotmat_pred = rotmat_pred.detach()

    if zero_hands:
        for i in [20, 21]:
            rotmat_pred[3 * i : 3 * (i + 1)] = 0

        for i in [12, 15]:  # neck, head
            rotmat_pred[3 * i + 1] = 0  # y

    global_orient = rotmat_pred[:3]
    body_pose = rotmat_pred[3:]
    left_hand_pose = None
    right_hand_pose = None
    if fist is not None:
        left_hand_pose = rotmat_pred.new_tensor(fist_pose.LEFT_RELAXED).unsqueeze(0)
        right_hand_pose = rotmat_pred.new_tensor(fist_pose.RIGHT_RELAXED).unsqueeze(0)
        for f in fist:
            pp = fist_pose.INT_TO_FIST[f]
            if pp is not None:
                pp = rotmat_pred.new_tensor(pp).unsqueeze(0)

            if f.startswith("lf"):
                left_hand_pose = pp
            elif f.startswith("rf"):
                right_hand_pose = pp
            elif f.startswith("l"):
                body_pose[19 * 3 : 19 * 3 + 3] = pp
                left_hand_pose = None
            elif f.startswith("r"):
                body_pose[20 * 3 : 20 * 3 + 3] = pp
                right_hand_pose = None
            else:
                raise RuntimeError(f"No such hand pose: {f}")

    with torch.no_grad():
        smpl_output = smpl(
            global_orient=global_orient.unsqueeze(0),
            body_pose=body_pose.unsqueeze(0),
            left_hand_pose=left_hand_pose,
            right_hand_pose=right_hand_pose,
            pose2rot=True,
        )

    return rotmat_pred, smpl_output


def create_bone(i, j, keypoints_2d):
    a = keypoints_2d[i]
    b = keypoints_2d[j]
    ab = b - a
    ab = torch.nn.functional.normalize(ab, dim=0)

    return ab


def is_parallel_to_plane(bone, thresh=21):
    return abs(bone[0]) > math.cos(math.radians(thresh))


def is_close_to_plane(bone, plane, thresh):
    dist = abs(bone[0] - plane)

    return dist < thresh


def get_selector():
    selector = []
    for kp in pose_estimation.KPS:
        tmp = spin.JOINT_NAMES.index(PE_KSP_TO_SPIN[kp])
        selector.append(tmp)

    return selector


def calc_cos(joints_2d, joints_3d):
    cos = []
    for i, j in pose_estimation.SKELETON:
        a = joints_2d[i] - joints_2d[j]
        a = nn.functional.normalize(a, dim=0)

        b = joints_3d[i] - joints_3d[j]
        b = nn.functional.normalize(b, dim=0)[:2]

        c = (a * b).sum()
        cos.append(c)

    cos = torch.stack(cos, dim=0)

    return cos


def get_natural(keypoints_2d, vertices, right_foot_inds, left_foot_inds, loss_parallel, smpl):
    height_2d = (
        keypoints_2d.max(dim=0).values[0] - keypoints_2d.min(dim=0).values[0]
    ).item()
    plane_2d = keypoints_2d.max(dim=0).values[0].item()

    ground_parallel = []
    parallel_in_3d = []
    parallel3d_bones = set()

    # parallel chains
    for i, j, k in [
        ("Right Upper Leg", "Right Leg", "Right Foot"),
        ("Right Leg", "Right Foot", "Right Toe"),  # to remove?
        ("Left Upper Leg", "Left Leg", "Left Foot"),
        ("Left Leg", "Left Foot", "Left Toe"),  # to remove?
        ("Right Shoulder", "Right Arm", "Right Hand"),
        ("Left Shoulder", "Left Arm", "Left Hand"),
        # ("Hips", "Spine", "Neck"),
        # ("Spine", "Neck", "Head"),
    ]:
        i = pose_estimation.KPS.index(i)
        j = pose_estimation.KPS.index(j)
        k = pose_estimation.KPS.index(k)
        upleg_leg = create_bone(i, j, keypoints_2d)
        leg_foot = create_bone(j, k, keypoints_2d)

        if is_parallel_to_plane(upleg_leg) and is_parallel_to_plane(leg_foot):
            if is_close_to_plane(
                upleg_leg, plane_2d, thresh=0.1 * height_2d
            ) or is_close_to_plane(leg_foot, plane_2d, thresh=0.1 * height_2d):
                ground_parallel.append(((i, j), 1))
                ground_parallel.append(((j, k), 1))

        if (upleg_leg * leg_foot).sum() > math.cos(math.radians(21)):
            parallel_in_3d.append(((i, j), (j, k)))
            parallel3d_bones.add((i, j))
            parallel3d_bones.add((j, k))

    # parallel feets
    for i, j in [
        ("Right Foot", "Right Toe"),
        ("Left Foot", "Left Toe"),
    ]:
        i = pose_estimation.KPS.index(i)
        j = pose_estimation.KPS.index(j)
        if (i, j) in parallel3d_bones:
            continue

        foot_toe = create_bone(i, j, keypoints_2d)
        if is_parallel_to_plane(foot_toe, thresh=25):
            if "Right" in pose_estimation.KPS[i]:
                loss_parallel.right_foot_inds = right_foot_inds
            else:
                loss_parallel.left_foot_inds = left_foot_inds

    loss_parallel.ground_parallel = ground_parallel
    loss_parallel.parallel_in_3d = parallel_in_3d

    vertices_np = vertices[0].cpu().numpy()
    if len(ground_parallel) > 0:
        # Silhuette veritices
        mesh = trimesh.Trimesh(vertices=vertices_np, faces=smpl.faces, process=False)
        silhuette_vertices_mask_1 = np.abs(mesh.vertex_normals[..., 2]) < 2e-1
        height_3d = vertices_np[:, 1].max() - vertices_np[:, 1].min()
        plane_3d = vertices_np[:, 1].max()
        silhuette_vertices_mask_2 = (
            np.abs(vertices_np[:, 1] - plane_3d) < 0.15 * height_3d
        )
        silhuette_vertices_mask = np.logical_and(
            silhuette_vertices_mask_1, silhuette_vertices_mask_2
        )
        (silhuette_vertices_inds,) = np.where(silhuette_vertices_mask)
        if len(silhuette_vertices_inds) > 0:
            loss_parallel.silhuette_vertices_inds = silhuette_vertices_inds
            loss_parallel.ground = plane_3d


def get_cos(keypoints_3d_pred, use_angle_transf, loss_parallel):
    keypoints_2d_pred = keypoints_3d_pred[:, :2]
    with torch.no_grad():
        cos_r = calc_cos(keypoints_2d_pred, keypoints_3d_pred)

    alpha = torch.acos(cos_r)
    if use_angle_transf:
        leg_inds = [
            5,
            6,  # right leg
            7,
            8,  # left leg
        ]
        foot_inds = [15, 16]
        nleg_inds = sorted(
            set(range(len(pose_estimation.SKELETON))) - set(leg_inds) - set(foot_inds)
        )
        alpha[nleg_inds] = alpha[nleg_inds] - alpha[nleg_inds].min()

        amli = alpha[leg_inds].min()
        leg_inds.extend(foot_inds)
        alpha[leg_inds] = alpha[leg_inds] - amli

        angles = alpha.detach().cpu().numpy()
        angles = hist_cub.cub(
            angles / (math.pi / 2),
            a=1.2121212121212122,
            b=-1.105527638190953,
            c=0.787878787878789,
        ) * (math.pi / 2)
        alpha = alpha.new_tensor(angles)

    loss_parallel.cos = torch.cos(alpha)

    return cos_r


def get_contacts(
    args,
    sc_module,
    y_data_conts,
    keypoints_2d,
    vertices,
    bone_to_params,
    loss_parallel,
):
    use_contacts = args.use_contacts
    use_msc = args.use_msc
    c_mse = args.c_mse

    if use_contacts:
        assert c_mse == 0
        contact, contact_2d, _ = find_contacts(
            y_data_conts, keypoints_2d, bone_to_params
        )
        if len(contact_2d) > 0:
            loss_parallel.contact_2d = contact_2d

        if len(contact) == 0:
            _, contact = sc_module.verts_in_contact(vertices, return_idx=True)
            contact = contact.cpu().numpy().ravel()
    elif use_msc:
        _, contact = sc_module.verts_in_contact(vertices, return_idx=True)
        contact = contact.cpu().numpy().ravel()
    else:
        contact = np.array([])

    return contact


def save_mesh(
    smpl,
    smpl_output,
    save_path,
    fname,
):
    mesh = trimesh.Trimesh(
        vertices=smpl_output.vertices[0].cpu().numpy(),
        faces=smpl.faces,
        process=False,
    )
    rot = trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0])
    mesh.apply_transform(rot)
    rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
    mesh.apply_transform(rot)
    mesh.export(save_path / f"{fname}.glb")


def eft_step(
    model_hmr,
    smpl,
    selector,
    input_img,
    keypoints_2d,
    optimizer,
    args,
    loss_mse,
    loss_parallel,
    c_beta,
    sc_module,
    y_data_conts,
    bone_to_params,
):
    (
        _,
        _,
        _,
        keypoints_3d_pred,
        _,
        smpl_output,
        _,
        _,
    ) = optimize(
        model_hmr,
        smpl,
        selector,
        input_img,
        keypoints_2d,
        optimizer,
        args,
        loss_mse=loss_mse,
        loss_parallel=loss_parallel,
        c_mse=1,
        c_new_mse=0,
        c_beta=c_beta,
        sc_crit=None,
        msc_crit=None,
        contact=None,
        n_steps=60 + 90,
    )

    # find contacts
    vertices = smpl_output.vertices.detach()
    contact = get_contacts(
        args,
        sc_module,
        y_data_conts,
        keypoints_2d,
        vertices,
        bone_to_params,
        loss_parallel,
    )

    return vertices, keypoints_3d_pred, contact


def dc_step(
    model_hmr,
    smpl,
    selector,
    input_img,
    keypoints_2d,
    optimizer,
    args,
    loss_mse,
    loss_parallel,
    c_mse,
    c_new_mse,
    c_beta,
    sc_crit,
    msc_crit,
    contact,
    use_contacts,
    use_msc,
):
    rotmat_pred, *_ = optimize(
        model_hmr,
        smpl,
        selector,
        input_img,
        keypoints_2d,
        optimizer,
        args,
        loss_mse=loss_mse,
        loss_parallel=loss_parallel,
        c_mse=c_mse,
        c_new_mse=c_new_mse,
        c_beta=c_beta,
        sc_crit=sc_crit,
        msc_crit=msc_crit if use_contacts or use_msc else None,
        contact=contact if use_contacts or use_msc else None,
        n_steps=60 if c_new_mse > 0 or use_contacts or use_msc else 0,  # + 60,,
        i_ini=60 + 90,
    )

    return rotmat_pred


def us_step(
    model_hmr,
    smpl,
    selector,
    input_img,
    rotmat_pred,
    keypoints_2d,
    args,
    loss_mse,
    loss_parallel,
    c_mse,
    c_new_mse,
    sc_crit,
    msc_crit,
    contact,
    use_contacts,
    use_msc,
    save_path,
):
    (_, _, camera_pred_us, _, _, _, smpl_output_us, _, _,) = get_pred_and_data(
        model_hmr,
        smpl,
        selector,
        input_img,
        use_betas=False,
        zero_hands=True,
    )

    _, smpl_output_us = optimize_ft(
        rotmat_pred,
        camera_pred_us,
        smpl,
        selector,
        keypoints_2d,
        args,
        loss_mse=loss_mse,
        loss_parallel=loss_parallel,
        c_mse=c_mse,
        c_new_mse=c_new_mse,
        sc_crit=sc_crit,
        msc_crit=msc_crit if use_contacts or use_msc else None,
        contact=contact if use_contacts or use_msc else None,
        n_steps=60 if use_contacts or use_msc else 0,  # + 60,
        i_ini=60 + 90 + 60,
        zero_hands=True,
        fist=args.fist,
    )

    save_mesh(
        smpl,
        smpl_output_us,
        save_path,
        "us",
    )


def main():
    args = parse_args()
    print(args)

    # models
    model_pose = cv2.dnn.readNetFromONNX(
        args.pose_estimation_model_path
    )  # "hrn_w48_384x288.onnx"
    model_contact = cv2.dnn.readNetFromONNX(
        args.contact_model_path
    )  # "contact_hrn_w32_256x192.onnx"

    device = (
        torch.device(args.device) if torch.cuda.is_available() else torch.device("cpu")
    )
    model_hmr = spin.hmr(args.smpl_mean_params_path)  # "smpl_mean_params.npz"
    model_hmr.to(device)
    checkpoint = torch.load(
        args.spin_model_path,  # "spin_model_smplx_eft_18.pt"
        map_location="cpu"
    )

    smpl = spin.SMPLX(
        args.smpl_model_dir,  # "models/smplx"
        batch_size=1,
        create_transl=False,
        use_pca=False,
        flat_hand_mean=args.fist is not None,
    )
    smpl.to(device)

    selector = get_selector()

    use_contacts = args.use_contacts
    use_msc = args.use_msc

    bone_to_params = np.load(args.bone_parametrization_path, allow_pickle=True).item()
    foot_inds = np.load(args.foot_inds_path, allow_pickle=True).item()
    left_foot_inds = foot_inds["left_foot_inds"]
    right_foot_inds = foot_inds["right_foot_inds"]

    if use_contacts:
        model_type = args.smpl_type
        sc_module = selfcontact.SelfContact(
            essentials_folder=args.essentials_dir,  # "smplify-xmc-essentials"
            geothres=0.3,
            euclthres=0.02,
            test_segments=True,
            compute_hd=True,
            model_type=model_type,
            device=device,
        )
        sc_module.to(device)

        sc_crit = selfcontact.losses.SelfContactLoss(
            contact_module=sc_module,
            inside_loss_weight=0.5,
            outside_loss_weight=0.0,
            contact_loss_weight=0.5,
            align_faces=True,
            use_hd=True,
            test_segments=True,
            device=device,
            model_type=model_type,
        )
        sc_crit.to(device)

        msc_crit = losses.MimickedSelfContactLoss(geodesics_mask=sc_module.geomask)
        msc_crit.to(device)
    else:
        sc_module = None
        sc_crit = None
        msc_crit = None

    loss_mse = losses.MSE([1, 10, 13])  # Neck + Right Upper Leg + Left Upper Leg

    ignore = (
        (1, 2),  # Neck + Right Shoulder
        (1, 5),  # Neck + Left Shoulder
        (9, 10),  # Hips + Right Upper Leg
        (9, 13),  # Hips + Left Upper Leg
    )
    loss_parallel = losses.Parallel(
        skeleton=pose_estimation.SKELETON,
        ignore=ignore,
    )

    c_mse = args.c_mse
    c_new_mse = args.c_par
    c_beta = 1e-3

    if c_mse > 0:
        assert c_new_mse == 0
    elif c_mse == 0:
        assert c_new_mse > 0

    root_path = Path(args.save_path)
    root_path.mkdir(exist_ok=True, parents=True)

    path_to_imgs = Path(args.img_path)
    if path_to_imgs.is_dir():
        path_to_imgs = path_to_imgs.iterdir()
    else:
        path_to_imgs = [path_to_imgs]

    for img_path in path_to_imgs:
        if not any(
            img_path.name.lower().endswith(ext) for ext in [".jpg", ".png", ".jpeg"]
        ):
            continue

        img_name = img_path.stem

        # use 2d keypoints detection
        (
            img_original,
            predicted_keypoints_2d,
            _,
            _,
        ) = pose_estimation.infer_single_image(
            model_pose,
            img_path,
            input_img_size=pose_estimation.IMG_SIZE,
            return_kps=True,
        )

        save_path = root_path / img_name
        save_path.mkdir(exist_ok=True, parents=True)

        img_original = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB)
        img_size_original = img_original.shape[:2]
        keypoints_2d, *_ = normalize_keypoints_to_spin(
            predicted_keypoints_2d, img_size_original
        )
        keypoints_2d = torch.from_numpy(keypoints_2d)
        keypoints_2d = keypoints_2d.to(device)

        (
            predicted_contact_heatmap,
            predicted_contact_heatmap_raw,
            very_hm_raw,
        ) = get_contact_heatmap(model_contact, img_path)
        predicted_contact_heatmap_raw = Image.fromarray(
            predicted_contact_heatmap_raw
        ).resize(img_size_original[::-1])
        predicted_contact_heatmap_raw = cv2.resize(very_hm_raw, img_size_original[::-1])

        if c_new_mse == 0:
            predicted_contact_heatmap_raw = None

        y_data_conts = get_vertices_in_heatmap(predicted_contact_heatmap)

        model_hmr.load_state_dict(checkpoint["model"], strict=True)
        model_hmr.train()
        freeze_layers(model_hmr)

        _, input_img = spin.process_image(img_path, input_res=spin.constants.IMG_RES)
        input_img = input_img.to(device)

        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, model_hmr.parameters()),
            lr=1e-6,
        )

        vertices, keypoints_3d_pred, contact = eft_step(
            model_hmr,
            smpl,
            selector,
            input_img,
            keypoints_2d,
            optimizer,
            args,
            loss_mse,
            loss_parallel,
            c_beta,
            sc_module,
            y_data_conts,
            bone_to_params,
        )

        if args.use_natural:
            get_natural(
                keypoints_2d, vertices, right_foot_inds, left_foot_inds, loss_parallel, smpl,
            )

        if args.use_cos:
            get_cos(keypoints_3d_pred, args.use_angle_transf, loss_parallel)

        rotmat_pred = dc_step(
            model_hmr,
            smpl,
            selector,
            input_img,
            keypoints_2d,
            optimizer,
            args,
            loss_mse,
            loss_parallel,
            c_mse,
            c_new_mse,
            c_beta,
            sc_crit,
            msc_crit,
            contact,
            use_contacts,
            use_msc,
        )

        us_step(
            model_hmr,
            smpl,
            selector,
            input_img,
            rotmat_pred,
            keypoints_2d,
            args,
            loss_mse,
            loss_parallel,
            c_mse,
            c_new_mse,
            sc_crit,
            msc_crit,
            contact,
            use_contacts,
            use_msc,
            save_path,
        )


if __name__ == "__main__":
    main()