import torch

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import os
import sys
try:
    import utils
    from diffusion import create_diffusion
except:
    sys.path.append(os.path.split(sys.path[0])[0])
    import utils
    from diffusion import create_diffusion
import argparse
import torchvision
from PIL import Image
from einops import rearrange
from models import get_models
from diffusers.models import AutoencoderKL
from models.clip import TextEmbedder
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from utils import mask_generation_before
from diffusers.utils.import_utils import is_xformers_available
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from vlogger.videofusion import fusion
from vlogger.videocaption import captioning
from vlogger.videoaudio import make_audio, merge_video_audio, concatenate_videos
from vlogger.STEB.model_transform import ip_scale_set, ip_transform_model, tca_transform_model
from vlogger.planning_utils.gpt4_utils import (readscript, 
                                               readtimescript, 
                                               readprotagonistscript, 
                                               readreferencescript, 
                                               readzhscript)


def auto_inpainting(args, 
                    video_input, 
                    masked_video, 
                    mask, 
                    prompt, 
                    image, 
                    vae, 
                    text_encoder, 
                    image_encoder, 
                    diffusion, 
                    model, 
                    device,
                    ):
    image_prompt_embeds = None
    if prompt is None:
        prompt = ""
    if image is not None:
        clip_image = CLIPImageProcessor()(images=image, return_tensors="pt").pixel_values
        clip_image_embeds = image_encoder(clip_image.to(device)).image_embeds
        uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds).to(device)
        image_prompt_embeds = torch.cat([clip_image_embeds, uncond_clip_image_embeds], dim=0)
        image_prompt_embeds = rearrange(image_prompt_embeds, '(b n) c -> b n c', b=2).contiguous()
        model = ip_scale_set(model, args.ref_cfg_scale)
        if args.use_fp16:
            image_prompt_embeds = image_prompt_embeds.to(dtype=torch.float16)
    b, f, c, h, w = video_input.shape
    latent_h = video_input.shape[-2] // 8
    latent_w = video_input.shape[-1] // 8

    if args.use_fp16:
        z = torch.randn(1, 4, 16, latent_h, latent_w, dtype=torch.float16, device=device) # b,c,f,h,w
        masked_video = masked_video.to(dtype=torch.float16)
        mask = mask.to(dtype=torch.float16)
    else:
        z = torch.randn(1, 4, 16, latent_h, latent_w, device=device) # b,c,f,h,w

    masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous()
    masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215)
    masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous()
    mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1)
    masked_video = torch.cat([masked_video] * 2)
    mask = torch.cat([mask] * 2)
    z = torch.cat([z] * 2)
    prompt_all = [prompt] + [args.negative_prompt]

    text_prompt = text_encoder(text_prompts=prompt_all, train=False)
    model_kwargs = dict(encoder_hidden_states=text_prompt, 
                        class_labels=None, 
                        cfg_scale=args.cfg_scale,
                        use_fp16=args.use_fp16,
                        ip_hidden_states=image_prompt_embeds)
    
    # Sample images:
    samples = diffusion.ddim_sample_loop(model.forward_with_cfg, 
                                         z.shape, 
                                         z, 
                                         clip_denoised=False, 
                                         model_kwargs=model_kwargs, 
                                         progress=True, 
                                         device=device,
                                         mask=mask, 
                                         x_start=masked_video, 
                                         use_concat=True,
                                         )
    samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
    if args.use_fp16:
        samples = samples.to(dtype=torch.float16)

    video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
    video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
    return video_clip


