import sys
import os 

os.system("git clone https://github.com/royorel/StyleSDF.git")
sys.path.append("StyleSDF")



os.system(f"{sys.executable} -m pip install -U fvcore")

import torch
pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
version_str="".join([
    f"py3{sys.version_info.minor}_cu",
    torch.version.cuda.replace(".",""),
    f"_pyt{pyt_version_str}"
])

os.system(f"{sys.executable} -m pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html")

from  download_models import download_pretrained_models

download_pretrained_models()


import torch
import trimesh
import numpy as np
from munch import *
from PIL import Image
from tqdm import tqdm
from torch.nn import functional as F
from torch.utils import data
from torchvision import utils
from torchvision import transforms
from skimage.measure import marching_cubes
from scipy.spatial import Delaunay
from options import BaseOptions
from model import Generator
from utils import (
    generate_camera_params,
    align_volume,
    extract_mesh_with_marching_cubes,
    xyz2mesh,
)
from utils import (
    generate_camera_params, align_volume, extract_mesh_with_marching_cubes,
    xyz2mesh, create_cameras, create_mesh_renderer, add_textures,
    )
from pytorch3d.structures import Meshes
from pdb import set_trace as st
import skvideo.io

def generate(opt, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent):
    g_ema.eval()
    if not opt.no_surface_renderings:
        surface_g_ema.eval()

    # set camera angles
    if opt.fixed_camera_angles:
        # These can be changed to any other specific viewpoints.
        # You can add or remove viewpoints as you wish
        locations = torch.tensor([[0, 0],
                                  [-1.5 * opt.camera.azim, 0],
                                  [-1 * opt.camera.azim, 0],
                                  [-0.5 * opt.camera.azim, 0],
                                  [0.5 * opt.camera.azim, 0],
                                  [1 * opt.camera.azim, 0],
                                  [1.5 * opt.camera.azim, 0],
                                  [0, -1.5 * opt.camera.elev],
                                  [0, -1 * opt.camera.elev],
                                  [0, -0.5 * opt.camera.elev],
                                  [0, 0.5 * opt.camera.elev],
                                  [0, 1 * opt.camera.elev],
                                  [0, 1.5 * opt.camera.elev]], device=device)
        # For zooming in/out change the values of fov
        # (This can be defined for each view separately via a custom tensor
        # like the locations tensor above. Tensor shape should be [locations.shape[0],1])
        # reasonable values are [0.75 * opt.camera.fov, 1.25 * opt.camera.fov]
        fov = opt.camera.fov * torch.ones((locations.shape[0],1), device=device)
        num_viewdirs = locations.shape[0]
    else: # draw random camera angles
        locations = None
        # fov = None
        fov = opt.camera.fov
        num_viewdirs = opt.num_views_per_id

    # generate images
    for i in tqdm(range(opt.identities)):
        with torch.no_grad():
            chunk = 8
            sample_z = torch.randn(1, opt.style_dim, device=device).repeat(num_viewdirs,1)
            sample_cam_extrinsics, sample_focals, sample_near, sample_far, sample_locations = \
            generate_camera_params(opt.renderer_output_size, device, batch=num_viewdirs,
                                   locations=locations, #input_fov=fov,
                                   uniform=opt.camera.uniform, azim_range=opt.camera.azim,
                                   elev_range=opt.camera.elev, fov_ang=fov,
                                   dist_radius=opt.camera.dist_radius)
            rgb_images = torch.Tensor(0, 3, opt.size, opt.size)
            rgb_images_thumbs = torch.Tensor(0, 3, opt.renderer_output_size, opt.renderer_output_size)
            for j in range(0, num_viewdirs, chunk):
                out = g_ema([sample_z[j:j+chunk]],
                            sample_cam_extrinsics[j:j+chunk],
                            sample_focals[j:j+chunk],
                            sample_near[j:j+chunk],
                            sample_far[j:j+chunk],
                            truncation=opt.truncation_ratio,
                            truncation_latent=mean_latent)

                rgb_images = torch.cat([rgb_images, out[0].cpu()], 0)
                rgb_images_thumbs = torch.cat([rgb_images_thumbs, out[1].cpu()], 0)

            utils.save_image(rgb_images,
                os.path.join(opt.results_dst_dir, 'images','{}.png'.format(str(i).zfill(7))),
                nrow=num_viewdirs,
                normalize=True,
                padding=0,
                value_range=(-1, 1),)

            utils.save_image(rgb_images_thumbs,
                os.path.join(opt.results_dst_dir, 'images','{}_thumb.png'.format(str(i).zfill(7))),
                nrow=num_viewdirs,
                normalize=True,
                padding=0,
                value_range=(-1, 1),)

            # this is done to fit to RTX2080 RAM size (11GB)
            del out
            torch.cuda.empty_cache()

            if not opt.no_surface_renderings:
                surface_chunk = 1
                scale = surface_g_ema.renderer.out_im_res / g_ema.renderer.out_im_res
                surface_sample_focals = sample_focals * scale
                for j in range(0, num_viewdirs, surface_chunk):
                    surface_out = surface_g_ema([sample_z[j:j+surface_chunk]],
                                                sample_cam_extrinsics[j:j+surface_chunk],
                                                surface_sample_focals[j:j+surface_chunk],
                                                sample_near[j:j+surface_chunk],
                                                sample_far[j:j+surface_chunk],
                                                truncation=opt.truncation_ratio,
                                                truncation_latent=surface_mean_latent,
                                                return_sdf=True,
                                                return_xyz=True)

                    xyz = surface_out[2].cpu()
                    sdf = surface_out[3].cpu()

                    # this is done to fit to RTX2080 RAM size (11GB)
                    del surface_out
                    torch.cuda.empty_cache()

                    # mesh extractions are done one at a time
                    for k in range(surface_chunk):
                        curr_locations = sample_locations[j:j+surface_chunk]
                        loc_str = '_azim{}_elev{}'.format(int(curr_locations[k,0] * 180 / np.pi),
                                                          int(curr_locations[k,1] * 180 / np.pi))

                        # Save depth outputs as meshes
                        depth_mesh_filename = os.path.join(opt.results_dst_dir,'depth_map_meshes','sample_{}_depth_mesh{}.obj'.format(i, loc_str))
                        depth_mesh = xyz2mesh(xyz[k:k+surface_chunk])
                        if depth_mesh != None:
                            with open(depth_mesh_filename, 'w') as f:
                                depth_mesh.export(f,file_type='obj')

                        # extract full geometry with marching cubes
                        if j == 0:
                            try:
                                frostum_aligned_sdf = align_volume(sdf)
                                marching_cubes_mesh = extract_mesh_with_marching_cubes(frostum_aligned_sdf[k:k+surface_chunk])
                            except ValueError:
                                marching_cubes_mesh = None
                                print('Marching cubes extraction failed.')
                                print('Please check whether the SDF values are all larger (or all smaller) than 0.')
                        return depth_mesh,marching_cubes_mesh
                    

                    
