import gradio as gr
import os
import spaces
import sys
from copy import deepcopy
sys.path.append('./VADER-VideoCrafter/scripts/main')
sys.path.append('./VADER-VideoCrafter/scripts')
sys.path.append('./VADER-VideoCrafter')


from train_t2v_lora import main_fn, setup_model

examples = [
    ["Fairy and Magical Flowers: A fairy tends to enchanted, glowing flowers.", 'huggingface-hps-aesthetic', 
     8, 901, 384, 512, 12.0, 25, 1.0, 24, 10],
    ["A cat playing an electric guitar in a loft with industrial-style decor and soft, multicolored lights.", 
     'huggingface-hps-aesthetic', 8, 208, 384, 512, 12.0, 25, 1.0, 24, 10],
    ["A raccoon playing a guitar under a blossoming cherry tree.", 
     'huggingface-hps-aesthetic', 8, 180, 384, 512, 12.0, 25, 1.0, 24, 10],
    ["A raccoon playing an electric bass in a garage band setting.", 
     'huggingface-hps-aesthetic', 8, 400, 384, 512, 12.0, 25, 1.0, 24, 10],
    ["A talking bird with shimmering feathers and a melodious voice finds a legendary treasure, guiding through enchanted forests, ancient ruins, and mystical challenges.",
     "huggingface-pickscore", 16, 200, 384, 512, 12.0, 25, 1.0, 24, 10],
    ["A snow princess stands on the balcony of her ice castle, her hair adorned with delicate snowflakes, overlooking her serene realm.",
     "huggingface-pickscore", 16, 400, 384, 512, 12.0, 25, 1.0, 24, 10],
    ["A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.",
     "huggingface-pickscore", 16, 800, 384, 512, 12.0, 25, 1.0, 24, 10],
]

model = setup_model()

@spaces.GPU(duration=180)
def gradio_main_fn(prompt, lora_model, lora_rank, seed, height, width, unconditional_guidance_scale, ddim_steps, ddim_eta,
                   frames, savefps):
    global model
    if model is None:
        return "Model is not loaded. Please load the model first."
    video_path = main_fn(prompt=prompt,
                    lora_model=lora_model,
                    lora_rank=int(lora_rank),
                    seed=int(seed),
                    height=int(height), 
                    width=int(width),
                    unconditional_guidance_scale=float(unconditional_guidance_scale),
                    ddim_steps=int(ddim_steps),
                    ddim_eta=float(ddim_eta), 
                    frames=int(frames),
                    savefps=int(savefps),
                    model=deepcopy(model))

    return video_path

def reset_fn():
    return ("A brown dog eagerly eats from a bowl in a kitchen.", 
            200, 384, 512, 12.0, 25, 1.0, 24, 16, 10, "huggingface-pickscore")

def update_lora_rank(lora_model):
    if lora_model == "huggingface-pickscore":
        return gr.update(value=16)
    elif lora_model == "huggingface-hps-aesthetic":
        return gr.update(value=8)
    else: # "Base Model"
        return gr.update(value=8)

def update_dropdown(lora_rank):
    if lora_rank == 16:
        return gr.update(value="huggingface-pickscore")
    elif lora_rank == 8:
        return gr.update(value="huggingface-hps-aesthetic")
    else: # 0
        return gr.update(value="Base Model")

custom_css = """
    #centered {
        display: flex;
        justify-content: center;
        width: 60%;
        margin: 0 auto;
    }
    .column-centered {
        display: flex;
        flex-direction: column;
        align-items: center;
        width: 60%;
    }
    #image-upload {
        flex-grow: 1;
    }
    #params .tabs {
        display: flex;
        flex-direction: column;
        flex-grow: 1;
    }
    #params .tabitem[style="display: block;"] {
        flex-grow: 1;
        display: flex !important;
    }
    #params .gap {
        flex-grow: 1;
    }
    #params .form {
        flex-grow: 1 !important;
    }
    #params .form > :last-child{
        flex-grow: 1;
    }
"""

