import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
from PIL import Image
from io import BytesIO
import base64
import re
import os
import requests


model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"

pipe = StableDiffusionPipeline.from_pretrained(
    model_id, use_auth_token=True, revision="fp16", torch_dtype=torch.float16
)
pipe = pipe.to(device)
torch.backends.cudnn.benchmark = True

is_gpu_busy = False


@torch.no_grad()
def image_generation(prompt, samples=4, steps=25, scale=7.5):
    global is_gpu_busy

    images = []
    if is_gpu_busy:
        return images

    is_gpu_busy = True
    with autocast("cuda"):
        images = pipe(
            [prompt] * samples,
            num_inference_steps=steps,
            guidance_scale=scale,
        ).images
        is_gpu_busy = False

    return images


with gr.Blocks() as demo:
    gr.Markdown("# Stable Diffusion demo\nType something and generate images!")
    textbox = gr.Textbox(placeholder="Something cool...", interactive=True)
    with gr.Column(scale=1):
        samples = gr.Slider(minimum=1, maximum=8, value=4, step=1, label="samples")
        steps = gr.Slider(minimum=10, maximum=50, value=25, step=1, label="steps")
        scale = gr.Slider(minimum=5, maximum=15, value=7.5,step=0.1, label="scale")
    submit = gr.Button("Submit", variant="primary")
    gr.Markdown("Images will appear below")
    with gr.Row():
        gallery = gr.Gallery()
    textbox.submit(image_generation, inputs=[textbox, samples, steps, scale], outputs=[gallery])
    submit.click(image_generation, inputs=[textbox, samples, steps, scale], outputs=[gallery])
        
demo.launch()