# User options


def get_generate_vars(model_type):

  opt = BaseOptions().parse()
  opt.camera.uniform = True
  opt.model.is_test = True
  opt.model.freeze_renderer = False
  opt.rendering.offset_sampling = True
  opt.rendering.static_viewdirs = True
  opt.rendering.force_background = True
  opt.rendering.perturb = 0
  opt.inference.renderer_output_size = opt.model.renderer_spatial_output_dim
  opt.inference.style_dim = opt.model.style_dim
  opt.inference.project_noise = opt.model.project_noise

  # User options
  opt.inference.no_surface_renderings = False # When true, only RGB images will be created
  opt.inference.fixed_camera_angles = False # When true, each identity will be rendered from a specific set of 13 viewpoints. Otherwise, random views are generated
  opt.inference.identities = 1 # Number of identities to generate
  opt.inference.num_views_per_id = 1 # Number of viewpoints generated per identity. This option is ignored if opt.inference.fixed_camera_angles is true.
  opt.inference.camera = opt.camera

  # Load saved model
  if model_type == 'ffhq':
      model_path = 'ffhq1024x1024.pt'
      opt.model.size = 1024
      opt.experiment.expname = 'ffhq1024x1024'
  else:
      opt.inference.camera.azim = 0.15
      model_path = 'afhq512x512.pt'
      opt.model.size = 512
      opt.experiment.expname = 'afhq512x512'

  # Create results directory
  result_model_dir = 'final_model'
  results_dir_basename = os.path.join(opt.inference.results_dir, opt.experiment.expname)
  opt.inference.results_dst_dir = os.path.join(results_dir_basename, result_model_dir)
  if opt.inference.fixed_camera_angles:
      opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'fixed_angles')
  else:
      opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'random_angles')

  os.makedirs(opt.inference.results_dst_dir, exist_ok=True)
  os.makedirs(os.path.join(opt.inference.results_dst_dir, 'images'), exist_ok=True)


  if not opt.inference.no_surface_renderings:
      os.makedirs(os.path.join(opt.inference.results_dst_dir, 'depth_map_meshes'), exist_ok=True)
      os.makedirs(os.path.join(opt.inference.results_dst_dir, 'marching_cubes_meshes'), exist_ok=True)

  opt.inference.size = opt.model.size
  checkpoint_path = os.path.join('full_models', model_path)
  checkpoint = torch.load(checkpoint_path)

  # Load image generation model
  g_ema = Generator(opt.model, opt.rendering).to(device)
  pretrained_weights_dict = checkpoint["g_ema"]
  model_dict = g_ema.state_dict()
  for k, v in pretrained_weights_dict.items():
      if v.size() == model_dict[k].size():
          model_dict[k] = v

  g_ema.load_state_dict(model_dict)

  # Load a second volume renderer that extracts surfaces at 128x128x128 (or higher) for better surface resolution
  if not opt.inference.no_surface_renderings:
      opt['surf_extraction'] = Munch()
      opt.surf_extraction.rendering = opt.rendering
      opt.surf_extraction.model = opt.model.copy()
      opt.surf_extraction.model.renderer_spatial_output_dim = 128
      opt.surf_extraction.rendering.N_samples = opt.surf_extraction.model.renderer_spatial_output_dim
      opt.surf_extraction.rendering.return_xyz = True
      opt.surf_extraction.rendering.return_sdf = True
      surface_g_ema = Generator(opt.surf_extraction.model, opt.surf_extraction.rendering, full_pipeline=False).to(device)


      # Load weights to surface extractor
      surface_extractor_dict = surface_g_ema.state_dict()
      for k, v in pretrained_weights_dict.items():
          if k in surface_extractor_dict.keys() and v.size() == surface_extractor_dict[k].size():
              surface_extractor_dict[k] = v

      surface_g_ema.load_state_dict(surface_extractor_dict)
  else:
      surface_g_ema = None

  # Get the mean latent vector for g_ema
  if opt.inference.truncation_ratio < 1:
      with torch.no_grad():
          mean_latent = g_ema.mean_latent(opt.inference.truncation_mean, device)
  else:
      surface_mean_latent = None

  # Get the mean latent vector for surface_g_ema
  if not opt.inference.no_surface_renderings:
      surface_mean_latent = mean_latent[0]
  else:
      surface_mean_latent = None

  return opt.inference, g_ema, surface_g_ema, mean_latent, surface_mean_latent,opt.inference.results_dst_dir



