import sys
import os
import h5py
import torch
import pytorch3d.transforms as tra3d

from StructDiffusion.utils.rearrangement import show_pcs_color_order
from StructDiffusion.utils.pointnet import random_point_sample, index_points


def switch_stdout(stdout_filename=None):
    if stdout_filename:
        print("setting stdout to {}".format(stdout_filename))
        if os.path.exists(stdout_filename):
            sys.stdout = open(stdout_filename, 'a')
        else:
            sys.stdout = open(stdout_filename, 'w')
    else:
        sys.stdout = sys.__stdout__


def visualize_batch_pcs(obj_xyzs, B, N, P, verbose=True, limit_B=None):
    if limit_B is None:
        limit_B = B

    vis_obj_xyzs = obj_xyzs.reshape(B, N, P, -1)
    vis_obj_xyzs = vis_obj_xyzs[:limit_B]

    if type(vis_obj_xyzs).__module__ == torch.__name__:
        if vis_obj_xyzs.is_cuda:
            vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
        vis_obj_xyzs = vis_obj_xyzs.numpy()

    for bi, vis_obj_xyz in enumerate(vis_obj_xyzs):
        if verbose:
            print("example {}".format(bi))
            print(vis_obj_xyz.shape)
        show_pcs_color_order([xyz[:, :3] for xyz in vis_obj_xyz], None, visualize=True, add_coordinate_frame=True, add_table=False)


def convert_bool(d):
    for k in d:
        if type(d[k]) == list:
            d[k] = [bool(i) for i in d[k]]
        else:
            d[k] = bool(d[k])
    return d


def save_dict_to_h5(dict_data, filename):
    fh = h5py.File(filename, 'w')
    for k in dict_data:
        key_data = dict_data[k]
        if key_data is None:
            raise RuntimeError('data was not properly populated')
        # if type(key_data) is dict:
        #     key_data = json.dumps(key_data, sort_keys=True)
        try:
            fh.create_dataset(k, data=key_data)
        except TypeError as e:
            print("Failure on key", k)
            print(key_data)
            print(e)
            raise e
    fh.close()


