import spaces
import random
import gradio as gr
import numpy as np
import torch
from PIL import Image

def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
if torch.cuda.is_available():
    device = "cuda:0"
else:
    device = "cpu"

### PeRFlow-T2I
from diffusers import StableDiffusionXLPipeline
pipe = StableDiffusionXLPipeline.from_pretrained("hansyan/perflow-sdxl-dreamshaper", torch_dtype=torch.float16, use_safetensors=True, variant="v0-fix")
from src.scheduler_perflow import PeRFlowScheduler
pipe.scheduler = PeRFlowScheduler.from_config(pipe.scheduler.config, prediction_type="ddim_eps", num_time_windows=4)
pipe.to("cuda:0", torch.float16)
# pipe_t2i = None


### gradio
@spaces.GPU
def generate(text, num_inference_steps, cfg_scale, seed):
    setup_seed(int(seed))
    num_inference_steps = int(num_inference_steps)
    cfg_scale = float(cfg_scale)
    
    prompt_prefix = "photorealistic, uhd, high resolution, high quality, highly detailed; "
    neg_prompt = "distorted, blur, low-quality, haze, out of focus"
    text = prompt_prefix + text
    samples = pipe(
            prompt              = [text],
            negative_prompt     = [neg_prompt],
            height              = 1024,
            width               = 1024,
            num_inference_steps = num_inference_steps,
            guidance_scale      = cfg_scale,
            output_type         = 'pt',
        ).images
    samples = samples.squeeze(0).permute(1, 2, 0).cpu().numpy()*255.
    samples = samples.astype(np.uint8)
    samples = Image.fromarray(samples[:, :, :3])
    return samples


# layout
css = """
h1 {
    text-align: center;
    display:block;
}
h2 {
    text-align: center;
    display:block;
}
h3 {
    text-align: center;
    display:block;
}
.gradio-container {
  max-width: 768px !important;
}
"""
with gr.Blocks(title="PeRFlow-SDXL", css=css) as interface:
    gr.Markdown(
    """
    # PeRFlow-SDXL

    GitHub: [https://github.com/magic-research/piecewise-rectified-flow](https://github.com/magic-research/piecewise-rectified-flow) <br/>
    Models: [https://huggingface.co/hansyan/perflow-sdxl-dreamshaper](https://huggingface.co/hansyan/perflow-sdxl-dreamshaper) 
    
    <br/>
    """
    )
    
    with gr.Column():
        text = gr.Textbox(
            label="Input Prompt", 
            value="masterpiece, A closeup face photo of girl, wearing a rain coat, in the street, heavy rain, bokeh"
        )
        with gr.Row():
            num_inference_steps = gr.Dropdown(label='Num Inference Steps',choices=[4,5,6,7,8], value=6, interactive=True)
            cfg_scale = gr.Dropdown(label='CFG scale',choices=[1.5, 2.0, 2.5], value=2.0, interactive=True)
            seed = gr.Textbox(label="Random Seed", value=42)
            submit = gr.Button(scale=1, variant='primary')
    
    # with gr.Column():
    # with gr.Row():
    output_image = gr.Image(label='Generated Image')

    example_inputs = [
        ["masterpiece, A closeup face photo of girl, wearing a rain coat, in the street, heavy rain, bokeh", 6, 2.0, 42],
        ["RAW photo, a handsome man, wearing a black coat, outside, closeup face", 5, 1.5, 35],
        ["RAW photo, a red luxury car, studio light", 7, 2.5, 30],
        ["masterpiece, A beautiful cat bask in the sun", 4, 2.0, 25]
    ]
    examples = gr.Examples(examples=example_inputs, inputs=[text, num_inference_steps, cfg_scale, seed], outputs=[output_image])

    gr.Markdown(
    """
    Here are some examples provided:
    - “masterpiece, A closeup face photo of girl, wearing a rain coat, in the street, heavy rain, bokeh”
    - “RAW photo, a handsome man, wearing a black coat, outside, closeup face”
    - “RAW photo, a red luxury car, studio light”
    - “masterpiece, A beautiful cat bask in the sun”
    """
    )
    
    # activate
    text.submit(
        fn=generate,
        inputs=[text, num_inference_steps, cfg_scale, seed],
        outputs=[output_image],
    )
    seed.submit(
        fn=generate,
        inputs=[text, num_inference_steps, cfg_scale, seed],
        outputs=[output_image],
    )
    submit.click(fn=generate,
        inputs=[text, num_inference_steps, cfg_scale, seed],
        outputs=[output_image],
    )



if __name__ == '__main__':
    interface.queue(max_size=10)
    interface.launch()