# Authors: Hui Ren (rhfeiyang.github.io)
import spaces
import os
import gradio as gr
from diffusers import DiffusionPipeline
import matplotlib.pyplot as plt
import torch
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
print(f"Using {device} device, dtype={dtype}")
pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",
                                         torch_dtype=dtype).to(device)

from inference import get_lora_network, inference, get_validation_dataloader
lora_map = {
    "None": "None",
    "Andre Derain (fauvism)": "andre-derain_subset1",
    "Vincent van Gogh (post impressionism)": "van_gogh_subset1",
    "Andy Warhol (pop art)": "andy_subset1",
    "Walter Battiss": "walter-battiss_subset2",
    "Camille Corot (realism)": "camille-corot_subset1",
    "Claude Monet (impressionism)": "monet_subset2",
    "Pablo Picasso (cubism)": "picasso_subset1",
    "Jackson Pollock": "jackson-pollock_subset1",
    "Gerhard Richter (abstract expressionism)": "gerhard-richter_subset1",
    "M.C. Escher": "m.c.-escher_subset1",
    "Albert Gleizes": "albert-gleizes_subset1",
    "Hokusai (ukiyo-e)": "katsushika-hokusai_subset1",
    "Wassily Kandinsky": "kandinsky_subset1",
    "Gustav Klimt (art nouveau)": "klimt_subset3",
    "Roy Lichtenstein": "roy-lichtenstein_subset1",
    "Henri Matisse (abstract expressionism)": "henri-matisse_subset1",
    "Joan Miro": "joan-miro_subset2",
}

@spaces.GPU
def demo_inference_gen_artistic(adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0):
    adapter_path = lora_map[adapter_choice]
    if adapter_path not in [None, "None"]:
        adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
        style_prompt="sks art"
    else:
        style_prompt=None
    prompts = [prompt]
    infer_loader = get_validation_dataloader(prompts,num_workers=0)
    network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype, device=device)["network"]

    pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
                            height=512, width=512, scales=[adapter_scale],
                            save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
                            start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,
                            from_scratch=True, device=device, weight_dtype=dtype)[0][adapter_scale][0]
    return pred_images

@spaces.GPU
def demo_inference_gen_ori( prompt:str, seed:int=0, steps=50, guidance_scale=7.5):
    style_prompt=None
    prompts = [prompt]
    infer_loader = get_validation_dataloader(prompts,num_workers=0)
    network = get_lora_network(pipe.unet, "None", weight_dtype=dtype, device=device)["network"]

    pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
                            height=512, width=512, scales=[0.0],
                            save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
                            start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,
                            from_scratch=True, device=device, weight_dtype=dtype)[0][0.0][0]
    return pred_images


@spaces.GPU
def demo_inference_stylization_ori(ref_image, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, start_noise=800):
    style_prompt=None
    prompts = [prompt]
    # convert np to pil
    ref_image = [Image.fromarray(ref_image)]
    network = get_lora_network(pipe.unet, "None", weight_dtype=dtype, device=device)["network"]
    infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0)
    pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
                            height=512, width=512, scales=[0.0],
                            save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
                            start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True,
                            from_scratch=False, device=device, weight_dtype=dtype)[0][0.0][0]
    return pred_images

@spaces.GPU
def demo_inference_stylization_artistic(ref_image, adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0,start_noise=800):
    adapter_path = lora_map[adapter_choice]
    if adapter_path not in [None, "None"]:
        adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
        style_prompt="sks art"
    else:
        style_prompt=None
    prompts = [prompt]
    # convert np to pil
    ref_image = [Image.fromarray(ref_image)]
    network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype, device=device)["network"]
    infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0)
    pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
                            height=512, width=512, scales=[adapter_scale],
                            save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
                            start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True,
                            from_scratch=False, device=device, weight_dtype=dtype)[0][adapter_scale][0]
    return pred_images

