import gradio as gr
import numpy as np
import random

import spaces
from pipeline_flux import FluxPipeline
from transformer_flux import FluxTransformer2DModel
import torch

flux_model = "schnell"
bfl_repo = f"black-forest-labs/FLUX.1-{flux_model}"

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16

transformer = FluxTransformer2DModel.from_pretrained(
    bfl_repo, subfolder="transformer", torch_dtype=dtype
)
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.scheduler.config.use_dynamic_shifting = False
pipe.scheduler.config.time_shift = 10
# pipe.enable_model_cpu_offload()
pipe = pipe.to(device)

pipe.load_lora_weights(
    "Huage001/URAE",
    weight_name="urae_2k_adapter.safetensors",
    adapter_name="2k",
)
pipe.load_lora_weights(
    "Huage001/URAE",
    weight_name="urae_4k_adapter_lora_conversion_dev.safetensors",
    adapter_name="4k_dev",
)
pipe.load_lora_weights(
    "Huage001/URAE",
    weight_name="urae_4k_adapter_lora_conversion_schnell.safetensors",
    adapter_name="4k_schnell",
)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 4096
USE_ZERO_GPU = True


# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
    prompt,
    seed,
    randomize_seed,
    width,
    height,
    num_inference_steps,
    model='2k',
    progress=gr.Progress(track_tqdm=True),
):
    print("Using model:", model)
    if model == "2k":
        pipe.vae.enable_tiling(True)
        pipe.set_adapters("2k")
    elif model == "4k":
        pipe.vae.enable_tiling(True)
        pipe.set_adapters(f"4k_{flux_model}")

    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    generator = torch.Generator().manual_seed(seed)

    image = pipe(
        prompt=prompt,
        guidance_scale=0,
        num_inference_steps=num_inference_steps,
        width=width,
        height=height,
        max_sequence_length=256,
        ntk_factor=10,
        proportional_attention=True,
        generator=generator,
    ).images[0]

    return image, seed


if USE_ZERO_GPU:
    infer = spaces.GPU(infer, duration=360)

examples = [
    "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
    "An astronaut riding a green horse",
    "A delicious ceviche cheesecake slice",
]

css = """
#maincontainer {
    display: flex;
}

#col1 {
    margin: 0 auto;
    max-width: 50%;
}
#col2 {
    margin: 0 auto;
    # max-width: 40px;
}
"""

head = """> ***U*ltra-*R*esolution *A*daptation with *E*ase**

<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
<a href="https://arxiv.org/abs/2503.16322"><img src="https://img.shields.io/badge/arXiv-2503.16322-A42C25.svg" alt="arXiv"></a> 
<a href="https://huggingface.co/Huage001/URAE"><img src="https://img.shields.io/badge/🤗_HuggingFace-Model-ffbd45.svg" alt="HuggingFace"></a>
<a href="https://huggingface.co/spaces/Yuanshi/URAE"><img src="https://img.shields.io/badge/🤗_HuggingFace-Space-ffbd45.svg" alt="HuggingFace"></a>
<a href="https://huggingface.co/spaces/Yuanshi/URAE_dev"><img src="https://img.shields.io/badge/🤗_HuggingFace-Space-ffbd45.svg" alt="HuggingFace"></a>
</div>
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("# URAE (FLUX.1 schnell) \n" + head)
    with gr.Row(elem_id="maincontainer"):
        with gr.Column(elem_id="col1"):
            gr.Markdown("### Prompt:")
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )

            gr.Examples(examples=examples, inputs=[prompt])
            run_button = gr.Button("Generate", scale=1, variant="primary")

            gr.Markdown("### Setting:")

            # model = gr.Radio(
            #     label="Model",
            #     choices=[
            #         ("2K model", "2k"),
            #         ("4K model (beta)", "4k"),
            #     ],
            #     value="2k",
            # )

            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=2048,  # Replace with defaults that work for your model
                )

                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=2048,  # Replace with defaults that work for your model
                )

            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )

            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)

            num_inference_steps = gr.Slider(
                label="Number of inference steps",
                minimum=1,
                maximum=50,
                step=1,
                value=4,  # Replace with defaults that work for your model
            )

        with gr.Column(elem_id="col2"):
            result = gr.Image(label="Result", show_label=False)

    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=infer,
        inputs=[
            prompt,
            # model,
            seed,
            randomize_seed,
            width,
            height,
            num_inference_steps,
        ],
        outputs=[result, seed],
    )

if __name__ == "__main__":
    demo.launch()