import gradio as gr
import spaces
import torch
import torch.nn.functional as F
from safetensors.numpy import save_file, load_file
from omegaconf import OmegaConf
from transformers import AutoConfig
import cv2
from PIL import Image
import numpy as np
import json
import os
#
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy, StableDiffusionInpaintPipeline, DDIMScheduler, AutoencoderKL
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DDIMScheduler
from diffusers import DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler
from diffusers.image_processor import VaeImageProcessor
#
from models.pipeline_mimicbrush import MimicBrushPipeline
from models.ReferenceNet import ReferenceNet
from models.depth_guider import DepthGuider
from mimicbrush import MimicBrush_RefNet
from data_utils import *
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download


from huggingface_hub import snapshot_download
snapshot_download(repo_id="xichenhku/cleansd", local_dir="./cleansd")
print('=== Pretrained SD weights downloaded ===')
snapshot_download(repo_id="xichenhku/MimicBrush", local_dir="./MimicBrush")
print('=== MimicBrush weights downloaded ===')

#sd_dir = ms_snapshot_download('xichen/cleansd', cache_dir='./modelscope')
#print('=== Pretrained SD weights downloaded ===')
#model_dir = ms_snapshot_download('xichen/MimicBrush', cache_dir='./modelscope')
#print('=== MimicBrush weights downloaded ===')

val_configs = OmegaConf.load('./configs/inference.yaml')

# === import Depth Anything ===
import sys
sys.path.append("./depthanything")
from torchvision.transforms import Compose
from depthanything.fast_import import depth_anything_model 
from depthanything.depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
transform = Compose([
    Resize(
        width=518,
        height=518,
        resize_target=False,
        keep_aspect_ratio=True,
        ensure_multiple_of=14,
        resize_method='lower_bound',
        image_interpolation_method=cv2.INTER_CUBIC,
    ),
    NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    PrepareForNet(),
])
depth_anything_model.load_state_dict(torch.load(val_configs.model_path.depth_model))



# === load the checkpoint ===
base_model_path = val_configs.model_path.pretrained_imitativer_path
vae_model_path = val_configs.model_path.pretrained_vae_name_or_path
image_encoder_path = val_configs.model_path.image_encoder_path
ref_model_path = val_configs.model_path.pretrained_reference_path
mimicbrush_ckpt = val_configs.model_path.mimicbrush_ckpt_path
device = "cuda"



def pad_img_to_square(original_image, is_mask=False):
    width, height = original_image.size
    
    if height == width:
        return original_image
    
    if height > width:
        padding = (height - width) // 2
        new_size = (height, height)
    else:
        padding = (width - height) // 2
        new_size = (width, width)
    
    if is_mask:
        new_image = Image.new("RGB", new_size, "black")
    else:
        new_image = Image.new("RGB", new_size, "white")
    
    if height > width:
        new_image.paste(original_image, (padding, 0))
    else:
        new_image.paste(original_image, (0, padding))
    return new_image


def collage_region(low, high, mask):
    mask = (np.array(mask) > 128).astype(np.uint8)
    low = np.array(low).astype(np.uint8) 
    low = (low * 0).astype(np.uint8) 
    high = np.array(high).astype(np.uint8)
    mask_3 = mask 
    collage = low * mask_3 + high * (1-mask_3)
    collage = Image.fromarray(collage)
    return collage


def resize_image_keep_aspect_ratio(image, target_size = 512):
    height, width = image.shape[:2]
    if height > width:
        new_height = target_size
        new_width = int(width * (target_size / height))
    else:
        new_width = target_size
        new_height = int(height * (target_size / width))
    resized_image = cv2.resize(image, (new_width, new_height))
    return resized_image


def crop_padding_and_resize(ori_image, square_image):
    ori_height, ori_width, _ = ori_image.shape
    scale = max(ori_height / square_image.shape[0], ori_width / square_image.shape[1])
    resized_square_image = cv2.resize(square_image, (int(square_image.shape[1] * scale), int(square_image.shape[0] * scale)))
    padding_size = max(resized_square_image.shape[0] - ori_height, resized_square_image.shape[1] - ori_width)
    if ori_height < ori_width:
        top = padding_size // 2
        bottom = resized_square_image.shape[0] - (padding_size - top)
        cropped_image = resized_square_image[top:bottom, :,:]
    else:
        left = padding_size // 2
        right = resized_square_image.shape[1] - (padding_size - left)
        cropped_image = resized_square_image[:, left:right,:]
    return cropped_image


