import imp
import os
from pickle import NONE
# os.environ['PYOPENGL_PLATFORM'] = 'osmesa'
import torch
import trimesh
import numpy as np
# import neural_renderer as nr
from skimage.transform import resize
from torchvision.utils import make_grid
import torch.nn.functional as F

from models.smpl import get_smpl_faces, get_model_faces, get_model_tpose
from utils.densepose_methods import DensePoseMethods
from core import constants, path_config
import json
from .geometry import convert_to_full_img_cam
from utils.imutils import crop

try:
    import math
    import pyrender
    from pyrender.constants import RenderFlags
except:
    pass
try:
    from opendr.renderer import ColoredRenderer
    from opendr.lighting import LambertianPointLight, SphericalHarmonics
    from opendr.camera import ProjectPoints
except:
    pass

from pytorch3d.structures.meshes import Meshes
# from pytorch3d.renderer.mesh.renderer import MeshRendererWithFragments

from pytorch3d.renderer import (
    look_at_view_transform, FoVPerspectiveCameras, PerspectiveCameras, AmbientLights, PointLights,
    RasterizationSettings, BlendParams, MeshRenderer, MeshRasterizer, SoftPhongShader,
    SoftSilhouetteShader, HardPhongShader, HardGouraudShader, HardFlatShader, TexturesVertex
)

import logging

logger = logging.getLogger(__name__)


class WeakPerspectiveCamera(pyrender.Camera):
    def __init__(
        self, scale, translation, znear=pyrender.camera.DEFAULT_Z_NEAR, zfar=None, name=None
    ):
        super(WeakPerspectiveCamera, self).__init__(
            znear=znear,
            zfar=zfar,
            name=name,
        )
        self.scale = scale
        self.translation = translation

    def get_projection_matrix(self, width=None, height=None):
        P = np.eye(4)
        P[0, 0] = self.scale[0]
        P[1, 1] = self.scale[1]
        P[0, 3] = self.translation[0] * self.scale[0]
        P[1, 3] = -self.translation[1] * self.scale[1]
        P[2, 2] = -1
        return P


