import pickle
import sys
import os

sys.path.append(os.getcwd())

import json
from glob import glob
from data_utils.utils import *
import torch.utils.data as data
from data_utils.consts import speaker_id
from data_utils.lower_body import count_part
import random
from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d

with open('data_utils/hand_component.json') as file_obj:
    comp = json.load(file_obj)
    left_hand_c = np.asarray(comp['left'])
    right_hand_c = np.asarray(comp['right'])


def to3d(data):
    left_hand_pose = np.einsum('bi,ij->bj', data[:, 75:87], left_hand_c[:12, :])
    right_hand_pose = np.einsum('bi,ij->bj', data[:, 87:99], right_hand_c[:12, :])
    data = np.concatenate((data[:, :75], left_hand_pose, right_hand_pose), axis=-1)
    return data


class SmplxDataset():
    '''
    creat a dataset for every segment and concat.
    '''

    def __init__(self,
                 data_root,
                 speaker,
                 motion_fn,
                 audio_fn,
                 audio_sr,
                 fps,
                 feat_method='mel_spec',
                 audio_feat_dim=64,
                 audio_feat_win_size=None,

                 train=True,
                 load_all=False,
                 split_trans_zero=False,
                 limbscaling=False,
                 num_frames=25,
                 num_pre_frames=25,
                 num_generate_length=25,
                 context_info=False,
                 convert_to_6d=False,
                 expression=False,
                 config=None,
                 am=None,
                 am_sr=None,
                 whole_video=False
                 ):

        self.data_root = data_root
        self.speaker = speaker

        self.feat_method = feat_method
        self.audio_fn = audio_fn
        self.audio_sr = audio_sr
        self.fps = fps
        self.audio_feat_dim = audio_feat_dim
        self.audio_feat_win_size = audio_feat_win_size
        self.context_info = context_info  # for aud feat
        self.convert_to_6d = convert_to_6d
        self.expression = expression

        self.train = train
        self.load_all = load_all
        self.split_trans_zero = split_trans_zero
        self.limbscaling = limbscaling
        self.num_frames = num_frames
        self.num_pre_frames = num_pre_frames
        self.num_generate_length = num_generate_length
        # print('num_generate_length ', self.num_generate_length)

        self.config = config
        self.am_sr = am_sr
        self.whole_video = whole_video
        load_mode = self.config.dataset_load_mode

        if load_mode == 'pickle':
            raise NotImplementedError

        elif load_mode == 'csv':
            import pickle
            with open(data_root, 'rb') as f:
                u = pickle._Unpickler(f)
                data = u.load()
                self.data = data[0]
            if self.load_all:
                self._load_npz_all()

        elif load_mode == 'json':
            self.annotations = glob(data_root + '/*pkl')
            if len(self.annotations) == 0:
                raise FileNotFoundError(data_root + ' are empty')
            self.annotations = sorted(self.annotations)
            self.img_name_list = self.annotations

            if self.load_all:
                self._load_them_all(am, am_sr, motion_fn)

    def _load_npz_all(self):
        self.loaded_data = {}
        self.complete_data = []
        data = self.data
        shape = data['body_pose_axis'].shape[0]
        self.betas = data['betas']
        self.img_name_list = []
        for index in range(shape):
            img_name = f'{index:6d}'
            self.img_name_list.append(img_name)

            jaw_pose = data['jaw_pose'][index]
            leye_pose = data['leye_pose'][index]
            reye_pose = data['reye_pose'][index]
            global_orient = data['global_orient'][index]
            body_pose = data['body_pose_axis'][index]
            left_hand_pose = data['left_hand_pose'][index]
            right_hand_pose = data['right_hand_pose'][index]

            full_body = np.concatenate(
                (jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose))
            assert full_body.shape[0] == 99
            if self.convert_to_6d:
                full_body = to3d(full_body)
                full_body = torch.from_numpy(full_body)
                full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body))
                full_body = np.asarray(full_body)
                if self.expression:
                    expression = data['expression'][index]
                    full_body = np.concatenate((full_body, expression))
                # full_body = np.concatenate((full_body, non_zero))
            else:
                full_body = to3d(full_body)
                if self.expression:
                    expression = data['expression'][index]
                    full_body = np.concatenate((full_body, expression))

            self.loaded_data[img_name] = full_body.reshape(-1)
            self.complete_data.append(full_body.reshape(-1))

        self.complete_data = np.array(self.complete_data)

        if self.audio_feat_win_size is not None:
            self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0)
            # print(self.audio_feat.shape)
        else:
            if self.feat_method == 'mel_spec':
                self.audio_feat = get_melspec(self.audio_fn, fps=self.fps, sr=self.audio_sr, n_mels=self.audio_feat_dim)
            elif self.feat_method == 'mfcc':
                self.audio_feat = get_mfcc(self.audio_fn,
                                           smlpx=True,
                                           sr=self.audio_sr,
                                           n_mfcc=self.audio_feat_dim,
                                           win_size=self.audio_feat_win_size
                                           )

    def _load_them_all(self, am, am_sr, motion_fn):
        self.loaded_data = {}
        self.complete_data = []
        f = open(motion_fn, 'rb+')
        data = pickle.load(f)

        self.betas = np.array(data['betas'])

        jaw_pose = np.array(data['jaw_pose'])
        leye_pose = np.array(data['leye_pose'])
        reye_pose = np.array(data['reye_pose'])
        global_orient = np.array(data['global_orient']).squeeze()
        body_pose = np.array(data['body_pose_axis'])
        left_hand_pose = np.array(data['left_hand_pose'])
        right_hand_pose = np.array(data['right_hand_pose'])

        full_body = np.concatenate(
            (jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose), axis=1)
        assert full_body.shape[1] == 99


        if self.convert_to_6d:
            full_body = to3d(full_body)
            full_body = torch.from_numpy(full_body)
            full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body.reshape(-1, 55, 3))).reshape(-1, 330)
            full_body = np.asarray(full_body)
            if self.expression:
                expression = np.array(data['expression'])
                full_body = np.concatenate((full_body, expression), axis=1)

        else:
            full_body = to3d(full_body)
            expression = np.array(data['expression'])
            full_body = np.concatenate((full_body, expression), axis=1)

        self.complete_data = full_body
        self.complete_data = np.array(self.complete_data)

        if self.audio_feat_win_size is not None:
            self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0)
        else:
            # if self.feat_method == 'mel_spec':
            #     self.audio_feat = get_melspec(self.audio_fn, fps=self.fps, sr=self.audio_sr, n_mels=self.audio_feat_dim)
            # elif self.feat_method == 'mfcc':
            self.audio_feat = get_mfcc_ta(self.audio_fn,
                                          smlpx=True,
                                          fps=30,
                                          sr=self.audio_sr,
                                          n_mfcc=self.audio_feat_dim,
                                          win_size=self.audio_feat_win_size,
                                          type=self.feat_method,
                                          am=am,
                                          am_sr=am_sr,
                                          encoder_choice=self.config.Model.encoder_choice,
                                          )
            # with open(audio_file, 'w', encoding='utf-8') as file:
            #     file.write(json.dumps(self.audio_feat.__array__().tolist(), indent=0, ensure_ascii=False))

    def get_dataset(self, normalization=False, normalize_stats=None, split='train'):

        class __Worker__(data.Dataset):
            def __init__(child, index_list, normalization, normalize_stats, split='train') -> None:
                super().__init__()
                child.index_list = index_list
                child.normalization = normalization
                child.normalize_stats = normalize_stats
                child.split = split

            def __getitem__(child, index):
                num_generate_length = self.num_generate_length
                num_pre_frames = self.num_pre_frames
                seq_len = num_generate_length + num_pre_frames
                # print(num_generate_length)

                index = child.index_list[index]
                index_new = index + random.randrange(0, 5, 3)
                if index_new + seq_len > self.complete_data.shape[0]:
                    index_new = index
                index = index_new

                if child.split in ['val', 'pre', 'test'] or self.whole_video:
                    index = 0
                    seq_len = self.complete_data.shape[0]
                seq_data = []
                assert index + seq_len <= self.complete_data.shape[0]
                # print(seq_len)
                seq_data = self.complete_data[index:(index + seq_len), :]
                seq_data = np.array(seq_data)

                '''
                audio featureļ¼Œ
                '''
                if not self.context_info:
                    if not self.whole_video:
                        audio_feat = self.audio_feat[index:index + seq_len, ...]
                        if audio_feat.shape[0] < seq_len:
                            audio_feat = np.pad(audio_feat, [[0, seq_len - audio_feat.shape[0]], [0, 0]],
                                                mode='reflect')

                        assert audio_feat.shape[0] == seq_len and audio_feat.shape[1] == self.audio_feat_dim
                    else:
                        audio_feat = self.audio_feat

                else:  # including feature and history
                    if self.audio_feat_win_size is None:
                        audio_feat = self.audio_feat[index:index + seq_len + num_pre_frames, ...]
                        if audio_feat.shape[0] < seq_len + num_pre_frames:
                            audio_feat = np.pad(audio_feat,
                                                [[0, seq_len + self.num_frames - audio_feat.shape[0]], [0, 0]],
                                                mode='constant')

                        assert audio_feat.shape[0] == self.num_frames + seq_len and audio_feat.shape[
                            1] == self.audio_feat_dim

                if child.normalization:
                    data_mean = child.normalize_stats['mean'].reshape(1, -1)
                    data_std = child.normalize_stats['std'].reshape(1, -1)
                    seq_data[:, :330] = (seq_data[:, :330] - data_mean) / data_std
                if child.split in['train', 'test']:
                    if self.convert_to_6d:
                        if self.expression:
                            data_sample = {
                                'poses': seq_data[:, :330].astype(np.float).transpose(1, 0),
                                'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0),
                                # 'nzero': seq_data[:, 375:].astype(np.float).transpose(1, 0),
                                'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
                                'speaker': speaker_id[self.speaker],
                                'betas': self.betas,
                                'aud_file': self.audio_fn,
                            }
                        else:
                            data_sample = {
                                'poses': seq_data[:, :330].astype(np.float).transpose(1, 0),
                                'nzero': seq_data[:, 330:].astype(np.float).transpose(1, 0),
                                'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
                                'speaker': speaker_id[self.speaker],
                                'betas': self.betas
                            }
                    else:
                        if self.expression:
                            data_sample = {
                                'poses': seq_data[:, :165].astype(np.float).transpose(1, 0),
                                'expression': seq_data[:, 165:].astype(np.float).transpose(1, 0),
                                'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
                                # 'wv2_feat': wv2_feat.astype(np.float).transpose(1, 0),
                                'speaker': speaker_id[self.speaker],
                                'aud_file': self.audio_fn,
                                'betas': self.betas
                            }
                        else:
                            data_sample = {
                                'poses': seq_data.astype(np.float).transpose(1, 0),
                                'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
                                'speaker': speaker_id[self.speaker],
                                'betas': self.betas
                            }
                    return data_sample
                else:
                    data_sample = {
                        'poses': seq_data[:, :330].astype(np.float).transpose(1, 0),
                        'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0),
                        # 'nzero': seq_data[:, 325:].astype(np.float).transpose(1, 0),
                        'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
                        'aud_file': self.audio_fn,
                        'speaker': speaker_id[self.speaker],
                        'betas': self.betas
                    }
                    return data_sample
            def __len__(child):
                return len(child.index_list)

        if split == 'train':
            index_list = list(
                range(0, min(self.complete_data.shape[0], self.audio_feat.shape[0]) - self.num_generate_length - self.num_pre_frames,
                      6))
        elif split in ['val', 'test']:
            index_list = list([0])
        if self.whole_video:
            index_list = list([0])
        self.all_dataset = __Worker__(index_list, normalization, normalize_stats, split)

    def __len__(self):
        return len(self.img_name_list)