def vis_mask(image, mask):
    # mask 3 channle 255
    mask = mask[:,:,0]
    mask_contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Draw outlines, using random colors
    outline_opacity = 0.5
    outline_thickness = 5
    outline_color = np.concatenate([ [255,255,255], [outline_opacity]  ])

    white_mask = np.ones_like(image) * 255

    mask_bin_3 = np.stack([mask,mask,mask],-1) > 128
    alpha = 0.5 
    image = ( white_mask * alpha + image * (1-alpha) ) * mask_bin_3 + image * (1-mask_bin_3)
    cv2.polylines(image, mask_contours, True, outline_color, outline_thickness, cv2.LINE_AA)
    return image 



noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)


vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
unet = UNet2DConditionModel.from_pretrained(base_model_path, subfolder="unet", in_channels=13, low_cpu_mem_usage=False, ignore_mismatched_sizes=True).to(dtype=torch.float16)

pipe = MimicBrushPipeline.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    scheduler=noise_scheduler,
    vae=vae,
    unet=unet,
    feature_extractor=None,
    safety_checker=None,
)

depth_guider = DepthGuider()
referencenet = ReferenceNet.from_pretrained(ref_model_path, subfolder="unet").to(dtype=torch.float16)
mimicbrush_model = MimicBrush_RefNet(pipe, image_encoder_path, mimicbrush_ckpt,  depth_anything_model, depth_guider, referencenet, device)
mask_processor = VaeImageProcessor(vae_scale_factor=1, do_normalize=False, do_binarize=True, do_convert_grayscale=True)

@spaces.GPU
def infer_single(ref_image, target_image, target_mask, seed = -1, num_inference_steps=50, guidance_scale = 5, enable_shape_control = False):
    #return ref_image
    """
    mask: 0/1 1-channel  np.array
    image: rgb           np.array
    """

    ref_image = ref_image.astype(np.uint8)
    target_image = target_image.astype(np.uint8)
    target_mask  = target_mask .astype(np.uint8)

    ref_image = Image.fromarray(ref_image.astype(np.uint8)) 
    ref_image = pad_img_to_square(ref_image)

    target_image = pad_img_to_square(Image.fromarray(target_image))
    target_image_low = target_image


    target_mask = np.stack([target_mask,target_mask,target_mask],-1).astype(np.uint8) * 255
    target_mask_np = target_mask.copy()
    target_mask = Image.fromarray(target_mask) 
    target_mask = pad_img_to_square(target_mask, True)

    target_image_ori = target_image.copy()
    target_image = collage_region(target_image_low, target_image, target_mask)
    

    depth_image = target_image_ori.copy()
    depth_image = np.array(depth_image)
    depth_image = transform({'image': depth_image})['image']
    depth_image = torch.from_numpy(depth_image).unsqueeze(0) / 255

    if not enable_shape_control:
        depth_image = depth_image * 0

    mask_pt = mask_processor.preprocess(target_mask, height=512, width=512)

    pred, depth_pred = mimicbrush_model.generate(pil_image=ref_image, depth_image = depth_image, num_samples=1, num_inference_steps=num_inference_steps,
                            seed=seed, image=target_image, mask_image=mask_pt, strength=1.0, guidance_scale=guidance_scale)


    depth_pred = F.interpolate(depth_pred, size=(512,512), mode = 'bilinear', align_corners=True)[0][0]
    depth_pred = (depth_pred - depth_pred.min()) / (depth_pred.max() - depth_pred.min()) * 255.0
    depth_pred = depth_pred.detach().cpu().numpy().astype(np.uint8)
    depth_pred = cv2.applyColorMap(depth_pred, cv2.COLORMAP_INFERNO)[:,:,::-1]

    pred = pred[0]
    pred = np.array(pred).astype(np.uint8)
    return pred, depth_pred.astype(np.uint8)



def inference_single_image(ref_image, 
                           tar_image, 
                           tar_mask, 
                           ddim_steps, 
                           scale, 
                           seed,
                           enable_shape_control,
                           ):
    if seed == -1:
        seed = np.random.randint(10000)
    pred, depth_pred = infer_single(ref_image, tar_image, tar_mask, seed, num_inference_steps=ddim_steps, guidance_scale = scale, enable_shape_control = enable_shape_control)
    return pred, depth_pred



