import numpy as np
import torch
from smplx import SMPL as _SMPL
from smplx import SMPLX as _SMPLX
from smplx.body_models import SMPLOutput, SMPLXOutput
from smplx.lbs import vertices2joints

from .constants import JOINT_MAP, JOINT_NAMES

# Hand joints
SMPLX_HAND_TO_PANOPTIC = [
    0,
    13,
    14,
    15,
    16,
    1,
    2,
    3,
    17,
    4,
    5,
    6,
    18,
    10,
    11,
    12,
    19,
    7,
    8,
    9,
    20,
]  # Wrist Thumb to Pinky



class SMPL(_SMPL):
    """Extension of the official SMPL implementation to support more joints"""

    JOINTS = (
        "Hips",
        "Left Upper Leg",
        "Right Upper Leg",
        "Spine",
        "Left Leg",
        "Right Leg",
        "Spine1",
        "Left Foot",
        "Right Foot",
        "Thorax",
        "Left Toe",
        "Right Toe",
        "Neck",
        "Left Shoulder",
        "Right Shoulder",
        "Head",
        "Left ForeArm",
        "Right ForeArm",
        "Left Arm",
        "Right Arm",
        "Left Hand",
        "Right Hand",
        "Left Finger",
        "Right Finger",
    )

    SKELETON = (
        (0, 1),
        (0, 2),
        (0, 3),
        (1, 4),
        (2, 5),
        (3, 6),
        (4, 7),
        (5, 8),
        (6, 9),
        (7, 10),
        (8, 11),
        (9, 12),
        (12, 13),
        (12, 14),
        (12, 15),
        (13, 16),
        (14, 17),
        (16, 18),
        (17, 19),
        (18, 20),
        (19, 21),
        (20, 22),
        (21, 23),
    )

    def __init__(self, *args, **kwargs):
        super(SMPL, self).__init__(*args, **kwargs)
        joints = [JOINT_MAP[i] for i in JOINT_NAMES]
        joint_regressor_extra = kwargs["joint_regressor_extra_path"]
        J_regressor_extra = np.load(joint_regressor_extra)
        self.register_buffer(
            "J_regressor_extra", torch.tensor(J_regressor_extra, dtype=torch.float32)
        )
        self.joint_map = torch.tensor(joints, dtype=torch.long)

    def forward(self, *args, **kwargs):
        kwargs["get_skin"] = True
        smpl_output = super(SMPL, self).forward(*args, **kwargs)
        extra_joints = vertices2joints(
            self.J_regressor_extra, smpl_output.vertices
        )  # Additional 9 joints #Check doc/J_regressor_extra.png
        joints = torch.cat(
            [smpl_output.joints, extra_joints], dim=1
        )  # [N, 24 + 21, 3]  + [N, 9, 3]
        joints = joints[:, self.joint_map, :]
        output = SMPLOutput(
            vertices=smpl_output.vertices,
            global_orient=smpl_output.global_orient,
            body_pose=smpl_output.body_pose,
            joints=joints,
            betas=smpl_output.betas,
            full_pose=smpl_output.full_pose,
        )
        return output


class SMPLX(_SMPLX):
    """Extension of the official SMPL implementation to support more joints"""

    JOINTS = (
        "Hips",
        "Left Upper Leg",
        "Right Upper Leg",
        "Spine",
        "Left Leg",
        "Right Leg",
        "Spine1",
        "Left Foot",
        "Right Foot",
        "Thorax",
        "Left Toe",
        "Right Toe",
        "Neck",
        "Left Shoulder",
        "Right Shoulder",
        "Head",
        "Left ForeArm",
        "Right ForeArm",
        "Left Arm",
        "Right Arm",
        "Left Hand",
        "Right Hand",
    )

    SKELETON = (
        (0, 1),
        (0, 2),
        (0, 3),
        (1, 4),
        (2, 5),
        (3, 6),
        (4, 7),
        (5, 8),
        (6, 9),
        (7, 10),
        (8, 11),
        (9, 12),
        (12, 13),
        (12, 14),
        (12, 15),
        (13, 16),
        (14, 17),
        (16, 18),
        (17, 19),
        (18, 20),
        (19, 21),
    )

    def __init__(self, *args, **kwargs):
        kwargs["ext"] = "pkl"  # We have pkl file
        super(SMPLX, self).__init__(*args, **kwargs)
        joints = [JOINT_MAP[i] for i in JOINT_NAMES]
        self.joint_map = torch.tensor(joints, dtype=torch.long)

    def forward(self, *args, **kwargs):
        kwargs["get_skin"] = True

        # if pose parameter is for SMPL with 21 joints (ignoring root)
        try:
            if kwargs["body_pose"].shape[1] == 69:
                kwargs["body_pose"] = kwargs["body_pose"][
                    :, : -2 * 3
                ]  # Ignore the last two joints (which are on the palm. Not used)

            if kwargs["body_pose"].shape[1] == 23:
                kwargs["body_pose"] = kwargs["body_pose"][
                    :, :-2
                ]  # Ignore the last two joints (which are on the palm. Not used)
        except:
            pass

        smpl_output = super(SMPLX, self).forward(*args, **kwargs)

        # SMPL-X Joint order: https://docs.google.com/spreadsheets/d/1_1dLdaX-sbMkCKr_JzJW_RZCpwBwd7rcKkWT_VgAQ_0/edit#gid=0
        smplx_to_smpl = (
            list(range(0, 22)) + [28, 43] + list(range(55, 76))
        )  # 28 left middle finger , 43: right middle finger 1
        smpl_joints = smpl_output.joints[
            :, smplx_to_smpl, :
        ]  # Convert SMPL-X to SMPL     127 ->45
        joints = smpl_joints
        joints = joints[:, self.joint_map, :]

        smplx_lhand = (
            [20] + list(range(25, 40)) + list(range(66, 71))
        )  # 20 for left wrist. 20 finger joints
        lhand_joints = smpl_output.joints[:, smplx_lhand, :]  # (N,21,3)
        lhand_joints = lhand_joints[
            :, SMPLX_HAND_TO_PANOPTIC, :
        ]  # Convert SMPL-X hand order to paonptic hand order

        smplx_rhand = (
            [21] + list(range(40, 55)) + list(range(71, 76))
        )  # 21 for right wrist. 20 finger joints
        rhand_joints = smpl_output.joints[:, smplx_rhand, :]  # (N,21,3)
        rhand_joints = rhand_joints[
            :, SMPLX_HAND_TO_PANOPTIC, :
        ]  # Convert SMPL-X hand order to paonptic hand order

        output = SMPLXOutput(
            vertices=smpl_output.vertices,
            global_orient=smpl_output.global_orient,
            body_pose=smpl_output.body_pose,
            joints=joints,
            right_hand_pose=rhand_joints,  # N,21,3
            left_hand_pose=lhand_joints,  # N,21,3
            betas=smpl_output.betas,
            full_pose=smpl_output.full_pose,
            A=smpl_output.A,
        )
        return output