def get_rendervideo_vars(model_type,number_frames):
    opt = BaseOptions().parse()
    opt.model.is_test = True
    opt.model.style_dim = 256
    opt.model.freeze_renderer = False
    opt.inference.size = opt.model.size
    opt.inference.camera = opt.camera
    opt.inference.renderer_output_size = opt.model.renderer_spatial_output_dim
    opt.inference.style_dim = opt.model.style_dim
    opt.inference.project_noise = opt.model.project_noise
    opt.rendering.perturb = 0
    opt.rendering.force_background = True
    opt.rendering.static_viewdirs = True
    opt.rendering.return_sdf = True
    opt.rendering.N_samples = 64
    opt.inference.identities = 1

      # Load saved model
    if model_type == 'ffhq':
        model_path = 'ffhq1024x1024.pt'
        opt.model.size = 1024
        opt.experiment.expname = 'ffhq1024x1024'
    else:
        opt.inference.camera.azim = 0.15
        model_path = 'afhq512x512.pt'
        opt.model.size = 512
        opt.experiment.expname = 'afhq512x512'

    opt.inference.size = opt.model.size

    # Create results directory
    result_model_dir = 'final_model'
    results_dir_basename = os.path.join(opt.inference.results_dir, opt.experiment.expname)
    
    opt.inference.results_dst_dir = os.path.join(results_dir_basename, result_model_dir)


    os.makedirs(opt.inference.results_dst_dir, exist_ok=True)
    os.makedirs(os.path.join(opt.inference.results_dst_dir, 'videos'), exist_ok=True)

    checkpoints_dir = './full_models'
    checkpoint_path = os.path.join('full_models', model_path)

    if os.path.isfile(checkpoint_path):
        # define results directory name
        result_model_dir = 'final_model'


    results_dir_basename = os.path.join(opt.inference.results_dir, opt.experiment.expname)
    opt.inference.results_dst_dir = os.path.join(results_dir_basename, result_model_dir, 'videos')
    if opt.model.project_noise:
        opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'with_noise_projection')

    os.makedirs(opt.inference.results_dst_dir, exist_ok=True)
    print(checkpoint_path)
    # load saved model
    checkpoint = torch.load(checkpoint_path)

    # load image generation model
    g_ema = Generator(opt.model, opt.rendering).to(device)

    # temp fix because of wrong noise sizes
    pretrained_weights_dict = checkpoint["g_ema"]
    model_dict = g_ema.state_dict()
    for k, v in pretrained_weights_dict.items():
        if v.size() == model_dict[k].size():
            model_dict[k] = v

    g_ema.load_state_dict(model_dict)

    # load a the volume renderee to a second that extracts surfaces at 128x128x128
    if not opt.inference.no_surface_renderings or opt.model.project_noise:
        opt['surf_extraction'] = Munch()
        opt.surf_extraction.rendering = opt.rendering
        opt.surf_extraction.model = opt.model.copy()
        opt.surf_extraction.model.renderer_spatial_output_dim = 128
        opt.surf_extraction.rendering.N_samples = opt.surf_extraction.model.renderer_spatial_output_dim
        opt.surf_extraction.rendering.return_xyz = True
        opt.surf_extraction.rendering.return_sdf = True
        opt.inference.surf_extraction_output_size = opt.surf_extraction.model.renderer_spatial_output_dim
        surface_g_ema = Generator(opt.surf_extraction.model, opt.surf_extraction.rendering, full_pipeline=False).to(device)


        # Load weights to surface extractor
        surface_extractor_dict = surface_g_ema.state_dict()
        for k, v in pretrained_weights_dict.items():
            if k in surface_extractor_dict.keys() and v.size() == surface_extractor_dict[k].size():
                surface_extractor_dict[k] = v

        surface_g_ema.load_state_dict(surface_extractor_dict)
    else:
        surface_g_ema = None

    # get the mean latent vector for g_ema
    if opt.inference.truncation_ratio < 1:
        with torch.no_grad():
            mean_latent = g_ema.mean_latent(opt.inference.truncation_mean, device)
    else:
        mean_latent = None

    # get the mean latent vector for surface_g_ema
    if not opt.inference.no_surface_renderings or opt.model.project_noise:
        surface_mean_latent = mean_latent[0]
    else:
        surface_mean_latent = None

    return opt.inference, g_ema, surface_g_ema, mean_latent, surface_mean_latent,opt.inference.results_dst_dir




