import gradio as gr
import numpy as np
import random
import torch
import spaces
from PIL import Image
import os
from huggingface_hub import hf_hub_download
import torch
from diffusers import DiffusionPipeline
from huggingface_hub import hf_hub_download

# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev",
                                         custom_pipeline="pipeline_flux_rf_inversion",
                                          torch_dtype=torch.bfloat16)
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), lora_scale=0.125)
pipe.fuse_lora(lora_scale=0.125)
pipe.to(DEVICE)

examples = [[Image.open("cat.jpg"), "a tiger", 0, 0.7, 0.5, 0.9, 8, 8, 789385745, False]]
def reset_do_inversion():
        return True
    
def resize_img(image, max_size=1024):
    width, height = image.size
    scaling_factor = min(max_size / width, max_size / height)
    new_width = int(width * scaling_factor)
    new_height = int(height * scaling_factor)
    return image.resize((new_width, new_height), Image.LANCZOS)

def check_style(stylezation, eta, eta_decay, decay_power, start_timestep, stop_timestep):
    return eta, eta_decay, decay_power, start_timestep, stop_timestep

@spaces.GPU(duration=85)
def invert_and_edit(image, 
                    prompt, 
                    eta, 
                    gamma, 
                    start_timestep, 
                    stop_timestep, 
                    num_inversion_steps,
                    num_inference_steps,
                    width,
                    height,
                    inverted_latents,
                    image_latents,
                    latent_image_ids,
                    do_inversion,
                    seed,
                    randomize_seed,
                    eta_decay,
                    decay_power,
                   ):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    if do_inversion:
        inverted_latents, image_latents, latent_image_ids = pipe.invert(image, num_inversion_steps=num_inversion_steps, gamma=gamma)
        do_inversion = False
    
    
    output = pipe(prompt, 
    inverted_latents = inverted_latents.to(DEVICE),
    image_latents = image_latents.to(DEVICE),
    latent_image_ids = latent_image_ids.to(DEVICE),
    start_timestep = start_timestep/num_inference_steps, 
    stop_timestep = stop_timestep/num_inference_steps,
    num_inference_steps = num_inference_steps,
    eta=eta,
    decay_eta = eta_decay,
    eta_decay_power = decay_power,
    ).images[0]
        
    return output, inverted_latents.cpu(), image_latents.cpu(), latent_image_ids.cpu(), do_inversion, seed

# UI CSS
css = """
#col-container {
    margin: 0 auto;
    max-width: 960px;
}
"""

# Create the Gradio interface
with gr.Blocks(css=css) as demo:

    inverted_latents = gr.State()
    image_latents = gr.State()
    latent_image_ids = gr.State()
    do_inversion = gr.State(False)
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# RF inversion 🖌️🏞️
### Edit real images with FLUX.1 [dev]
following the algorithm proposed in [*Semantic Image Inversion and Editing using
Stochastic Rectified Differential Equations* by Rout et al.](https://rf-inversion.github.io/data/rf-inversion.pdf)

based on the implementations of [@raven38](https://github.com/raven38) & [@DarkMnDragon](https://github.com/DarkMnDragon) 🙌🏻

[[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[project page](https://rf-inversion.github.io/) [[arxiv](https://arxiv.org/pdf/2410.10792)]
        """)
        
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(
                    label="Input Image",
                    type="pil"
                )
                prompt = gr.Text(
                    label="Edit Prompt",
                    max_lines=1,
                    placeholder="describe the edited output",
                )
                stylezation = gr.Checkbox(label="stylzation", value=True)
                with gr.Row():          
                    start_timestep = gr.Slider(
                    label="start timestep",
                    info = "decrease to enhace fidelity to original image",
                    minimum=0,
                    maximum=28,
                    step=1,
                    value=0,
                )
                    stop_timestep = gr.Slider(
                        label="stop timestep",
                        info = "increase to enhace fidelity to original image",
                        minimum=0,
                        maximum=28,
                        step=1,
                        value=4,
                    )
                    eta = gr.Slider(
                        label="eta",
                        info = "lower eta to ehnace the edits",
                        minimum=0.0,
                        maximum=1.0,
                        step=0.01,
                        value=0.9,
                    )
                
                run_button = gr.Button("Edit", variant="primary")
            
            with gr.Column():
                result = gr.Image(label="Result")
        
        with gr.Accordion("Advanced Settings", open=False):
            
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=42,
            )
            
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            with gr.Row():
                num_inference_steps = gr.Slider(
                            label="num inference steps",
                            minimum=1,
                            maximum=50,
                            step=1,
                            value=8,
                        )
                eta_decay = gr.Checkbox(label="eta decay", value=False)
                decay_power = gr.Slider(
                            label="eta decay power",
                            minimum=0,
                            maximum=5,
                            step=1,
                            value=1,
                        )

            with gr.Row():
                gamma = gr.Slider(
                    label="gamma",
                    info = "increase gamma to enhance realism",
                    minimum=0.0,
                    maximum=1.0,
                    step=0.01,
                    value=0.5,
                )
                num_inversion_steps = gr.Slider(
                        label="num inversion steps",
                        minimum=1,
                        maximum=50,
                        step=1,
                        value=8,
                    )
            
            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
                
                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
            
                
    
    run_button.click(
        fn=invert_and_edit,
        inputs=[
            input_image,
            prompt,
            eta, 
            gamma, 
            start_timestep, 
            stop_timestep, 
            num_inversion_steps,
            num_inference_steps,
            width,
            height,
            inverted_latents,
            image_latents,
            latent_image_ids,
            do_inversion,
            seed,
            randomize_seed,
            eta_decay,
            decay_power,
        ],
        outputs=[result, inverted_latents, image_latents, latent_image_ids, do_inversion, seed],
    )

    gr.Examples(
                examples=examples,
                inputs=[input_image, prompt,start_timestep, stop_timestep, gamma, eta, num_inversion_steps, num_inference_steps, seed, randomize_seed ],
                outputs=[result, inverted_latents, image_latents, latent_image_ids, do_inversion, seed],
                fn=invert_and_edit,
            )

    input_image.change(
        fn=reset_do_inversion,
        outputs=[do_inversion]
    )

    num_inversion_steps.change(
        fn=reset_do_inversion,
        outputs=[do_inversion]
    )

    stylezation.change(
        fn=check_style,
        inputs=[stylezation, eta, eta_decay, decay_power, start_timestep, stop_timestep],
        outputs=[eta, eta_decay, decay_power, start_timestep, stop_timestep]
    )
    

    

if __name__ == "__main__":
    demo.launch()