import gradio as gr
import numpy as np
import random
import torch
import spaces
from PIL import Image
import os
from huggingface_hub import hf_hub_download
import torch
from diffusers import DiffusionPipeline
from huggingface_hub import hf_hub_download

# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev",
                                         custom_pipeline="pipeline_flux_rf_inversion",
                                          torch_dtype=torch.bfloat16)

#pipe.enable_lora()
pipe.to(DEVICE)

def get_examples():
    examples = [[Image.open("metal.png"), "a dragon, in 3d melting gold metal",0.9, 0.5, 0, 5, 28, 28,  0, False,False, 2,  False,Image.open("dragon.png") ],        
               [Image.open("doll.png"), "anime illustration",0.9, 0.5, 0, 6, 28, 28,  0, False, False, 2,  False ,Image.open("anime.png")],
               [Image.open("doll.png"), "raccoon, made of yarn",0.9, 0.5, 0, 4, 28, 28,  0, False, False, 2,  False , Image.open("raccoon.png")],
               [Image.open("cat.jpg"), "a parrot", 0.9 ,0.5,2, 8,28, 28,0, False ,  False, 1,  False,Image.open("parrot.png")],
               [Image.open("cat.jpg"), "a tiger", 0.9 ,0.5,0, 4,8, 8,789385745, False ,  False, 1,  True,Image.open("tiger.png")],
               [Image.open("metal.png"),"a dragon, in 3d melting gold metal",0.9, 0.5, 0, 4, 8, 8,  789385745, False,True, 2,  True , Image.open("dragon.png")],
               ]
    return examples
def reset_do_inversion():
        return True
    
def resize_img(image, max_size=1024):
    width, height = image.size
    scaling_factor = min(max_size / width, max_size / height)
    new_width = int(width * scaling_factor)
    new_height = int(height * scaling_factor)
    return image.resize((new_width, new_height), Image.LANCZOS)

def check_style(stylezation, enable_hyper_flux):
    if stylezation:
        return 0.9, 0.5, 0, 6, 28, 28, False,False
    else:
        if enable_hyper_flux:
            return 0.9, 0.5, 0, 4, 8, 8, False,False
        else:
            return 0.9, 0.5, 2, 7, 28, 28, False,False

def check_hyper_flux_lora(enable_hyper_flux):
    if enable_hyper_flux:
        pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), lora_scale=0.125)
        pipe.fuse_lora(lora_scale=0.125)
        return 8, 8, 4
    else:
        pipe.unfuse_lora()
        return 28, 28, 6

@spaces.GPU(duration=85)
def invert_and_edit(image, 
                    prompt, 
                    eta, 
                    gamma, 
                    start_timestep, 
                    stop_timestep, 
                    num_inversion_steps,
                    num_inference_steps,
                    seed,
                    randomize_seed,
                    eta_decay,
                    decay_power,
                    width = 1024,
                    height = 1024,
                    inverted_latents = None,
                    image_latents = None,
                    latent_image_ids = None,
                    do_inversion = True,
                    
                   ):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    if do_inversion:
        inverted_latents, image_latents, latent_image_ids = pipe.invert(image, num_inversion_steps=num_inversion_steps, gamma=gamma)
        do_inversion = False
    
    
    output = pipe(prompt, 
    inverted_latents = inverted_latents.to(DEVICE),
    image_latents = image_latents.to(DEVICE),
    latent_image_ids = latent_image_ids.to(DEVICE),
    start_timestep = start_timestep/num_inference_steps, 
    stop_timestep = stop_timestep/num_inference_steps,
    num_inference_steps = num_inference_steps,
    eta=eta,
    decay_eta = eta_decay,
    eta_decay_power = decay_power,
    ).images[0]
        
    return output, inverted_latents.cpu(), image_latents.cpu(), latent_image_ids.cpu(), do_inversion, seed

# UI CSS
css = """
#col-container {
    margin: 0 auto;
    max-width: 960px;
}
"""

