import gradio as gr
from optimum.intel.openvino import OVStableDiffusionPipeline
from diffusers.training_utils import set_seed
from diffusers import DDPMScheduler, StableDiffusionPipeline
import gc

import subprocess

import time


def create_pipeline(name):
    if name == "svjack/Stable-Diffusion-Pokemon-en": #"valhalla/sd-pokemon-model":
        scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012,
             beta_schedule="scaled_linear", num_train_timesteps=1000)
        pipe = StableDiffusionPipeline.from_pretrained(name, scheduler=scheduler)
        pipe.safety_checker = lambda images, clip_input: (images, False)
    elif name == "OpenVINO/stable-diffusion-pokemons-fp32": #"stable-diffusion-pokemons-valhalla-fp32":
        scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012,
             beta_schedule="scaled_linear", num_train_timesteps=1000)
        pipe = OVStableDiffusionPipeline.from_pretrained(name, compile=False, scheduler=scheduler)
        pipe.reshape(batch_size=1, height=512, width=512, num_images_per_prompt=1)
        pipe.compile()
    else:
        pipe = OVStableDiffusionPipeline.from_pretrained(name, compile=False)
        pipe.reshape(batch_size=1, height=512, width=512, num_images_per_prompt=1)
        pipe.compile()
    return pipe

pipes = {
    "Torch fp32": "svjack/Stable-Diffusion-Pokemon-en", #"valhalla/sd-pokemon-model"
    "OpenVINO fp32": "OpenVINO/stable-diffusion-pokemons-fp32", #"OpenVINO/stable-diffusion-pokemons-valhalla-fp32"
    "OpenVINO 8-bit quantized": "OpenVINO/stable-diffusion-pokemons-quantized-aggressive", #"OpenVINO/stable-diffusion-pokemons-valhalla-quantized-agressive"
    "OpenVINO merged and quantized": "OpenVINO/stable-diffusion-pokemons-tome-quantized-aggressive" #"OpenVINO/stable-diffusion-pokemons-valhalla-tome-quantized-agressive"
}

# prefetch pipelines on start
for v in pipes.values():
    pipe = create_pipeline(v)
    del pipe
    gc.collect()

print((subprocess.check_output("lscpu", shell=True).strip()).decode())

def generate(prompt, option, seed):
    pipe = create_pipeline(pipes[option])
    set_seed(int(seed))
    start_time = time.time()
    if "Torch" in option:
        output = pipe(prompt, num_inference_steps=50, output_type="pil", height=512, width=512)
    else:
        output = pipe(prompt, num_inference_steps=50, output_type="pil")
    elapsed_time = time.time() - start_time
    return (output.images[0], "{:10.4f}".format(elapsed_time))

examples = ["cartoon bird",
            "a drawing of a green pokemon with red eyes",
            "plant pokemon in jungle"]

model_options = [option for option in pipes.keys()]

gr.Interface(
    fn=generate,
    inputs=[gr.inputs.Textbox(default="cartoon bird", label="Prompt", lines=1),
            gr.inputs.Dropdown(choices=model_options, default=model_options[-1], label="Model version"),
            gr.inputs.Textbox(default="42", label="Seed", lines=1)
           ],
    outputs=[gr.outputs.Image(type="pil", label="Generated Image"), gr.outputs.Textbox(label="Inference time")],
    title="OpenVINO-optimized Stable Diffusion",
    description="This is the Optimum-based demo for NNCF-optimized Stable Diffusion pipeline trained on 'lambdalabs/pokemon-blip-captions' dataset and running with OpenVINO.\n"
                 "The pipeline is run using 8 vCPUs (4 cores) only.",
    theme="huggingface",
).launch()