@spaces.GPU
def demo_inference_all(prompt:str, ref_image, seed:int=0, adapter_choice="Andre Derain (fauvism)", steps=20, guidance_scale=7.5, adapter_scale=1.0,start_noise=800):
    results = []
    results.append(demo_inference_gen_ori(prompt, seed, steps, guidance_scale))
    results.append(demo_inference_gen_artistic(adapter_choice, prompt, seed, steps, guidance_scale, adapter_scale))
    results.append(demo_inference_stylization_ori(ref_image, prompt, seed, steps, guidance_scale, start_noise))
    results.append(demo_inference_stylization_artistic(ref_image, adapter_choice, prompt, seed, steps, guidance_scale, adapter_scale, start_noise))
    return results


block = gr.Blocks()
# Direct infer
with block:
    with gr.Group():
        gr.Markdown(" # Art-Free Diffusion Demo")
        gr.Markdown("(More features in development...)")
        with gr.Row():
            text = gr.Textbox(
                label="Prompt (long and detailed would be better):",
                max_lines=10,
                placeholder="Enter your prompt (long and detailed would be better)",
                container=True,
                value="A beautiful garden with a large pond. The pond is surrounded by a wooden deck, and there are several chairs placed around the area. A stone fountain is present in the middle of the pond, adding to the serene atmosphere. The garden is decorated with a variety of potted plants, creating a lush and inviting environment. The scene is captured in a vibrant and colorful style, highlighting the natural beauty of the garden.",
            )

        with gr.Tab('Generation'):
            with gr.Row():
                with gr.Column():
                    # gr.Markdown("## Art-Free Generation")
                    # gr.Markdown("Generate images from text prompts.")

                    gallery_gen_ori = gr.Image(
                        label="W/O Adapter",
                        show_label=True,
                        elem_id="gallery",
                        height="auto"
                    )


                with gr.Column():
                    # gr.Markdown("## Art-Free Generation")
                    # gr.Markdown("Generate images from text prompts.")
                    gallery_gen_art = gr.Image(
                        label="W/ Adapter",
                        show_label=True,
                        elem_id="gallery",
                        height="auto"
                    )


            with gr.Row():
                btn_gen_ori = gr.Button("Art-Free Generate", scale=1)
                btn_gen_art = gr.Button("Artistic Generate", scale=1)


        with gr.Tab('Stylization'):
            with gr.Row():

                with gr.Column():
                    # gr.Markdown("## Art-Free Generation")
                    # gr.Markdown("Generate images from text prompts.")

                    gallery_stylization_ref = gr.Image(
                        label="Ref Image",
                        show_label=True,
                        elem_id="gallery",
                        height="auto",
                        scale=1,
                        value="data/a_black_SUV_driving_down_a_highway_with_a_scenic_view_of_mountains_and_water_in_the_background._The_.jpg"
                    )
                with gr.Column(scale=2):
                    with gr.Row():
                        with gr.Column():
                            # gr.Markdown("## Art-Free Generation")
                            # gr.Markdown("Generate images from text prompts.")

                            gallery_stylization_ori = gr.Image(
                                label="W/O Adapter",
                                show_label=True,
                                elem_id="gallery",
                                height="auto",
                                scale=1,
                            )


                        with gr.Column():
                            # gr.Markdown("## Art-Free Generation")
                            # gr.Markdown("Generate images from text prompts.")
                            gallery_stylization_art = gr.Image(
                                label="W/ Adapter",
                                show_label=True,
                                elem_id="gallery",
                                height="auto",
                                scale=1,
                            )
                    start_timestep = gr.Slider(label="Timestep start from:", minimum=0, maximum=1000, value=800, step=1)
            with gr.Row():
                btn_style_ori = gr.Button("Art-Free Stylize", scale=1)
                btn_style_art = gr.Button("Artistic Stylize", scale=1)


        with gr.Row():
            # with gr.Column():
            # samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1, scale=1)
            scale = gr.Slider(
                label="Guidance Scale", minimum=0, maximum=20, value=7.5, step=0.1
            )
            # with gr.Column():
            adapter_choice = gr.Dropdown(
                label="Select Art Adapter",
                choices=[ "Andre Derain (fauvism)","Vincent van Gogh (post impressionism)","Andy Warhol (pop art)",
                          "Camille Corot (realism)", "Claude Monet (impressionism)", "Pablo Picasso (cubism)", "Gerhard Richter (abstract expressionism)",
                          "Hokusai (ukiyo-e)", "Gustav Klimt (art nouveau)", "Henri Matisse (abstract expressionism)",
                          "Walter Battiss", "Jackson Pollock",  "M.C. Escher", "Albert Gleizes",  "Wassily Kandinsky",
                          "Roy Lichtenstein", "Joan Miro"
                          ],
                value="Andre Derain (fauvism)",
                scale=1
            )

        with gr.Row():
            steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1)
            adapter_scale = gr.Slider(label="Adapter Scale", minimum=0, maximum=1.5, value=1., step=0.1, scale=1)

        with gr.Row():
            seed = gr.Slider(label="Seed",minimum=0,maximum=2147483647,step=1,randomize=True,scale=1)


        gr.on([btn_gen_ori.click], demo_inference_gen_ori, inputs=[text, seed, steps, scale], outputs=gallery_gen_ori)
        gr.on([btn_gen_art.click], demo_inference_gen_artistic, inputs=[adapter_choice, text, seed, steps, scale, adapter_scale], outputs=gallery_gen_art)

        gr.on([btn_style_ori.click], demo_inference_stylization_ori, inputs=[gallery_stylization_ref, text, seed, steps, scale, start_timestep], outputs=gallery_stylization_ori)
        gr.on([btn_style_art.click], demo_inference_stylization_artistic, inputs=[gallery_stylization_ref, adapter_choice, text, seed, steps, scale, adapter_scale, start_timestep], outputs=gallery_stylization_art)

    examples = gr.Examples(
        examples=[
            ["Snow-covered trees with sunlight shining through",
             "data/Snow-covered_trees_with_sunlight_shining_through.jpg",
             0,
             ],
            ["A picturesque landscape showcasing a winding river cutting through a lush green valley, surrounded by rugged mountains under a clear blue sky. The mix of red and brown tones in the rocky hills adds to the region's natural beauty and diversity.",
             "data/0011772.jpg",
             528741066,
             ],
            ["A black SUV driving down a highway with a scenic view of mountains and water in the background. The SUV is the main focus of the image, and it appears to be traveling at a moderate speed. The road is well-maintained and provides a smooth driving experience. The mountains and water create a picturesque backdrop, adding to the overall beauty of the scene. The image captures the essence of a leisurely road trip, with the SUV as the primary subject, highlighting the sense of adventure and exploration that comes with such journeys.",
             "data/a_black_SUV_driving_down_a_highway_with_a_scenic_view_of_mountains_and_water_in_the_background._The_.jpg",
             98762568,
             ],
            ["A beautiful garden with a large pond. The pond is surrounded by a wooden deck, and there are several chairs placed around the area. A stone fountain is present in the middle of the pond, adding to the serene atmosphere. The garden is decorated with a variety of potted plants, creating a lush and inviting environment. The scene is captured in a vibrant and colorful style, highlighting the natural beauty of the garden.",
             "data/a_beautiful_garden_with_a_large_pond._The_pond_is_surrounded_by_a_wooden_deck,_and_there_are_several.jpg",
             76265772,
             ],
            [
                "A blue bench situated in a park, surrounded by trees and leaves. The bench is positioned under a tree, providing shade and a peaceful atmosphere. There are several benches in the park, with one being closer to the foreground and the others further in the background. A person can be seen in the distance, possibly enjoying the park or taking a walk. The overall scene is serene and inviting, with the bench serving as a focal point in the park's landscape.",
                "data/003904765.jpg",
                3904764,
            ]

        ],
        inputs=[
            text,
            gallery_stylization_ref,
            seed,
            adapter_choice,
            steps,
            scale,
            adapter_scale,
            start_timestep,
        ],
        fn=demo_inference_all,
        outputs=[gallery_gen_ori, gallery_gen_art, gallery_stylization_ori, gallery_stylization_art],
        cache_examples=True,
    )

block.launch()
# block.launch(sharing=True)