# Create the Gradio interface
with gr.Blocks(css=css) as demo:

    inverted_latents = gr.State()
    image_latents = gr.State()
    latent_image_ids = gr.State()
    do_inversion = gr.State(False)
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# RF inversion 🖌️🏞️
### Edit real images with FLUX.1 [dev]
following the algorithm proposed in [*Semantic Image Inversion and Editing using
Stochastic Rectified Differential Equations* by Rout et al.](https://rf-inversion.github.io/data/rf-inversion.pdf)

based on the implementations of [@raven38](https://github.com/raven38) & [@DarkMnDragon](https://github.com/DarkMnDragon) 🙌🏻

[[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[project page](https://rf-inversion.github.io/) [[arxiv](https://arxiv.org/pdf/2410.10792)]
        """)
        
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(
                    label="Input Image",
                    type="pil"
                )
                prompt = gr.Text(
                    label="Edit Prompt",
                    max_lines=1,
                    placeholder="describe the edited output",
                )
                with gr.Row():
                    enable_hyper_flux = gr.Checkbox(label="8-step LoRA", value=False, info="")
                    stylezation = gr.Checkbox(label="stylzation")
                with gr.Row():          
                    start_timestep = gr.Slider(
                    label="start timestep",
                    info = "increase to enhance fidelity, decrease to enhance realism",
                    minimum=0,
                    maximum=28,
                    step=1,
                    value=0,
                )
                    stop_timestep = gr.Slider(
                        label="stop timestep",
                        info = "increase to enhace fidelity to original image",
                        minimum=0,
                        maximum=28,
                        step=1,
                        value=6,
                    )
                    eta = gr.Slider(
                        label="eta",
                        info = "lower eta to ehnace the edits",
                        minimum=0.0,
                        maximum=1.0,
                        step=0.01,
                        value=0.9,
                    )
                
                run_button = gr.Button("Edit", variant="primary")
            
            with gr.Column():
                result = gr.Image(label="Result")
        
        with gr.Accordion("Advanced Settings", open=False):
            
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=42,
            )
            
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            with gr.Row():
                num_inference_steps = gr.Slider(
                            label="num inference steps",
                            minimum=1,
                            maximum=50,
                            step=1,
                            value=28,
                        )
                eta_decay = gr.Checkbox(label="eta decay", value=False)
                decay_power = gr.Slider(
                            label="eta decay power",
                            minimum=0,
                            maximum=5,
                            step=1,
                            value=1,
                        )

            with gr.Row():
                gamma = gr.Slider(
                    label="gamma",
                    info = "increase gamma to enhance realism",
                    minimum=0.0,
                    maximum=1.0,
                    step=0.01,
                    value=0.5,
                )
                num_inversion_steps = gr.Slider(
                        label="num inversion steps",
                        minimum=1,
                        maximum=50,
                        step=1,
                        value=28,
                    )
            
            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
                
                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
            
                
    
    run_button.click(
        fn=invert_and_edit,
        inputs=[
            input_image,
            prompt,
            eta, 
            gamma, 
            start_timestep, 
            stop_timestep, 
            num_inversion_steps,
            num_inference_steps,
            seed,
            randomize_seed,
            eta_decay,
            decay_power,
            width,
            height,
            inverted_latents,
            image_latents,
            latent_image_ids,
            do_inversion
            
        ],
        outputs=[result, inverted_latents, image_latents, latent_image_ids, do_inversion, seed],
    )

    gr.Examples(
                examples=get_examples,
                inputs=[input_image, prompt,eta,gamma,start_timestep, stop_timestep, num_inversion_steps, num_inference_steps,  seed, randomize_seed, eta_decay, decay_power, enable_hyper_flux ]
                outputs=[result],
               
            )

    input_image.change(
        fn=reset_do_inversion,
        outputs=[do_inversion]
    )

    num_inversion_steps.change(
        fn=reset_do_inversion,
        outputs=[do_inversion]
    )

    seed.change(
        fn=reset_do_inversion,
        outputs=[do_inversion]
    )

    stylezation.change(
        fn=check_style,
        inputs=[stylezation],
        outputs=[eta_decay, decay_power]
    )

    enable_hyper_flux.change(
        fn=check_hyper_flux_lora,
        inputs=[enable_hyper_flux],
        outputs=[num_inversion_steps, num_inference_steps, stop_timestep]
    )
    

    

if __name__ == "__main__":
    demo.launch()