def main(args):
    # Setup PyTorch:
    if args.seed:
        torch.manual_seed(args.seed)
    torch.set_grad_enabled(False)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    seed_everything(args.seed)

    model = get_models(args).to(device)
    model = tca_transform_model(model).to(device)
    model = ip_transform_model(model).to(device)
    if args.enable_xformers_memory_efficient_attention:
        if is_xformers_available():
            model.enable_xformers_memory_efficient_attention()
        else:
            raise ValueError("xformers is not available. Make sure it is installed correctly")
    if args.use_compile:
        model = torch.compile(model)

    ckpt_path = args.ckpt 
    state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema']
    model_dict = model.state_dict()
    pretrained_dict = {}
    for k, v in state_dict.items():
        if k in model_dict:
            pretrained_dict[k] = v
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    model.eval()  # important!
    diffusion = create_diffusion(str(args.num_sampling_steps))
    vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device)
    text_encoder = text_encoder = TextEmbedder(args.pretrained_model_path).to(device)
    image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path).to(device)
    if args.use_fp16:
        print('Warnning: using half percision for inferencing!')
        vae.to(dtype=torch.float16)
        model.to(dtype=torch.float16)
        text_encoder.to(dtype=torch.float16)
    print("model ready!\n", flush=True)
    
    
    # load protagonist script
    character_places = readprotagonistscript(args.protagonist_file_path)
    print("protagonists ready!", flush=True)

    # load script
    video_list = readscript(args.script_file_path)
    print("video script ready!", flush=True)
    
    # load reference script
    reference_lists = readreferencescript(video_list, character_places, args.reference_file_path)
    print("reference script ready!", flush=True)
    
    # load zh script
    zh_video_list = readzhscript(args.zh_script_file_path)
    print("zh script ready!", flush=True)
    
    # load time script
    key_list = []
    for key, value in character_places.items():
        key_list.append(key)
    time_list = readtimescript(args.time_file_path)
    print("time script ready!", flush=True)
    

    # generation begin
    sample_list = []
    for i, text_prompt in enumerate(video_list):
        sample_list.append([])
        for time in range(time_list[i]):
            if time == 0:
                print('Generating the ({}) prompt'.format(text_prompt), flush=True)
                if reference_lists[i][0] == 0 or reference_lists[i][0] > len(key_list):
                    pil_image = None
                else:
                    pil_image = Image.open(args.reference_image_path[reference_lists[i][0] - 1])
                    pil_image.resize((256, 256))
                video_input = torch.zeros([1, 16, 3, args.image_size[0], args.image_size[1]]).to(device)
                mask = mask_generation_before("first0", video_input.shape, video_input.dtype, device) # b,f,c,h,w
                masked_video = video_input * (mask == 0)
                samples = auto_inpainting(args, 
                                          video_input, 
                                          masked_video, 
                                          mask, 
                                          text_prompt, 
                                          pil_image, 
                                          vae, 
                                          text_encoder, 
                                          image_encoder, 
                                          diffusion, 
                                          model, 
                                          device,
                                          )
                sample_list[i].append(samples)
            else:
                if sum(video.shape[0] for video in sample_list[i]) / args.fps >= time_list[i]:
                    break
                print('Generating the ({}) prompt'.format(text_prompt), flush=True)
                if reference_lists[i][0] == 0 or reference_lists[i][0] > len(key_list):
                    pil_image = None
                else:
                    pil_image = Image.open(args.reference_image_path[reference_lists[i][0] - 1])
                    pil_image.resize((256, 256))
                pre_video = sample_list[i][-1][-args.researve_frame:]
                f, c, h, w = pre_video.shape
                lat_video = torch.zeros(args.num_frames - args.researve_frame, c, h, w).to(device)
                video_input = torch.concat([pre_video, lat_video], dim=0)
                video_input = video_input.to(device).unsqueeze(0)
                mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device)
                masked_video = video_input * (mask == 0)
                video_clip = auto_inpainting(args, 
                                             video_input, 
                                             masked_video, 
                                             mask, 
                                             text_prompt, 
                                             pil_image, 
                                             vae, 
                                             text_encoder, 
                                             image_encoder, 
                                             diffusion, 
                                             model, 
                                             device,
                                             )
                sample_list[i].append(video_clip[args.researve_frame:])
                print(video_clip[args.researve_frame:].shape)

        # transition
        if args.video_transition and i != 0:
            video_1 = sample_list[i - 1][-1][-1:]
            video_2 = sample_list[i][0][:1]
            f, c, h, w = video_1.shape
            video_middle = torch.zeros(args.num_frames - 2, c, h, w).to(device)
            video_input = torch.concat([video_1, video_middle, video_2], dim=0)
            video_input = video_input.to(device).unsqueeze(0)
            mask = mask_generation_before("onelast1", video_input.shape, video_input.dtype, device)
            masked_video = masked_video = video_input * (mask == 0)
            video_clip = auto_inpainting(args, 
                                         video_input, 
                                         masked_video, 
                                         mask, 
                                         "smooth transition, slow motion, slow changing.", 
                                         pil_image, 
                                         vae, 
                                         text_encoder, 
                                         image_encoder, 
                                         diffusion, 
                                         model, 
                                         device,
                                         )
            sample_list[i].insert(0, video_clip[1:-1])

        # save videos
        samples = torch.concat(sample_list[i], dim=0)
        samples = samples[0: time_list[i] * args.fps]
        if not os.path.exists(args.save_origin_video_path):
            os.makedirs(args.save_origin_video_path)
        video_ = ((samples * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1).contiguous()
        torchvision.io.write_video(args.save_origin_video_path + "/" + f"{i}" + '.mp4', video_, fps=args.fps)
    
    # post processing
    fusion(args.save_origin_video_path)
    captioning(args.script_file_path, args.zh_script_file_path, args.save_origin_video_path, args.save_caption_video_path)
    fusion(args.save_caption_video_path)
    make_audio(args.script_file_path, args.save_audio_path)
    merge_video_audio(args.save_caption_video_path, args.save_audio_path, args.save_audio_caption_video_path)
    concatenate_videos(args.save_audio_caption_video_path)
    print('final video save path {}'.format(args.save_audio_caption_video_path))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="configs/vlog_read_script_sample.yaml")
    args = parser.parse_args()
    omega_conf = OmegaConf.load(args.config)
    save_path = omega_conf.save_path
    save_origin_video_path = os.path.join(save_path, "origin_video")
    save_caption_video_path = os.path.join(save_path.rsplit('/', 1)[0], "caption_video")
    save_audio_path = os.path.join(save_path.rsplit('/', 1)[0], "audio")
    save_audio_caption_video_path = os.path.join(save_path.rsplit('/', 1)[0], "audio_caption_video")
    if omega_conf.sample_num is not None:
        for i in range(omega_conf.sample_num):
            omega_conf.save_origin_video_path = save_origin_video_path + f'-{i}'
            omega_conf.save_caption_video_path = save_caption_video_path + f'-{i}'
            omega_conf.save_audio_path = save_audio_path + f'-{i}'
            omega_conf.save_audio_caption_video_path = save_audio_caption_video_path + f'-{i}'
            omega_conf.seed += i
            main(omega_conf)
    else:
        omega_conf.save_origin_video_path = save_origin_video_path
        omega_conf.save_caption_video_path = save_caption_video_path
        omega_conf.save_audio_path = save_audio_path
        omega_conf.save_audio_caption_video_path = save_audio_caption_video_path
        main(omega_conf)