class PyRenderer:
    def __init__(
        self, resolution=(224, 224), orig_img=False, wireframe=False, scale_ratio=1., vis_ratio=1.
    ):
        self.resolution = (resolution[0] * scale_ratio, resolution[1] * scale_ratio)
        # self.scale_ratio = scale_ratio

        self.faces = {
            'smplx': get_model_faces('smplx'),
            'smpl': get_model_faces('smpl'),
        #   'mano': get_model_faces('mano'),
        #   'flame': get_model_faces('flame'),
        }
        self.orig_img = orig_img
        self.wireframe = wireframe
        self.renderer = pyrender.OffscreenRenderer(
            viewport_width=self.resolution[0], viewport_height=self.resolution[1], point_size=1.0
        )

        self.vis_ratio = vis_ratio

        # set the scene
        self.scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.3, 0.3, 0.3))

        light = pyrender.PointLight(color=np.array([1.0, 1.0, 1.0]) * 0.2, intensity=1)

        yrot = np.radians(120)    # angle of lights

        light_pose = np.eye(4)
        light_pose[:3, 3] = [0, -1, 1]
        self.scene.add(light, pose=light_pose)

        light_pose[:3, 3] = [0, 1, 1]
        self.scene.add(light, pose=light_pose)

        light_pose[:3, 3] = [1, 1, 2]
        self.scene.add(light, pose=light_pose)

        spot_l = pyrender.SpotLight(
            color=np.ones(3), intensity=15.0, innerConeAngle=np.pi / 3, outerConeAngle=np.pi / 2
        )

        light_pose[:3, 3] = [1, 2, 2]
        self.scene.add(spot_l, pose=light_pose)

        light_pose[:3, 3] = [-1, 2, 2]
        self.scene.add(spot_l, pose=light_pose)

        # light_pose[:3, 3] = [-2, 2, 0]
        # self.scene.add(spot_l, pose=light_pose)

        # light_pose[:3, 3] = [-2, 2, 0]
        # self.scene.add(spot_l, pose=light_pose)

        self.colors_dict = {
            'red': np.array([0.5, 0.2, 0.2]),
            'pink': np.array([0.7, 0.5, 0.5]),
            'neutral': np.array([0.7, 0.7, 0.6]),
        # 'purple': np.array([0.5, 0.5, 0.7]),
            'purple': np.array([0.55, 0.4, 0.9]),
            'green': np.array([0.5, 0.55, 0.3]),
            'sky': np.array([0.3, 0.5, 0.55]),
            'white': np.array([1.0, 0.98, 0.94]),
        }

    def __call__(
        self,
        verts,
        faces=None,
        img=np.zeros((224, 224, 3)),
        cam=np.array([1, 0, 0]),
        focal_length=[5000, 5000],
        camera_rotation=np.eye(3),
        crop_info=None,
        angle=None,
        axis=None,
        mesh_filename=None,
        color_type=None,
        color=[1.0, 1.0, 0.9],
        iwp_mode=True,
        crop_img=True,
        mesh_type='smpl',
        scale_ratio=1.,
        rgba_mode=False
    ):

        if faces is None:
            faces = self.faces[mesh_type]
        mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False)

        Rx = trimesh.transformations.rotation_matrix(math.radians(180), [1, 0, 0])
        mesh.apply_transform(Rx)

        if mesh_filename is not None:
            mesh.export(mesh_filename)

        if angle and axis:
            R = trimesh.transformations.rotation_matrix(math.radians(angle), axis)
            mesh.apply_transform(R)

        cam = cam.copy()
        if iwp_mode:
            resolution = np.array(img.shape[:2]) * scale_ratio
            if len(cam) == 4:
                sx, sy, tx, ty = cam
                # sy = sx
                camera_translation = np.array(
                    [tx, ty, 2 * focal_length[0] / (resolution[0] * sy + 1e-9)]
                )
            elif len(cam) == 3:
                sx, tx, ty = cam
                sy = sx
                camera_translation = np.array(
                    [-tx, ty, 2 * focal_length[0] / (resolution[0] * sy + 1e-9)]
                )
            render_res = resolution
            self.renderer.viewport_width = render_res[1]
            self.renderer.viewport_height = render_res[0]
        else:
            if crop_info['opt_cam_t'] is None:
                camera_translation = convert_to_full_img_cam(
                    pare_cam=cam[None],
                    bbox_height=crop_info['bbox_scale'] * 200.,
                    bbox_center=crop_info['bbox_center'],
                    img_w=crop_info['img_w'],
                    img_h=crop_info['img_h'],
                    focal_length=focal_length[0],
                )
            else:
                camera_translation = crop_info['opt_cam_t']
            if torch.is_tensor(camera_translation):
                camera_translation = camera_translation[0].cpu().numpy()
            camera_translation = camera_translation.copy()
            camera_translation[0] *= -1
            if 'img_h' in crop_info and 'img_w' in crop_info:
                render_res = (int(crop_info['img_h'][0]), int(crop_info['img_w'][0]))
            else:
                render_res = img.shape[:2] if type(img) is not list else img[0].shape[:2]
            self.renderer.viewport_width = render_res[1]
            self.renderer.viewport_height = render_res[0]
            camera_rotation = camera_rotation.T
        camera = pyrender.IntrinsicsCamera(
            fx=focal_length[0], fy=focal_length[1], cx=render_res[1] / 2., cy=render_res[0] / 2.
        )

        if color_type != None:
            color = self.colors_dict[color_type]

        material = pyrender.MetallicRoughnessMaterial(
            metallicFactor=0.2,
            roughnessFactor=0.6,
            alphaMode='OPAQUE',
            baseColorFactor=(color[0], color[1], color[2], 1.0)
        )

        mesh = pyrender.Mesh.from_trimesh(mesh, material=material)

        mesh_node = self.scene.add(mesh, 'mesh')

        camera_pose = np.eye(4)
        camera_pose[:3, :3] = camera_rotation
        camera_pose[:3, 3] = camera_rotation @ camera_translation
        cam_node = self.scene.add(camera, pose=camera_pose)

        if self.wireframe:
            render_flags = RenderFlags.RGBA | RenderFlags.ALL_WIREFRAME | RenderFlags.SHADOWS_SPOT
        else:
            render_flags = RenderFlags.RGBA | RenderFlags.SHADOWS_SPOT

        rgb, _ = self.renderer.render(self.scene, flags=render_flags)
        if crop_info is not None and crop_img:
            crop_res = img.shape[:2]
            rgb, _, _ = crop(rgb, crop_info['bbox_center'][0], crop_info['bbox_scale'][0], crop_res)

        valid_mask = (rgb[:, :, -1] > 0)[:, :, np.newaxis]

        image_list = [img] if type(img) is not list else img

        return_img = []
        for item in image_list:
            if scale_ratio != 1:
                orig_size = item.shape[:2]
                item = resize(
                    item, (orig_size[0] * scale_ratio, orig_size[1] * scale_ratio),
                    anti_aliasing=True
                )
                item = (item * 255).astype(np.uint8)
            output_img = rgb[:, :, :-1] * valid_mask * self.vis_ratio + (
                1 - valid_mask * self.vis_ratio
            ) * item
            # output_img[valid_mask < 0.5] = item[valid_mask < 0.5]
            # if scale_ratio != 1:
            #     output_img = resize(output_img, (orig_size[0], orig_size[1]), anti_aliasing=True)
            if rgba_mode:
                output_img_rgba = np.zeros((output_img.shape[0], output_img.shape[1], 4))
                output_img_rgba[:, :, :3] = output_img
                output_img_rgba[:, :, 3][valid_mask[:, :, 0]] = 255
                output_img = output_img_rgba.astype(np.uint8)
            image = output_img.astype(np.uint8)
            return_img.append(image)
            return_img.append(item)

        if type(img) is not list:
            # if scale_ratio == 1:
            return_img = return_img[0]

        self.scene.remove_node(mesh_node)
        self.scene.remove_node(cam_node)

        return return_img


