import os
import gradio as gr
import torch
import yaml
import numpy as np

from munch import munchify
import torchvision.transforms as transforms
from torchvision.transforms import functional as F
from diffusers import (
    AutoPipelineForInpainting,
)
from generate_dataset import outpainting_generator_rectangle, merge_images_horizontally
from ddim_with_prob import DDIMSchedulerCustom

transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((512, 512), interpolation=F.InterpolationMode.LANCZOS),
    ])

def pref_inpainting(image,
                    box_width_ratio,
                    mask_random_start,
                    steps,
                    ):
    with open("./configs/paintreward_train_configs.yaml") as file:
        config_dict= yaml.safe_load(file)
        config = munchify(config_dict)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    
    pipe_ours = AutoPipelineForInpainting.from_pretrained(
                './model_ckpt', torch_dtype=torch.float16, variant='fp16')
    pipe_ours.scheduler = DDIMSchedulerCustom.from_config(pipe_ours.scheduler.config)

    pipe_runway = AutoPipelineForInpainting.from_pretrained("runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant='fp16')
    
    
    pipe_ours = pipe_ours.to(device)
    pipe_runway = pipe_runway.to(device)
    print('Loading pipeline')
   
    color, mask = outpainting_generator_rectangle(image, box_width_ratio/100, mask_random_start)
    mask = mask.convert('L')
   
    color, mask = np.array(color).transpose(2, 0, 1), np.array(mask)
    mask = mask[None, ...]
    mask_ = np.zeros_like(mask)
    mask_[mask < 125] = 0
    mask_[mask >= 125] = 1
    
    color = torch.from_numpy(color).to(device)
    mask = torch.from_numpy(mask).to(device)


    color, mask = transform(color), transform(mask)
    res_ours = pipe_ours(prompt='', image=color, mask_image=mask, eta=config.eta).images[0]
    print('Running inference ours')
    res_runway = pipe_runway(prompt="", image=color, mask_image=mask).images[0]
    print('Running inference runway')
    
    # res.save(os.path.join('./',  'test.png'))
    res_ours = merge_images_horizontally(color, res_ours)
    res_runway = merge_images_horizontally(color, res_runway)

    return res_ours, res_runway


inputs = [
    gr.Image(type="pil", image_mode="RGBA", label='Input Image'), # shape=[512, 512]
    gr.Slider(30, 45, value=35, step=1, label="box_width_ratio"),
    gr.Slider(0, 256, value=125, step=1, label="mask_random_start"),
    gr.Slider(30, 100, value=50, step=5, label="steps"),
]

outputs = [
    gr.Image(type="pil", image_mode="RGBA", label='PrefPaint', container=True, width="100%"),
    gr.Image(type="pil", image_mode="RGBA", label='RunwayPaint', container=True, width="100%"), 
]

files = os.listdir("./assets")
examples = [
    [f"./assets/{file_name}", 35, 125, 50] for file_name in files
]


with gr.Blocks() as demo:
   

    iface = gr.Interface(
        fn=pref_inpainting,
        inputs=inputs,
        outputs=outputs,
        title="Inpainting with Human Preference (Utilizing Free CPU Resources)",
        description="Upload an image and start your inpainting (currently only supporting outpainting masks; other mask types coming soon).",
        theme="default",
        examples=examples,
        # allow_flagging="never"
    )
   

    # iface.launch()

demo.launch()