def move_pc_and_create_scene_new(obj_xyzs, obj_params, struct_pose, current_pc_pose, target_object_inds, device,
                                 return_scene_pts=False, return_scene_pts_and_pc_idxs=False, num_scene_pts=None, normalize_pc=False,
                                 return_pair_pc=False, num_pair_pc_pts=None, normalize_pair_pc=False):

    # obj_xyzs: N, P, 3
    # obj_params: B, N, 6
    # struct_pose: B x N, 4, 4
    # current_pc_pose: B x N, 4, 4
    # target_object_inds: 1, N

    B, N, _ = obj_params.shape
    _, P, _ = obj_xyzs.shape

    # B, N, 6
    flat_obj_params = obj_params.reshape(B * N, -1)
    goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device)
    goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ")
    goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3]  # B x N, 4, 4

    goal_pc_pose = struct_pose @ goal_pc_pose_in_struct
    goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose)  # cur_batch_size x N, 4, 4

    # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
    transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))

    # obj_xyzs: N, P, 3
    new_obj_xyzs = obj_xyzs.repeat(B, 1, 1)
    new_obj_xyzs = transpose.transform_points(new_obj_xyzs)

    # put it back to B, N, P, 3
    new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
    # visualize_batch_pcs(new_obj_xyzs, S, N, P)


    # initialize the additional outputs
    subsampled_scene_xyz = None
    subsampled_pc_idxs = None
    obj_pair_xyzs = None

    # ===================================
    # Pass to discriminator
    if return_scene_pts:

        num_indicator = N

        # add one hot
        indicator_variables = torch.eye(num_indicator).repeat(B, 1, 1, P).reshape(B, num_indicator, P, num_indicator).to(device)  # B, N, P, N
        # print(indicator_variables.shape)
        # print(new_obj_xyzs.shape)
        new_obj_xyzs = torch.cat([new_obj_xyzs, indicator_variables], dim=-1)  # B, N, P, 3 + N

        # combine pcs in each scene
        scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3 + N)

        # ToDo: maybe convert this to a batch operation
        subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3 + N).to(device)
        for si, scene_xyz in enumerate(scene_xyzs):
            # scene_xyz: N*P, 3+N
            # target_object_inds: 1, N
            subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device)
            subsampled_scene_xyz[si] = scene_xyz[subsample_idx]

            # # debug:
            # print("-"*50)
            # if si < 10:
            #     trimesh.PointCloud(scene_xyz[:, :3].cpu().numpy(), colors=[255, 0, 0, 255]).show()
            #     trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 255, 0, 255]).show()

        # subsampled_scene_xyz: B, num_scene_pts, 3+N
        # new_obj_xyzs: B, N, P, 3
        # goal_pc_pose: B, N, 4, 4

        # important:
        if normalize_pc:
            subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3])

            # # debug:
            # for si in range(10):
            #     trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show()

    if return_scene_pts_and_pc_idxs:
        num_indicator = N
        pc_idxs = torch.arange(0, num_indicator)[:, None].repeat(B, 1, P).reshape(B, num_indicator, P).to(device)  # B, N, P
        # new_obj_xyzs: B, N, P, 3 + 1

        # combine pcs in each scene
        scene_xyzs = new_obj_xyzs.reshape(B, N * P, 3)
        pc_idxs = pc_idxs.reshape(B, N*P)

        subsampled_scene_xyz = torch.FloatTensor(B, num_scene_pts, 3).to(device)
        subsampled_pc_idxs = torch.LongTensor(B, num_scene_pts).to(device)
        for si, (scene_xyz, pc_idx) in enumerate(zip(scene_xyzs, pc_idxs)):
            # scene_xyz: N*P, 3+1
            # target_object_inds: 1, N
            subsample_idx = torch.randint(0, torch.sum(target_object_inds[0]) * P, (num_scene_pts,)).to(device)
            subsampled_scene_xyz[si] = scene_xyz[subsample_idx]
            subsampled_pc_idxs[si] = pc_idx[subsample_idx]

        # subsampled_scene_xyz: B, num_scene_pts, 3
        # subsampled_pc_idxs: B, num_scene_pts
        # new_obj_xyzs: B, N, P, 3
        # goal_pc_pose: B, N, 4, 4

        # important:
        if normalize_pc:
            subsampled_scene_xyz[:, :, 0:3] = pc_normalize_batch(subsampled_scene_xyz[:, :, 0:3])

        # TODO: visualize each individual object
        # debug
        # print(subsampled_scene_xyz.shape)
        # print(subsampled_pc_idxs.shape)
        # print("visualize subsampled scene")
        # for si in range(5):
        #     trimesh.PointCloud(subsampled_scene_xyz[si, :, :3].cpu().numpy(), colors=[0, 0, 255, 255]).show()

    ###############################################
    # Create input for pairwise collision detector
    if return_pair_pc:

        assert num_pair_pc_pts is not None

        # new_obj_xyzs: B, N, P, 3 + N
        # target_object_inds: 1, N
        # ignore paddings
        num_objs = torch.sum(target_object_inds[0])
        obj_pair_idxs = torch.combinations(torch.arange(num_objs), r=2)  # num_comb, 2

        # use [:, :, :, :3] to get obj_xyzs without object-wise indicator
        obj_pair_xyzs = new_obj_xyzs[:, :, :, :3][:, obj_pair_idxs]  # B, num_comb, 2 (obj 1 and obj 2), P, 3
        num_comb = obj_pair_xyzs.shape[1]
        pair_indicator_variables = torch.eye(2).repeat(B, num_comb, 1, 1, P).reshape(B, num_comb, 2, P, 2).to(device)  # B, num_comb, 2, P, 2
        obj_pair_xyzs = torch.cat([obj_pair_xyzs, pair_indicator_variables], dim=-1)  # B, num_comb, 2, P, 3 (pc channels) + 2 (indicator for obj 1 and obj 2)
        obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, P * 2, 5)

        # random sample: idx = np.random.randint(0, scene_xyz.shape[0], self.num_scene_pts)
        obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, P * 2, 5)
        # random_point_sample() input dim: B, N, C
        rand_idxs = random_point_sample(obj_pair_xyzs, num_pair_pc_pts)  # B * num_comb, num_pair_pc_pts
        obj_pair_xyzs = index_points(obj_pair_xyzs, rand_idxs)  # B * num_comb, num_pair_pc_pts, 5

        if normalize_pair_pc:
            # pc_normalize_batch() input dim: pc: B, num_scene_pts, 3
            # obj_pair_xyzs = obj_pair_xyzs.reshape(B * num_comb, num_pair_pc_pts, 5)
            obj_pair_xyzs[:, :, 0:3] = pc_normalize_batch(obj_pair_xyzs[:, :, 0:3])
            obj_pair_xyzs = obj_pair_xyzs.reshape(B, num_comb, num_pair_pc_pts, 5)

            # # debug
            # for bi, this_obj_pair_xyzs in enumerate(obj_pair_xyzs):
            #     print("batch id", bi)
            #     for pi, obj_pair_xyz in enumerate(this_obj_pair_xyzs):
            #         print("pair", pi)
            #         # obj_pair_xyzs: 2 * P, 5
            #         print(obj_pair_xyz[:, :3].shape)
            #         trimesh.PointCloud(obj_pair_xyz[:, :3].cpu()).show()

    # obj_pair_xyzs: B, num_comb, num_pair_pc_pts, 3 + 2
    goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4)

    return new_obj_xyzs, goal_pc_pose, subsampled_scene_xyz, subsampled_pc_idxs, obj_pair_xyzs