class OpenDRenderer:
    def __init__(self, resolution=(224, 224), ratio=1):
        self.resolution = (resolution[0] * ratio, resolution[1] * ratio)
        self.ratio = ratio
        self.focal_length = 5000.
        self.K = np.array(
            [
                [self.focal_length, 0., self.resolution[1] / 2.],
                [0., self.focal_length, self.resolution[0] / 2.], [0., 0., 1.]
            ]
        )
        self.colors_dict = {
            'red': np.array([0.5, 0.2, 0.2]),
            'pink': np.array([0.7, 0.5, 0.5]),
            'neutral': np.array([0.7, 0.7, 0.6]),
            'purple': np.array([0.5, 0.5, 0.7]),
            'green': np.array([0.5, 0.55, 0.3]),
            'sky': np.array([0.3, 0.5, 0.55]),
            'white': np.array([1.0, 0.98, 0.94]),
        }
        self.renderer = ColoredRenderer()
        self.faces = get_smpl_faces()

    def reset_res(self, resolution):
        self.resolution = (resolution[0] * self.ratio, resolution[1] * self.ratio)
        self.K = np.array(
            [
                [self.focal_length, 0., self.resolution[1] / 2.],
                [0., self.focal_length, self.resolution[0] / 2.], [0., 0., 1.]
            ]
        )

    def __call__(
        self,
        verts,
        faces=None,
        color=None,
        color_type='white',
        R=None,
        mesh_filename=None,
        img=np.zeros((224, 224, 3)),
        cam=np.array([1, 0, 0]),
        rgba=False,
        addlight=True
    ):
        '''Render mesh using OpenDR
        verts: shape - (V, 3)
        faces: shape - (F, 3)
        img: shape - (224, 224, 3), range - [0, 255] (np.uint8)
        axis: rotate along with X/Y/Z axis (by angle)
        R: rotation matrix (used to manipulate verts) shape - [3, 3]
        Return:
            rendered img: shape - (224, 224, 3), range - [0, 255] (np.uint8)
        '''
        ## Create OpenDR renderer
        rn = self.renderer
        h, w = self.resolution
        K = self.K

        f = np.array([K[0, 0], K[1, 1]])
        c = np.array([K[0, 2], K[1, 2]])

        if faces is None:
            faces = self.faces
        if len(cam) == 4:
            t = np.array([cam[2], cam[3], 2 * K[0, 0] / (w * cam[0] + 1e-9)])
        elif len(cam) == 3:
            t = np.array([cam[1], cam[2], 2 * K[0, 0] / (w * cam[0] + 1e-9)])

        rn.camera = ProjectPoints(rt=np.array([0, 0, 0]), t=t, f=f, c=c, k=np.zeros(5))
        rn.frustum = {'near': 1., 'far': 1000., 'width': w, 'height': h}

        albedo = np.ones_like(verts) * .9

        if color is not None:
            color0 = np.array(color)
            color1 = np.array(color)
            color2 = np.array(color)
        elif color_type == 'white':
            color0 = np.array([1., 1., 1.])
            color1 = np.array([1., 1., 1.])
            color2 = np.array([0.7, 0.7, 0.7])
            color = np.ones_like(verts) * self.colors_dict[color_type][None, :]
        else:
            color0 = self.colors_dict[color_type] * 1.2
            color1 = self.colors_dict[color_type] * 1.2
            color2 = self.colors_dict[color_type] * 1.2
            color = np.ones_like(verts) * self.colors_dict[color_type][None, :]

        # render_smpl = rn.r
        if R is not None:
            assert R.shape == (3, 3), "Shape of rotation matrix should be (3, 3)"
            verts = np.dot(verts, R)

        rn.set(v=verts, f=faces, vc=color, bgcolor=np.zeros(3))

        if addlight:
            yrot = np.radians(120)    # angle of lights
            # # 1. 1. 0.7
            rn.vc = LambertianPointLight(
                f=rn.f,
                v=rn.v,
                num_verts=len(rn.v),
                light_pos=rotateY(np.array([-200, -100, -100]), yrot),
                vc=albedo,
                light_color=color0
            )

            # Construct Left Light
            rn.vc += LambertianPointLight(
                f=rn.f,
                v=rn.v,
                num_verts=len(rn.v),
                light_pos=rotateY(np.array([800, 10, 300]), yrot),
                vc=albedo,
                light_color=color1
            )

            # Construct Right Light
            rn.vc += LambertianPointLight(
                f=rn.f,
                v=rn.v,
                num_verts=len(rn.v),
                light_pos=rotateY(np.array([-500, 500, 1000]), yrot),
                vc=albedo,
                light_color=color2
            )

        rendered_image = rn.r
        visibility_image = rn.visibility_image

        image_list = [img] if type(img) is not list else img

        return_img = []
        for item in image_list:
            if self.ratio != 1:
                img_resized = resize(
                    item, (item.shape[0] * self.ratio, item.shape[1] * self.ratio),
                    anti_aliasing=True
                )
            else:
                img_resized = item / 255.

            try:
                img_resized[visibility_image != (2**32 - 1)
                           ] = rendered_image[visibility_image != (2**32 - 1)]
            except:
                logger.warning('Can not render mesh.')

            img_resized = (img_resized * 255).astype(np.uint8)
            res = img_resized

            if rgba:
                img_resized_rgba = np.zeros((img_resized.shape[0], img_resized.shape[1], 4))
                img_resized_rgba[:, :, :3] = img_resized
                img_resized_rgba[:, :, 3][visibility_image != (2**32 - 1)] = 255
                res = img_resized_rgba.astype(np.uint8)
            return_img.append(res)

        if type(img) is not list:
            return_img = return_img[0]

        return return_img