def render_video(opt, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent,numberofframes):
    g_ema.eval()
    if not opt.no_surface_renderings or opt.project_noise:
        surface_g_ema.eval()

    images = torch.Tensor(0, 3, opt.size, opt.size)
    num_frames = numberofframes
    # Generate video trajectory
    trajectory = np.zeros((num_frames,3), dtype=np.float32)

    # set camera trajectory
    # sweep azimuth angles (4 seconds)
    if opt.azim_video:
        t = np.linspace(0, 1, num_frames)
        elev = 0
        fov = opt.camera.fov
        if opt.camera.uniform:
            azim = opt.camera.azim * np.cos(t * 2 * np.pi)
        else:
            azim = 1.5 * opt.camera.azim * np.cos(t * 2 * np.pi)

        trajectory[:num_frames,0] = azim
        trajectory[:num_frames,1] = elev
        trajectory[:num_frames,2] = fov

    # elipsoid sweep (4 seconds)
    else:
        t = np.linspace(0, 1, num_frames)
        fov = opt.camera.fov #+ 1 * np.sin(t * 2 * np.pi)
        if opt.camera.uniform:
            elev = opt.camera.elev / 2 + opt.camera.elev / 2  * np.sin(t * 2 * np.pi)
            azim = opt.camera.azim  * np.cos(t * 2 * np.pi)
        else:
            elev = 1.5 * opt.camera.elev * np.sin(t * 2 * np.pi)
            azim = 1.5 * opt.camera.azim * np.cos(t * 2 * np.pi)

        trajectory[:num_frames,0] = azim
        trajectory[:num_frames,1] = elev
        trajectory[:num_frames,2] = fov

    trajectory = torch.from_numpy(trajectory).to(device)

    # generate input parameters for the camera trajectory
    # sample_cam_poses, sample_focals, sample_near, sample_far = \
    # generate_camera_params(trajectory, opt.renderer_output_size, device, dist_radius=opt.camera.dist_radius)


    sample_cam_extrinsics, sample_focals, sample_near, sample_far, _ = \
    generate_camera_params(opt.renderer_output_size, device, locations=trajectory[:,:2],
                           fov_ang=trajectory[:,2:], dist_radius=opt.camera.dist_radius)


    # In case of noise projection, generate input parameters for the frontal position.
    # The reference mesh for the noise projection is extracted from the frontal position.
    # For more details see section C.1 in the supplementary material.
    if opt.project_noise:
        frontal_pose = torch.tensor([[0.0,0.0,opt.camera.fov]]).to(device)
        # frontal_cam_pose, frontal_focals, frontal_near, frontal_far = \
        # generate_camera_params(frontal_pose, opt.surf_extraction_output_size, device, dist_radius=opt.camera.dist_radius)
        frontal_cam_pose, frontal_focals, frontal_near, frontal_far, _ = \
        generate_camera_params(opt.surf_extraction_output_size, device, location=frontal_pose[:,:2],
                               fov_ang=frontal_pose[:,2:], dist_radius=opt.camera.dist_radius)

    # create geometry renderer (renders the depth maps)
    cameras = create_cameras(azim=np.rad2deg(trajectory[0,0].cpu().numpy()),
                             elev=np.rad2deg(trajectory[0,1].cpu().numpy()),
                             dist=1, device=device)
    renderer = create_mesh_renderer(cameras, image_size=512, specular_color=((0,0,0),),
                    ambient_color=((0.1,.1,.1),), diffuse_color=((0.75,.75,.75),),
                    device=device)

    suffix = '_azim' if opt.azim_video else '_elipsoid'

    # generate videos
    for i in range(opt.identities):
        print('Processing identity {}/{}...'.format(i+1, opt.identities))
        chunk = 1
        sample_z = torch.randn(1, opt.style_dim, device=device).repeat(chunk,1)
        video_filename = 'sample_video_{}{}.mp4'.format(i,suffix)
        writer = skvideo.io.FFmpegWriter(os.path.join(opt.results_dst_dir, video_filename),
                                         outputdict={'-pix_fmt': 'yuv420p', '-crf': '10'})
        if not opt.no_surface_renderings:
            depth_video_filename = 'sample_depth_video_{}{}.mp4'.format(i,suffix)
            depth_writer = skvideo.io.FFmpegWriter(os.path.join(opt.results_dst_dir, depth_video_filename),
                                             outputdict={'-pix_fmt': 'yuv420p', '-crf': '1'})


        ####################### Extract initial surface mesh from the frontal viewpoint #############
        # For more details see section C.1 in the supplementary material.
        if opt.project_noise:
            with torch.no_grad():
                frontal_surface_out = surface_g_ema([sample_z],
                                                    frontal_cam_pose,
                                                    frontal_focals,
                                                    frontal_near,
                                                    frontal_far,
                                                    truncation=opt.truncation_ratio,
                                                    truncation_latent=surface_mean_latent,
                                                    return_sdf=True)
                frontal_sdf = frontal_surface_out[2].cpu()

            print('Extracting Identity {} Frontal view Marching Cubes for consistent video rendering'.format(i))

            frostum_aligned_frontal_sdf = align_volume(frontal_sdf)
            del frontal_sdf

            try:
                frontal_marching_cubes_mesh = extract_mesh_with_marching_cubes(frostum_aligned_frontal_sdf)
            except ValueError:
                frontal_marching_cubes_mesh = None

            if frontal_marching_cubes_mesh != None:
                frontal_marching_cubes_mesh_filename = os.path.join(opt.results_dst_dir,'sample_{}_frontal_marching_cubes_mesh{}.obj'.format(i,suffix))
                with open(frontal_marching_cubes_mesh_filename, 'w') as f:
                    frontal_marching_cubes_mesh.export(f,file_type='obj')

            del frontal_surface_out
            torch.cuda.empty_cache()
        #############################################################################################

        for j in tqdm(range(0, num_frames, chunk)):
            with torch.no_grad():
                out = g_ema([sample_z],
                            sample_cam_extrinsics[j:j+chunk],
                            sample_focals[j:j+chunk],
                            sample_near[j:j+chunk],
                            sample_far[j:j+chunk],
                            truncation=opt.truncation_ratio,
                            truncation_latent=mean_latent,
                            randomize_noise=False,
                            project_noise=opt.project_noise,
                            mesh_path=frontal_marching_cubes_mesh_filename if opt.project_noise else None)

                rgb = out[0].cpu()
                utils.save_image(rgb,
                    os.path.join(opt.results_dst_dir, '{}.png'.format(str(i).zfill(7))),
                    nrow= trajectory[:,:2].shape[0],
                    normalize=True,
                    padding=0,
                    value_range=(-1, 1),)

                # this is done to fit to RTX2080 RAM size (11GB)
                del out
                torch.cuda.empty_cache()

                # Convert RGB from [-1, 1] to [0,255]
                rgb = 127.5 * (rgb.clamp(-1,1).permute(0,2,3,1).cpu().numpy() + 1)

                # Add RGB, frame to video
                for k in range(chunk):
                    writer.writeFrame(rgb[k])

                ########## Extract surface ##########
                if not opt.no_surface_renderings:
                    scale = surface_g_ema.renderer.out_im_res / g_ema.renderer.out_im_res
                    surface_sample_focals = sample_focals * scale
                    surface_out = surface_g_ema([sample_z],
                                                sample_cam_extrinsics[j:j+chunk],
                                                surface_sample_focals[j:j+chunk],
                                                sample_near[j:j+chunk],
                                                sample_far[j:j+chunk],
                                                truncation=opt.truncation_ratio,
                                                truncation_latent=surface_mean_latent,
                                                return_xyz=True)
                    xyz = surface_out[2].cpu()

                    # this is done to fit to RTX2080 RAM size (11GB)
                    del surface_out
                    torch.cuda.empty_cache()

                    # Render mesh for video
                    depth_mesh = xyz2mesh(xyz)
                    mesh = Meshes(
                        verts=[torch.from_numpy(np.asarray(depth_mesh.vertices)).to(torch.float32).to(device)],
                        faces = [torch.from_numpy(np.asarray(depth_mesh.faces)).to(torch.float32).to(device)],
                        textures=None,
                        verts_normals=[torch.from_numpy(np.copy(np.asarray(depth_mesh.vertex_normals))).to(torch.float32).to(device)],
                    )
                    mesh = add_textures(mesh)
                    cameras = create_cameras(azim=np.rad2deg(trajectory[j,0].cpu().numpy()),
                                             elev=np.rad2deg(trajectory[j,1].cpu().numpy()),
                                             fov=2*trajectory[j,2].cpu().numpy(),
                                             dist=1, device=device)
                    renderer = create_mesh_renderer(cameras, image_size=512,
                                                    light_location=((0.0,1.0,5.0),), specular_color=((0.2,0.2,0.2),),
                                                    ambient_color=((0.1,0.1,0.1),), diffuse_color=((0.65,.65,.65),),
                                                    device=device)

                    mesh_image = 255 * renderer(mesh).cpu().numpy()
                    mesh_image = mesh_image[...,:3]

                    # Add depth frame to video
                    for k in range(chunk):
                        depth_writer.writeFrame(mesh_image[k])

        # Close video writers
        writer.close()
        if not opt.no_surface_renderings:
            depth_writer.close()

        return video_filename
    
    