def run_local(base,
              ref,
              *args):
    image = base["background"].convert("RGB") #base["image"].convert("RGB")
    mask = base["layers"][0]  #base["mask"].convert("L")
    
    image = np.asarray(image)
    mask = np.asarray(mask)[:,:,-1]
    #print(image.shape, mask.shape, mask.max(), mask.min())
    mask = np.where(mask > 128, 1, 0).astype(np.uint8)
    

    ref_image = ref.convert("RGB")
    ref_image = np.asarray(ref_image)

    if mask.sum() == 0:
        raise gr.Error('No mask for the background image.')
    
    mask_3 = np.stack([mask,mask,mask],-1).astype(np.uint8) * 255

    mask_alpha = mask_3.copy()
    for i in range(10):
        mask_alpha = cv2.GaussianBlur(mask_alpha, (3, 3), 0)
    
    synthesis, depth_pred = inference_single_image(ref_image.copy(), image.copy(), mask.copy(), *args)


    synthesis = crop_padding_and_resize(image, synthesis)
    depth_pred = crop_padding_and_resize(image, depth_pred)


    mask_3_bin = mask_alpha / 255
    synthesis = synthesis * mask_3_bin + image * (1-mask_3_bin)

    vis_source = vis_mask(image, mask_3).astype(np.uint8)
    return [synthesis.astype(np.uint8), depth_pred.astype(np.uint8), vis_source, mask_3]



with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("#  MimicBrush: Zero-shot Image Editing with Reference Imitation ")
        with gr.Row():
            baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=1, height=768)
            with gr.Accordion("Advanced Option", open=True):
                num_samples = 1
                ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
                scale = gr.Slider(label="Guidance Scale", minimum=-30.0, maximum=30.0, value=5.0, step=0.1)
                seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=-1)
                enable_shape_control = gr.Checkbox(label='Keep the original shape', value=False, interactive = True)
                
                gr.Markdown("### Tutorial")
                gr.Markdown("1. Upload the source image and the reference image")
                gr.Markdown("2. Select the \"draw button\" to mask the to-edit region on the source image  ")
                gr.Markdown("3. Click generate ")
                gr.Markdown("#### You shoud click \"keep the original shape\" to conduct texture transfer  ")
    
        gr.Markdown("# Upload the source image and reference image")
        gr.Markdown("### Tips: you could adjust the brush size")

        with gr.Row():
            base = gr.ImageEditor(  label="Source",
                                    type="pil",
                                    brush=gr.Brush(colors=["#FFFFFF"],default_size = 30,color_mode = "fixed"),
                                    layers = False,
                                    interactive=True
                                )
            ref = gr.Image(label="Reference", sources="upload", type="pil", height=512)
        run_local_button = gr.Button(value="Run")
        


    with gr.Row():
        gr.Examples(
        examples=[
            [
                './demo_example/005_source.png',
                './demo_example/005_reference.png', 
                0

            ],
            [
                './demo_example/004_source.png',
                './demo_example/004_reference.png', 
                0
            ],
   
            [
                './demo_example/000_source.png',
                './demo_example/000_reference.png', 
                0
            ],
            [
                './demo_example/003_source.png',
                './demo_example/003_reference.png', 
                0
            ],     

            [
                './demo_example/006_source.png',
                './demo_example/006_reference.png', 
                0
            ],
            [
                './demo_example/001_source.png',
                './demo_example/001_reference.png', 
                1
            ],
            [
                './demo_example/002_source.png',
                './demo_example/002_reference.png', 
                1
            ],

            [
                './demo_example/007_source.png',
                './demo_example/007_reference.png', 
                1
            ],

        ],

        inputs=[
                base,
                ref,
                enable_shape_control
                ],
                cache_examples=False,
                examples_per_page=100)

    run_local_button.click(fn=run_local, 
                           inputs=[base, 
                                   ref, 
                                   ddim_steps, 
                                   scale, 
                                   seed,
                                   enable_shape_control
                                   ], 
                           outputs=[baseline_gallery]
                        )

demo.launch(server_name="0.0.0.0")