#  https://github.com/classner/up/blob/master/up_tools/camera.py
def rotateY(points, angle):
    """Rotate all points in a 2D array around the y axis."""
    ry = np.array(
        [[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.], [-np.sin(angle), 0.,
                                                            np.cos(angle)]]
    )
    return np.dot(points, ry)


def rotateX(points, angle):
    """Rotate all points in a 2D array around the x axis."""
    rx = np.array(
        [[1., 0., 0.], [0., np.cos(angle), -np.sin(angle)], [0., np.sin(angle),
                                                             np.cos(angle)]]
    )
    return np.dot(points, rx)


def rotateZ(points, angle):
    """Rotate all points in a 2D array around the z axis."""
    rz = np.array(
        [[np.cos(angle), -np.sin(angle), 0.], [np.sin(angle), np.cos(angle), 0.], [0., 0., 1.]]
    )
    return np.dot(points, rz)


class IUV_Renderer(object):
    def __init__(
        self,
        focal_length=5000.,
        orig_size=224,
        output_size=56,
        mode='iuv',
        device=torch.device('cuda'),
        mesh_type='smpl'
    ):

        self.focal_length = focal_length
        self.orig_size = orig_size
        self.output_size = output_size

        if mode in ['iuv']:
            if mesh_type == 'smpl':
                DP = DensePoseMethods()

                vert_mapping = DP.All_vertices.astype('int64') - 1
                self.vert_mapping = torch.from_numpy(vert_mapping)

                faces = DP.FacesDensePose
                faces = faces[None, :, :]
                self.faces = torch.from_numpy(
                    faces.astype(np.int32)
                )    # [1, 13774, 3], torch.int32

                num_part = float(np.max(DP.FaceIndices))
                self.num_part = num_part

                dp_vert_pid_fname = 'data/dp_vert_pid.npy'
                if os.path.exists(dp_vert_pid_fname):
                    dp_vert_pid = list(np.load(dp_vert_pid_fname))
                else:
                    print('creating data/dp_vert_pid.npy')
                    dp_vert_pid = []
                    for v in range(len(vert_mapping)):
                        for i, f in enumerate(DP.FacesDensePose):
                            if v in f:
                                dp_vert_pid.append(DP.FaceIndices[i])
                                break
                    np.save(dp_vert_pid_fname, np.array(dp_vert_pid))

                textures_vts = np.array(
                    [
                        (dp_vert_pid[i] / num_part, DP.U_norm[i], DP.V_norm[i])
                        for i in range(len(vert_mapping))
                    ]
                )
                self.textures_vts = torch.from_numpy(
                    textures_vts[None].astype(np.float32)
                )    # (1, 7829, 3)
        elif mode == 'pncc':
            self.vert_mapping = None
            self.faces = torch.from_numpy(
                get_model_faces(mesh_type)[None].astype(np.int32)
            )    #  mano: torch.Size([1, 1538, 3])
            textures_vts = get_model_tpose(mesh_type).unsqueeze(
                0
            )    # mano: torch.Size([1, 778, 3])

            texture_min = torch.min(textures_vts) - 0.001
            texture_range = torch.max(textures_vts) - texture_min + 0.001
            self.textures_vts = (textures_vts - texture_min) / texture_range
        elif mode in ['seg']:
            self.vert_mapping = None
            body_model = 'smpl'

            self.faces = torch.from_numpy(get_smpl_faces().astype(np.int32)[None])

            with open(
                os.path.join(
                    path_config.SMPL_MODEL_DIR, '{}_vert_segmentation.json'.format(body_model)
                ), 'rb'
            ) as json_file:
                smpl_part_id = json.load(json_file)

            v_id = []
            for k in smpl_part_id.keys():
                v_id.extend(smpl_part_id[k])

            v_id = torch.tensor(v_id)
            n_verts = len(torch.unique(v_id))
            num_part = len(constants.SMPL_PART_ID.keys())
            self.num_part = num_part

            seg_vert_pid = np.zeros(n_verts)
            for k in smpl_part_id.keys():
                seg_vert_pid[smpl_part_id[k]] = constants.SMPL_PART_ID[k]

            print('seg_vert_pid', seg_vert_pid.shape)
            textures_vts = seg_vert_pid[:, None].repeat(3, axis=1) / num_part
            print('textures_vts', textures_vts.shape)
            # textures_vts = np.array(
            #     [(seg_vert_pid[i] / num_part,) * 3 for i in
            #     range(n_verts)])
            self.textures_vts = torch.from_numpy(textures_vts[None].astype(np.float32))

        K = np.array(
            [
                [self.focal_length, 0., self.orig_size / 2.],
                [0., self.focal_length, self.orig_size / 2.], [0., 0., 1.]
            ]
        )

        R = np.array([[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]])

        t = np.array([0, 0, 5])

        if self.orig_size != 224:
            rander_scale = self.orig_size / float(224)
            K[0, 0] *= rander_scale
            K[1, 1] *= rander_scale
            K[0, 2] *= rander_scale
            K[1, 2] *= rander_scale

        self.K = torch.FloatTensor(K[None, :, :])
        self.R = torch.FloatTensor(R[None, :, :])
        self.t = torch.FloatTensor(t[None, None, :])

        camK = F.pad(self.K, (0, 1, 0, 1), "constant", 0)
        camK[:, 2, 2] = 0
        camK[:, 3, 2] = 1
        camK[:, 2, 3] = 1

        self.K = camK

        self.device = device
        lights = AmbientLights(device=self.device)

        raster_settings = RasterizationSettings(
            image_size=output_size,
            blur_radius=0,
            faces_per_pixel=1,
        )
        self.renderer = MeshRenderer(
            rasterizer=MeshRasterizer(raster_settings=raster_settings),
            shader=HardFlatShader(
                device=self.device,
                lights=lights,
                blend_params=BlendParams(background_color=[0, 0, 0], sigma=0.0, gamma=0.0)
            )
        )

    def camera_matrix(self, cam):
        batch_size = cam.size(0)

        K = self.K.repeat(batch_size, 1, 1)
        R = self.R.repeat(batch_size, 1, 1)
        t = torch.stack(
            [-cam[:, 1], -cam[:, 2], 2 * self.focal_length / (self.orig_size * cam[:, 0] + 1e-9)],
            dim=-1
        )

        if cam.is_cuda:
            # device_id = cam.get_device()
            K = K.to(cam.device)
            R = R.to(cam.device)
            t = t.to(cam.device)

        return K, R, t

    def verts2iuvimg(self, verts, cam, iwp_mode=True):
        batch_size = verts.size(0)

        K, R, t = self.camera_matrix(cam)

        if self.vert_mapping is None:
            vertices = verts
        else:
            vertices = verts[:, self.vert_mapping, :]

        mesh = Meshes(vertices, self.faces.to(verts.device).expand(batch_size, -1, -1))
        mesh.textures = TexturesVertex(
            verts_features=self.textures_vts.to(verts.device).expand(batch_size, -1, -1)
        )

        cameras = PerspectiveCameras(
            device=verts.device,
            R=R,
            T=t,
            K=K,
            in_ndc=False,
            image_size=[(self.orig_size, self.orig_size)]
        )

        iuv_image = self.renderer(mesh, cameras=cameras)
        iuv_image = iuv_image[..., :3].permute(0, 3, 1, 2)

        return iuv_image