import gradio as gr
import plotly.graph_objects as go
from PIL import Image

device='cuda' if torch.cuda.is_available() else 'cpu'


def get_video(model_type,numberofframes,mesh_type):
    options,g_ema,surface_g_ema,  mean_latent, surface_mean_latent,result_filename=get_rendervideo_vars(model_type,numberofframes)
    render_video(options, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent,numberofframes)
    torch.cuda.empty_cache()
    del options,g_ema,surface_g_ema,  mean_latent, surface_mean_latent
    path_img=os.path.join(result_filename,"0000000.png")
    image=Image.open(path_img)

    if mesh_type=="DepthMesh":
      path=os.path.join(result_filename,"sample_depth_video_0_elipsoid.mp4")
    else:
      path=os.path.join(result_filename,"sample_video_0_elipsoid.mp4")

    return path,image

def get_mesh(model_type,mesh_type):
    options,g_ema,surface_g_ema,  mean_latent, surface_mean_latent,result_filename=get_generate_vars(model_type)
    depth_mesh,mc_mesh=generate(options, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent)
    torch.cuda.empty_cache()
    del options,g_ema,surface_g_ema,  mean_latent, surface_mean_latent
    if mesh_type=="DepthMesh":
      mesh=depth_mesh
    else:
      mesh=mc_mesh

    x=np.asarray(mesh.vertices).T[0]
    y=np.asarray(mesh.vertices).T[1]
    z=np.asarray(mesh.vertices).T[2]

    i=np.asarray(mesh.faces).T[0]
    j=np.asarray(mesh.faces).T[1]
    k=np.asarray(mesh.faces).T[2]
    fig = go.Figure(go.Mesh3d(x=x, y=y, z=z, 
                    i=i, j=j, k=k, 
                    colorscale="Viridis",
                  colorbar_len=0.75,
                  flatshading=True,
                  lighting=dict(ambient=0.5,
                                diffuse=1,
                                fresnel=4,        
                                specular=0.5,
                                roughness=0.05,
                                facenormalsepsilon=0,
                                vertexnormalsepsilon=0),
                  lightposition=dict(x=100,
                                    y=100,
                                    z=1000)))
    path=os.path.join(result_filename,"images/0000000.png")

    image=Image.open(path)

    return fig,image
    