def move_pc(obj_xyzs, obj_params, struct_pose, current_pc_pose, device):

    # obj_xyzs: N, P, 3
    # obj_params: B, N, 6
    # struct_pose: B x N, 4, 4
    # current_pc_pose: B x N, 4, 4
    # target_object_inds: 1, N

    B, N, _ = obj_params.shape
    _, P, _ = obj_xyzs.shape

    # B, N, 6
    flat_obj_params = obj_params.reshape(B * N, -1)
    goal_pc_pose_in_struct = torch.eye(4).repeat(B * N, 1, 1).to(device)
    goal_pc_pose_in_struct[:, :3, :3] = tra3d.euler_angles_to_matrix(flat_obj_params[:, 3:], "XYZ")
    goal_pc_pose_in_struct[:, :3, 3] = flat_obj_params[:, :3]  # B x N, 4, 4

    goal_pc_pose = struct_pose @ goal_pc_pose_in_struct
    goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_pose)  # cur_batch_size x N, 4, 4

    # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
    transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))

    # obj_xyzs: N, P, 3
    new_obj_xyzs = obj_xyzs.repeat(B, 1, 1)
    new_obj_xyzs = transpose.transform_points(new_obj_xyzs)

    # put it back to B, N, P, 3
    new_obj_xyzs = new_obj_xyzs.reshape(B, N, P, -1)
    # visualize_batch_pcs(new_obj_xyzs, S, N, P)

    # subsampled_scene_xyz: B, num_scene_pts, 3+N
    # new_obj_xyzs: B, N, P, 3
    # goal_pc_pose: B, N, 4, 4

    goal_pc_pose = goal_pc_pose.reshape(B, N, 4, 4)
    return new_obj_xyzs, goal_pc_pose


def sample_gaussians(mus, sigmas, sample_size):
    # mus: [number of individual gaussians]
    # sigmas: [number of individual gaussians]
    normal = torch.distributions.Normal(mus, sigmas)
    samples = normal.sample((sample_size,))
    # samples: [sample_size, number of individual gaussians]
    return samples

def fit_gaussians(samples, sigma_eps=0.01):
    device = samples.device

    # samples: [sample_size, number of individual gaussians]
    num_gs = samples.shape[1]
    mus = torch.mean(samples, dim=0).to(device)
    sigmas = torch.std(samples, dim=0).to(device) + sigma_eps * torch.ones(num_gs).to(device)
    # mus: [number of individual gaussians]
    # sigmas: [number of individual gaussians]
    return mus, sigmas


def pc_normalize_batch(pc):
    # pc: B, num_scene_pts, 3
    centroid = torch.mean(pc, dim=1)  # B, 3
    pc = pc - centroid[:, None, :]
    m = torch.max(torch.sqrt(torch.sum(pc ** 2, dim=2)), dim=1)[0]
    pc = pc / m[:, None, None]
    return pc