#!/usr/bin/env python
from __future__ import annotations

import os
import random
import time

import gradio as gr
import numpy as np
import PIL.Image

from huggingface_hub import snapshot_download
from diffusers import DiffusionPipeline

from optimum.intel.openvino.modeling_diffusion import OVModelVaeDecoder, OVBaseModel, OVStableDiffusionPipeline

import os
from tqdm import tqdm
import gradio_user_history as gr_user_history

from concurrent.futures import ThreadPoolExecutor
import uuid

DESCRIPTION = '''# Latent Consistency Model OpenVINO CPU TAESD
Based on [Latency Consistency Model OpenVINO CPU](https://huggingface.co/spaces/deinferno/Latent_Consistency_Model_OpenVino_CPU) HF space 

Converted from [SoteMix](https://huggingface.co/Disty0/SoteMix) to [LCM_SoteMix](https://huggingface.co/Disty0/LCM_SoteMix) and then to OpenVINO

This model is for Anime art style.

Slower but higher quality version with Full VAE: [LCM_SoteMix_OpenVINO_CPU_Space](https://huggingface.co/spaces/Disty0/LCM_SoteMix_OpenVINO_CPU_Space)

[LCM Project page](https://latent-consistency-models.github.io)

<p>Running on a Dual Core CPU with OpenVINO Acceleration</p>
'''

MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1"

model_id = "Disty0/LCM_SoteMix"
batch_size = -1
width = int(os.getenv("IMAGE_WIDTH", "512"))
height = int(os.getenv("IMAGE_HEIGHT", "512"))
num_images = int(os.getenv("NUM_IMAGES", "1"))
guidance_scale = float(os.getenv("GUIDANCE_SCALE", "1.0"))

class CustomOVModelVaeDecoder(OVModelVaeDecoder):
    def __init__(
        self, model: openvino.runtime.Model, parent_model: OVBaseModel, ov_config: Optional[Dict[str, str]] = None, model_dir: str = None,
    ):
        super(OVModelVaeDecoder, self).__init__(model, parent_model, ov_config, "vae_decoder", model_dir)

pipe = OVStableDiffusionPipeline.from_pretrained(model_id, compile = False, ov_config = {"CACHE_DIR":""})

# Inject TAESD

taesd_dir = snapshot_download(repo_id="deinferno/taesd-openvino")
pipe.vae_decoder = CustomOVModelVaeDecoder(model = OVBaseModel.load_model(f"{taesd_dir}/vae_decoder/openvino_model.xml"), parent_model = pipe, model_dir = taesd_dir)

pipe.reshape(batch_size=batch_size, height=height, width=width, num_images_per_prompt=num_images)
pipe.compile()

def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    return seed

def save_image(img, profile: gr.OAuthProfile | None, metadata: dict):
    unique_name = str(uuid.uuid4()) + '.png'
    img.save(unique_name)
    gr_user_history.save_image(label=metadata["prompt"], image=img, profile=profile, metadata=metadata)
    return unique_name

def save_images(image_array, profile: gr.OAuthProfile | None, metadata: dict):
    paths = []
    with ThreadPoolExecutor() as executor:
        paths = list(executor.map(save_image, image_array, [profile]*len(image_array), [metadata]*len(image_array)))
    return paths

def generate(
    prompt: str,
    negative_prompt: str,
    seed: int = 0,
    num_inference_steps: int = 4,
    randomize_seed: bool = False,
    progress = gr.Progress(track_tqdm=True),
    profile: gr.OAuthProfile | None = None,
) -> PIL.Image.Image:
    global batch_size
    global width
    global height
    global num_images
    global guidance_scale

    seed = randomize_seed_fn(seed, randomize_seed)
    np.random.seed(seed)
    start_time = time.time()
    result = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=width,
        height=height,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        num_images_per_prompt=num_images,
        output_type="pil",
    ).images
    paths = save_images(result, profile, metadata={"prompt": prompt, "seed": seed, "width": width, "height": height, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps})
    print(time.time() - start_time)
    return paths, seed

examples = [
    "masterpiece, best quality, highres, 1girl, solo,",
    "masterpiece, best quality, highres, 1girl, solo, pov, scenery, wind, petals, rim lighting, shrine, lens flare, light scatter, depth of field, lens refraction,",
    "masterpiece, best quality, highres, 1girl, solo, scenery, wind, petals, rim lighting, shrine, lens flare, light scatter, depth of field, lens refraction, dark red hair, long hair, blue eyes, straight hair, cat ears, medium breasts, mature female, white sweater,",
    "masterpiece, best quality, highres, 1girl, solo, supernova, abstract, abstract background, bloom, swirling lights, light particles, fire, galaxy, floating, romanticized, blush, looking at viewer, dark red hair, long hair, blue eyes, straight hair, cat ears, medium breasts, mature female, white sweater,",
]

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(
        value="Duplicate Space for private use",
        elem_id="duplicate-button",
        visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
    )
    with gr.Group():
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                value="masterpiece, best quality, highres, 1girl, solo,",
                container=False,
            )
            run_button = gr.Button("Run", scale=0)
        result = gr.Gallery(
            label="Generated images", show_label=False, elem_id="gallery"
        )
    with gr.Accordion("Advanced options", open=False):
        with gr.Row():
            negative_prompt = gr.Text(
	        label="Negative Prompt",
	        max_lines=1,
	        value="worst quality, low quality, lowres, monochrome, realistic,",
	    )
        seed = gr.Slider(
            label="Seed",
            minimum=0,
            maximum=MAX_SEED,
            step=1,
            value=0,
            randomize=True
        )
        randomize_seed = gr.Checkbox(label="Randomize seed across runs", value=True)
        with gr.Row():
            num_inference_steps = gr.Slider(
                label="Number of inference steps for base",
                minimum=1,
                maximum=8,
                step=1,
                value=4,
            )

    with gr.Accordion("Past generations", open=False):
        gr_user_history.render()
    
    gr.Examples(
        examples=examples,
        inputs=prompt,
        outputs=result,
        fn=generate,
        cache_examples=CACHE_EXAMPLES,
    )

    gr.on(
        triggers=[
            prompt.submit,
            run_button.click,
        ],
        fn=generate,
        inputs=[
            prompt,
            negative_prompt,
            seed,
            num_inference_steps,
            randomize_seed
        ],
        outputs=[result, seed],
        api_name="run",
    )

if __name__ == "__main__":
    demo.queue(api_open=False)
    # demo.queue(max_size=20).launch()
    demo.launch()