markdown=f'''
  # StyleSDF: High-Resolution 3D-Consistent Image and Geometry Generation

  
  [The space demo for the CVPR 2022 paper "StyleSDF: High-Resolution 3D-Consistent Image and Geometry Generation".](https://arxiv.org/abs/2112.11427)
  
  [For the official implementation.](https://github.com/royorel/StyleSDF)

  ### Future Work based on interest
  - Adding new models for new type objects
  - New Customization 
  
  
  It is running on {device}

  The process can take long time.Especially ,To generate videos and the time of process depends the number of frames and current compiler device.

  Note : For RGB video , choose marching cubes mesh type
  
'''
with gr.Blocks() as demo:
    with gr.Row():
      with gr.Column():
        with gr.Row():
            with gr.Column():
              gr.Markdown(markdown)
            with gr.Column():
              with gr.Row():
                with gr.Column():
                      image=gr.Image(type="pil",shape=(512,512))
                with gr.Column():
                      mesh = gr.Plot()
                with gr.Column():
                      video=gr.Video()
    with gr.Row():
      numberoframes = gr.Slider( minimum=30, maximum=250,label='Number Of Frame For Video Generation')
      model_name=gr.Dropdown(choices=["ffhq","afhq"],label="Choose Model Type")
      mesh_type=gr.Dropdown(choices=["DepthMesh","Marching Cubes"],label="Choose Mesh Type")

    with gr.Row():
      btn = gr.Button(value="Generate Mesh")
      btn_2=gr.Button(value="Generate Video")

    btn.click(get_mesh, [model_name,mesh_type],[ mesh,image])
    btn_2.click(get_video,[model_name,numberoframes,mesh_type],[video,image])

demo.launch(debug=True)