"""
0	pelvis',
1	left_hip',
2	right_hip',
3	spine1',
4	left_knee',
5	right_knee',
6	spine2',
7	left_ankle',
8	right_ankle',
9	spine3',
10	left_foot',
11	right_foot',
12	neck',
13	left_collar',
14	right_collar',
15	head',
16	left_shoulder',
17	right_shoulder',
18	left_elbow',
19	right_elbow',
20	left_wrist',
21	right_wrist',
22	jaw',
23	left_eye_smplhf',
24	right_eye_smplhf',
25	left_index1',
26	left_index2',
27	left_index3',
28	left_middle1',
29	left_middle2',
30	left_middle3',
31	left_pinky1',
32	left_pinky2',
33	left_pinky3',
34	left_ring1',
35	left_ring2',
36	left_ring3',
37	left_thumb1',
38	left_thumb2',
39	left_thumb3',
40	right_index1',
41	right_index2',
42	right_index3',
43	right_middle1',
44	right_middle2',
45	right_middle3',
46	right_pinky1',
47	right_pinky2',
48	right_pinky3',
49	right_ring1',
50	right_ring2',
51	right_ring3',
52	right_thumb1',
53	right_thumb2',
54	right_thumb3',
55	nose',
56	right_eye',
57	left_eye',
58	right_ear',
59	left_ear',
60	left_big_toe',
61	left_small_toe',
62	left_heel',
63	right_big_toe',
64	right_small_toe',
65	right_heel',
66	left_thumb',
67	left_index',
68	left_middle',
69	left_ring',
70	left_pinky',
71	right_thumb',
72	right_index',
73	right_middle',
74	right_ring',
75	right_pinky',
76	right_eye_brow1',
77	right_eye_brow2',
78	right_eye_brow3',
79	right_eye_brow4',
80	right_eye_brow5',
81	left_eye_brow5',
82	left_eye_brow4',
83	left_eye_brow3',
84	left_eye_brow2',
85	left_eye_brow1',
86	nose1',
87	nose2',
88	nose3',
89	nose4',
90	right_nose_2',
91	right_nose_1',
92	nose_middle',
93	left_nose_1',
94	left_nose_2',
95	right_eye1',
96	right_eye2',
97	right_eye3',
98	right_eye4',
99	right_eye5',
100	right_eye6',
101	left_eye4',
102	left_eye3',
103	left_eye2',
104	left_eye1',
105	left_eye6',
106	left_eye5',
107	right_mouth_1',
108	right_mouth_2',
109	right_mouth_3',
110	mouth_top',
111	left_mouth_3',
112	left_mouth_2',
113	left_mouth_1',
114	left_mouth_5', # 59 in OpenPose output
115	left_mouth_4', # 58 in OpenPose output
116	mouth_bottom',
117	right_mouth_4',
118	right_mouth_5',
119	right_lip_1',
120	right_lip_2',
121	lip_top',
122	left_lip_2',
123	left_lip_1',
124	left_lip_3',
125	lip_bottom',
126	right_lip_3',
127	right_contour_1',
128	right_contour_2',
129	right_contour_3',
130	right_contour_4',
131	right_contour_5',
132	right_contour_6',
133	right_contour_7',
134	right_contour_8',
135	contour_middle',
136	left_contour_8',
137	left_contour_7',
138	left_contour_6',
139	left_contour_5',
140	left_contour_4',
141	left_contour_3',
142	left_contour_2',
143	left_contour_1'
"""


# SMPL Joints:
"""
0	pelvis',
1	left_hip',
2	right_hip',
3	spine1',
4	left_knee',
5	right_knee',
6	spine2',
7	left_ankle',
8	right_ankle',
9	spine3',
10	left_foot',
11	right_foot',
12	neck',
13	left_collar',
14	right_collar',
15	head',
16	left_shoulder',
17	right_shoulder',
18	left_elbow',
19	right_elbow',
20	left_wrist',
21	right_wrist',
22	
23	
24	nose',
25	right_eye',
26	left_eye',
27	right_ear',
28	left_ear',
29	left_big_toe',
30	left_small_toe',
31	left_heel',
32	right_big_toe',
33	right_small_toe',
34	right_heel',
35	left_thumb',
36	left_index',
37	left_middle',
38	left_ring',
39	left_pinky',
40	right_thumb',
41	right_index',
42	right_middle',
43	right_ring',
44	right_pinky',
"""