with gr.Blocks(css=custom_css) as demo:
    with gr.Row():
        with gr.Column():
            gr.HTML(
                """
                <h1 style='text-align: center; font-size: 3.2em; margin-bottom: 0.5em; font-family: Arial, sans-serif; margin: 20px;'>
                    Video Diffusion Alignment via Reward Gradient
                </h1>
                """
            )
            gr.HTML(
                """
                <style>
                    body {
                        font-family: Arial, sans-serif;
                        text-align: center;
                        margin: 50px;
                    }
                    a {
                        text-decoration: none !important;
                        color: black !important;
                    }
                </style>
                <body>
                <div style="font-size: 1.4em; margin-bottom: 0.5em; ">
                    <a href="https://mihirp1998.github.io">Mihir Prabhudesai</a><sup>*</sup>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
                    <a href="https://russellmendonca.github.io/">Russell Mendonca</a><sup>*</sup>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
                    <a href="mailto: zheyangqin.qzy@gmail.com">Zheyang Qin</a><sup>*</sup>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
                    <a href="https://www.cs.cmu.edu/~katef/">Katerina Fragkiadaki</a><sup></sup>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
                    <a href="https://www.cs.cmu.edu/~dpathak/">Deepak Pathak</a><sup></sup>


                </div>
                <div style="font-size: 1.3em; font-style: italic;">
                    Carnegie Mellon University
                </div>
                </body>
                """
            )
            gr.HTML(
                """
                <head>
                <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css">

                <style>
                .button-container {
                    display: flex;
                    justify-content: center;
                    gap: 10px;
                    margin-top: 10px;
                }

                .button-container a {
                    display: inline-flex;
                    align-items: center;
                    padding: 10px 20px;
                    border-radius: 30px;
                    border: 1px solid #ccc;
                    text-decoration: none;
                    color: #333 !important;
                    font-size: 16px;
                    text-decoration: none !important;
                }

                .button-container a i {
                    margin-right: 8px;
                }
                </style>
                </head>

                <div class="button-container">
                <a href="https://arxiv.org/abs/2407.08737" class="btn btn-outline-primary">
                    <i class="fa-solid fa-file-pdf"></i> Paper
                </a>
                <a href="https://vader-vid.github.io/" class="btn btn-outline-danger">
                    <i class="fa-solid fa-video"></i> Website
                <a href="https://github.com/mihirp1998/VADER" class="btn btn-outline-secondary">
                    <i class="fa-brands fa-github"></i> Code
                </a>
                </div>
                """
            )

    with gr.Row(elem_id="centered"):
        with gr.Column(elem_id="params"):
            lora_model = gr.Dropdown(
                label="VADER Model",
                choices=["huggingface-pickscore", "huggingface-hps-aesthetic"],
                value="huggingface-pickscore"
            )
            lora_rank = gr.Slider(minimum=8, maximum=16, label="LoRA Rank", step = 8, value=16)
            prompt = gr.Textbox(placeholder="Enter prompt text here", lines=4, label="Text Prompt",
                                value="A brown dog eagerly eats from a bowl in a kitchen.")
            run_btn = gr.Button("Run Inference")

        with gr.Column():
            output_video = gr.Video(elem_id="image-upload")
            
    with gr.Row(elem_id="centered"):
        with gr.Column():      
           

            seed = gr.Slider(minimum=0, maximum=65536, label="Seed", step = 1, value=200)

            with gr.Row():
                height = gr.Slider(minimum=0, maximum=512, label="Height", step = 16, value=384)
                width = gr.Slider(minimum=0, maximum=512, label="Width", step = 16, value=512)

            with gr.Row():
                frames = gr.Slider(minimum=0, maximum=50, label="Frames", step = 1, value=24)
                savefps = gr.Slider(minimum=0, maximum=30, label="Save FPS", step = 1, value=10)
            
            
            with gr.Row():
                DDIM_Steps = gr.Slider(minimum=0, maximum=50, label="DDIM Steps", step = 1, value=25)
                unconditional_guidance_scale = gr.Slider(minimum=0, maximum=50, label="Guidance Scale", step = 0.1, value=12.0)
                DDIM_Eta = gr.Slider(minimum=0, maximum=1, label="DDIM Eta", step = 0.01, value=1.0)

            # reset button
            reset_btn = gr.Button("Reset")
            
            reset_btn.click(fn=reset_fn, outputs=[prompt, seed, height, width, unconditional_guidance_scale, DDIM_Steps, DDIM_Eta, frames, lora_rank, savefps, lora_model])
                

            run_btn.click(fn=gradio_main_fn, 
                        inputs=[prompt, lora_model, lora_rank,
                                seed, height, width, unconditional_guidance_scale, 
                                DDIM_Steps, DDIM_Eta, frames, savefps],
                        outputs=output_video
                        )
            
            lora_model.change(fn=update_lora_rank, inputs=lora_model, outputs=lora_rank)
            lora_rank.change(fn=update_dropdown, inputs=lora_rank, outputs=lora_model)

            gr.Examples(examples=examples,
                    inputs=[prompt, lora_model, lora_rank, seed, 
                            height, width, unconditional_guidance_scale, 
                            DDIM_Steps, DDIM_Eta, frames, savefps],
                    outputs=output_video,
                    fn=gradio_main_fn,
                    run_on_click=False,
                    cache_examples="lazy",
                    )

demo.launch(share=True)