import os
import torch
from dataclasses import dataclass
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import cv2
import mediapipe as mp
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
import vqvae
import vit
from typing import Literal
from diffusion import create_diffusion
from utils import scale_keypoint, keypoint_heatmap, check_keypoints_validity
from segment_hoi import init_sam
from io import BytesIO
from PIL import Image
import random
from copy import deepcopy
from typing import Optional
import requests
from huggingface_hub import hf_hub_download
import spaces

MAX_N = 6
FIX_MAX_N = 6

placeholder = cv2.cvtColor(cv2.imread("placeholder.png"), cv2.COLOR_BGR2RGB)
NEW_MODEL = True
MODEL_EPOCH = 6
REF_POSE_MASK = True

def set_seed(seed):
    seed = int(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)

# if torch.cuda.is_available():
device = "cuda"
# else:
    # device = "cpu"

def remove_prefix(text, prefix):
    if text.startswith(prefix):
        return text[len(prefix) :]
    return text


def unnormalize(x):
    return (((x + 1) / 2) * 255).astype(np.uint8)


def visualize_hand(all_joints, img, side=["right", "left"], n_avail_joints=21):
    # Define the connections between joints for drawing lines and their corresponding colors
    connections = [
        ((0, 1), "red"),
        ((1, 2), "green"),
        ((2, 3), "blue"),
        ((3, 4), "purple"),
        ((0, 5), "orange"),
        ((5, 6), "pink"),
        ((6, 7), "brown"),
        ((7, 8), "cyan"),
        ((0, 9), "yellow"),
        ((9, 10), "magenta"),
        ((10, 11), "lime"),
        ((11, 12), "indigo"),
        ((0, 13), "olive"),
        ((13, 14), "teal"),
        ((14, 15), "navy"),
        ((15, 16), "gray"),
        ((0, 17), "lavender"),
        ((17, 18), "silver"),
        ((18, 19), "maroon"),
        ((19, 20), "fuchsia"),
    ]
    H, W, C = img.shape

    # Create a figure and axis
    plt.figure()
    ax = plt.gca()
    # Plot joints as points
    ax.imshow(img)
    start_is = []
    if "right" in side:
        start_is.append(0)
    if "left" in side:
        start_is.append(21)
    for start_i in start_is:
        joints = all_joints[start_i : start_i + n_avail_joints]
        if len(joints) == 1:
            ax.scatter(joints[0][0], joints[0][1], color="red", s=10)
        else:
            for connection, color in connections[: len(joints) - 1]:
                joint1 = joints[connection[0]]
                joint2 = joints[connection[1]]
                ax.plot([joint1[0], joint2[0]], [joint1[1], joint2[1]], color=color)

    ax.set_xlim([0, W])
    ax.set_ylim([0, H])
    ax.grid(False)
    ax.set_axis_off()
    ax.invert_yaxis()
    # plt.subplots_adjust(wspace=0.01)
    # plt.show()
    buf = BytesIO()
    plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
    plt.close()

    # Convert BytesIO object to numpy array
    buf.seek(0)
    img_pil = Image.open(buf)
    img_pil = img_pil.resize((H, W))
    numpy_img = np.array(img_pil)

    return numpy_img


def mask_image(image, mask, color=[0, 0, 0], alpha=0.6, transparent=True):
    """Overlay mask on image for visualization purpose.
    Args:
        image (H, W, 3) or (H, W): input image
        mask (H, W): mask to be overlaid
        color: the color of overlaid mask
        alpha: the transparency of the mask
    """
    out = deepcopy(image)
    img = deepcopy(image)
    img[mask == 1] = color
    if transparent:
        out = cv2.addWeighted(img, alpha, out, 1 - alpha, 0, out)
    else:
        out = img
    return out


def scale_keypoint(keypoint, original_size, target_size):
    """Scale a keypoint based on the resizing of the image."""
    keypoint_copy = keypoint.copy()
    keypoint_copy[:, 0] *= target_size[0] / original_size[0]
    keypoint_copy[:, 1] *= target_size[1] / original_size[1]
    return keypoint_copy


print("Configure...")


@dataclass
class HandDiffOpts:
    run_name: str = "ViT_256_handmask_heatmap_nvs_b25_lr1e-5"
    sd_path: str = "/users/kchen157/scratch/weights/SD/sd-v1-4.ckpt"
    log_dir: str = "/users/kchen157/scratch/log"
    data_root: str = "/users/kchen157/data/users/kchen157/dataset/handdiff"
    image_size: tuple = (256, 256)
    latent_size: tuple = (32, 32)
    latent_dim: int = 4
    mask_bg: bool = False
    kpts_form: str = "heatmap"
    n_keypoints: int = 42
    n_mask: int = 1
    noise_steps: int = 1000
    test_sampling_steps: int = 250
    ddim_steps: int = 100
    ddim_discretize: str = "uniform"
    ddim_eta: float = 0.0
    beta_start: float = 8.5e-4
    beta_end: float = 0.012
    latent_scaling_factor: float = 0.18215
    cfg_pose: float = 5.0
    cfg_appearance: float = 3.5
    batch_size: int = 25
    lr: float = 1e-5
    max_epochs: int = 500
    log_every_n_steps: int = 100
    limit_val_batches: int = 1
    n_gpu: int = 8
    num_nodes: int = 1
    precision: str = "16-mixed"
    profiler: str = "simple"
    swa_epoch_start: int = 10
    swa_lrs: float = 1e-3
    num_workers: int = 10
    n_val_samples: int = 4

# load models
token = os.getenv("HF_TOKEN")
if NEW_MODEL:
    opts = HandDiffOpts()
    if MODEL_EPOCH == 7:
        model_path = './DINO_EMA_11M_b50_lr1e-5_epoch7_step380k.ckpt'
    elif MODEL_EPOCH == 6:
        # model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt"
        model_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt", token=token)
    elif MODEL_EPOCH == 4:
        model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch4_step210k.ckpt"
    elif MODEL_EPOCH == 10:
        model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch10_step550k.ckpt"
    else:
        raise ValueError(f"new model epoch should be either 6 or 7, got {MODEL_EPOCH}")
    # vae_path = './vae-ft-mse-840000-ema-pruned.ckpt'
    vae_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="vae-ft-mse-840000-ema-pruned.ckpt", token=token)
    # sd_path = './sd-v1-4.ckpt'
    print('Load diffusion model...')
    diffusion = create_diffusion(str(opts.test_sampling_steps))
    model = vit.DiT_XL_2(
        input_size=opts.latent_size[0],
        latent_dim=opts.latent_dim,
        in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
        learn_sigma=True,
    ).to(device)
    # ckpt_state_dict = torch.load(model_path)['model_state_dict']
    ckpt_state_dict = torch.load(model_path, map_location='cpu')['ema_state_dict']
    missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
    model = model.to(device)
    model.eval()
    print(missing_keys, extra_keys)
    assert len(missing_keys) == 0
    vae_state_dict = torch.load(vae_path, map_location='cpu')['state_dict']
    print(f"vae_state_dict encoder dtype: {vae_state_dict['encoder.conv_in.weight'].dtype}")
    autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False)
    print(f"autoencoder encoder dtype: {next(autoencoder.encoder.parameters()).dtype}")
    print(f"encoder before load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
    print(f"encoder before load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
    missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
    print(f"encoder after load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
    print(f"encoder after load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
    autoencoder = autoencoder.to(device)
    autoencoder.eval()
    print(f"encoder after eval() min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
    print(f"encoder after eval() max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
    print(f"autoencoder encoder after eval() dtype: {next(autoencoder.encoder.parameters()).dtype}")
    assert len(missing_keys) == 0
# else:
#     opts = HandDiffOpts()
#     model_path = './finetune_epoch=5-step=130000.ckpt'
#     sd_path = './sd-v1-4.ckpt'
#     print('Load diffusion model...')
#     diffusion = create_diffusion(str(opts.test_sampling_steps))
#     model = vit.DiT_XL_2(
#         input_size=opts.latent_size[0],
#         latent_dim=opts.latent_dim,
#         in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
#         learn_sigma=True,
#     ).to(device)
#     ckpt_state_dict = torch.load(model_path)['state_dict']
#     dit_state_dict = {remove_prefix(k, 'diffusion_backbone.'): v for k, v in ckpt_state_dict.items() if k.startswith('diffusion_backbone')}
#     vae_state_dict = {remove_prefix(k, 'autoencoder.'): v for k, v in ckpt_state_dict.items() if k.startswith('autoencoder')}
#     missing_keys, extra_keys = model.load_state_dict(dit_state_dict, strict=False)
#     model.eval()
#     assert len(missing_keys) == 0 and len(extra_keys) == 0
#     autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).to(device)
#     missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
#     autoencoder.eval()
#     assert len(missing_keys) == 0 and len(extra_keys) == 0
sam_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="sam_vit_h_4b8939.pth", token=token)
sam_predictor = init_sam(ckpt_path=sam_path, device='cpu')


print("Mediapipe hand detector and SAM ready...")
mp_hands = mp.solutions.hands
hands = mp_hands.Hands(
    static_image_mode=True,  # Use False if image is part of a video stream
    max_num_hands=2,  # Maximum number of hands to detect
    min_detection_confidence=0.1,
)

def prepare_ref_anno(ref):
    if ref is None:
        return (
            None,
            None,
            None,
            None,
            None,
        )
    missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)

    img = ref["composite"][..., :3]
    img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
    keypts = np.zeros((42, 2))
    mp_pose = hands.process(img)
    if mp_pose.multi_hand_landmarks:
        # handedness is flipped assuming the input image is mirrored in MediaPipe
        for hand_landmarks, handedness in zip(
            mp_pose.multi_hand_landmarks, mp_pose.multi_handedness
        ):
            # actually right hand
            if handedness.classification[0].label == "Left":
                start_idx = 0
            # actually left hand
            elif handedness.classification[0].label == "Right":
                start_idx = 21
            for i, landmark in enumerate(hand_landmarks.landmark):
                keypts[start_idx + i] = [
                    landmark.x * opts.image_size[1],
                    landmark.y * opts.image_size[0],
                ]

        print(f"keypts.max(): {keypts.max()}, keypts.min(): {keypts.min()}")
        return img, keypts
    else:
        return img, None

def get_ref_anno(img, keypts):
    if keypts is None:
        no_hands = cv2.resize(np.array(Image.open("no_hands.png"))[..., :3], (LENGTH, LENGTH))
        return None, no_hands, None
    if isinstance(keypts, list):
        if len(keypts[0]) == 0:
            keypts[0] = np.zeros((21, 2))
        elif len(keypts[0]) == 21:
            keypts[0] = np.array(keypts[0], dtype=np.float32)
        else:
            gr.Info("Number of right hand keypoints should be either 0 or 21.")
            return None, None, None

        if len(keypts[1]) == 0:
            keypts[1] = np.zeros((21, 2))
        elif len(keypts[1]) == 21:
            keypts[1] = np.array(keypts[1], dtype=np.float32)
        else:
            gr.Info("Number of left hand keypoints should be either 0 or 21.")
            return None, None, None

        keypts = np.concatenate(keypts, axis=0)
    if REF_POSE_MASK:
        sam_predictor.set_image(img)
        if keypts[0].sum() != 0 and keypts[21].sum() != 0:
            input_point = np.array([keypts[0], keypts[21]])
            input_label = np.array([1, 1])
        elif keypts[0].sum() != 0:
            input_point = np.array(keypts[:1])
            input_label = np.array([1])
        elif keypts[21].sum() != 0:
            input_point = np.array(keypts[21:22])
            input_label = np.array([1])
        masks, _, _ = sam_predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=False,
        )
        hand_mask = masks[0]
        masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None])
        ref_pose = visualize_hand(keypts, masked_img)
    else:
        hand_mask = np.zeros_like(img[:,:, 0])
        ref_pose = np.zeros_like(img)
    def make_ref_cond(
        img,
        keypts,
        hand_mask,
        device="cuda",
        target_size=(256, 256),
        latent_size=(32, 32),
    ):
        image_transform = Compose(
            [
                ToTensor(),
                Resize(target_size),
                Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
            ]
        )
        image = image_transform(img) # .to(device)
        kpts_valid = check_keypoints_validity(keypts, target_size)
        heatmaps = torch.tensor(
            keypoint_heatmap(
                scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0
            )
            * kpts_valid[:, None, None],
            dtype=torch.float,
            # device=device
        )[None, ...]
        mask = torch.tensor(
            cv2.resize(
                hand_mask.astype(int),
                dsize=latent_size,
                interpolation=cv2.INTER_NEAREST,
            ),
            dtype=torch.float,
            # device=device,
        ).unsqueeze(0)[None, ...]
        return image[None, ...], heatmaps, mask

    print(f"img.max(): {img.max()}, img.min(): {img.min()}")
    image, heatmaps, mask = make_ref_cond(
        img,
        keypts,
        hand_mask,
        device="cuda",
        target_size=opts.image_size,
        latent_size=opts.latent_size,
    )
    print(f"image.max(): {image.max()}, image.min(): {image.min()}")
    print(f"opts.latent_scaling_factor: {opts.latent_scaling_factor}")
    print(f"autoencoder encoder before operating max: {min([p.min() for p in autoencoder.encoder.parameters()])}")
    print(f"autoencoder encoder before operating min: {max([p.max() for p in autoencoder.encoder.parameters()])}")
    print(f"autoencoder encoder before operating dtype: {next(autoencoder.encoder.parameters()).dtype}")
    latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
    print(f"latent.max(): {latent.max()}, latent.min(): {latent.min()}")
    if not REF_POSE_MASK:
        heatmaps = torch.zeros_like(heatmaps)
        mask = torch.zeros_like(mask)
    print(f"heatmaps.max(): {heatmaps.max()}, heatmaps.min(): {heatmaps.min()}")
    print(f"mask.max(): {mask.max()}, mask.min(): {mask.min()}")
    ref_cond = torch.cat([latent, heatmaps, mask], 1)
    print(f"ref_cond.max(): {ref_cond.max()}, ref_cond.min(): {ref_cond.min()}")

    return img, ref_pose, ref_cond

def get_target_anno(target):
    if target is None:
        return (
            gr.State.update(value=None),
            gr.Image.update(value=None),
            gr.State.update(value=None),
            gr.State.update(value=None),
        )
    pose_img = target["composite"][..., :3]
    pose_img = cv2.resize(pose_img, opts.image_size, interpolation=cv2.INTER_AREA)
    # detect keypoints
    mp_pose = hands.process(pose_img)
    target_keypts = np.zeros((42, 2))
    detected = np.array([0, 0])
    start_idx = 0
    if mp_pose.multi_hand_landmarks:
        # handedness is flipped assuming the input image is mirrored in MediaPipe
        for hand_landmarks, handedness in zip(
            mp_pose.multi_hand_landmarks, mp_pose.multi_handedness
        ):
            # actually right hand
            if handedness.classification[0].label == "Left":
                start_idx = 0
                detected[0] = 1
            # actually left hand
            elif handedness.classification[0].label == "Right":
                start_idx = 21
                detected[1] = 1
            for i, landmark in enumerate(hand_landmarks.landmark):
                target_keypts[start_idx + i] = [
                    landmark.x * opts.image_size[1],
                    landmark.y * opts.image_size[0],
                ]

        target_pose = visualize_hand(target_keypts, pose_img)
        kpts_valid = check_keypoints_validity(target_keypts, opts.image_size)
        target_heatmaps = torch.tensor(
            keypoint_heatmap(
                scale_keypoint(target_keypts, opts.image_size, opts.latent_size),
                opts.latent_size,
                var=1.0,
            )
            * kpts_valid[:, None, None],
            dtype=torch.float,
            # device=device,
        )[None, ...]
        target_cond = torch.cat(
            [target_heatmaps, torch.zeros_like(target_heatmaps)[:, :1]], 1
        )
    else:
        raise gr.Error("No hands detected in the target image.")

    return pose_img, target_pose, target_cond, target_keypts


def get_mask_inpaint(ref):
    inpaint_mask = np.array(ref["layers"][0])[..., -1]
    inpaint_mask = cv2.resize(
        inpaint_mask, opts.image_size, interpolation=cv2.INTER_AREA
    )
    inpaint_mask = (inpaint_mask >= 128).astype(np.uint8)
    return inpaint_mask


def visualize_ref(crop, brush):
    if crop is None or brush is None:
        return None
    inpainted = brush["layers"][0][..., -1]
    img = crop["background"][..., :3]
    img = cv2.resize(img, inpainted.shape[::-1], interpolation=cv2.INTER_AREA)
    mask = inpainted < 128
    # img = img.astype(np.int32)
    # img[mask, :] = img[mask, :] - 50
    # img[np.any(img<0, axis=-1)]=0
    # img = img.astype(np.uint8)
    img = mask_image(img, mask)
    return img


def get_kps(img, keypoints, side: Literal["right", "left"], evt: gr.SelectData):
    if keypoints is None:
        keypoints = [[], []]
    kps = np.zeros((42, 2))
    if side == "right":
        if len(keypoints[0]) == 21:
            gr.Info("21 keypoints for right hand already selected. Try reset if something looks wrong.")
        else:
            keypoints[0].append(list(evt.index))
        len_kps = len(keypoints[0])
        kps[:len_kps] = np.array(keypoints[0])
    elif side == "left":
        if len(keypoints[1]) == 21:
            gr.Info("21 keypoints for left hand already selected. Try reset if something looks wrong.")
        else:
            keypoints[1].append(list(evt.index))
        len_kps = len(keypoints[1])
        kps[21 : 21 + len_kps] = np.array(keypoints[1])
    vis_hand = visualize_hand(kps, img, side, len_kps)
    return vis_hand, keypoints


def undo_kps(img, keypoints, side: Literal["right", "left"]):
    if keypoints is None:
        return img, None
    kps = np.zeros((42, 2))
    if side == "right":
        if len(keypoints[0]) == 0:
            return img, keypoints
        keypoints[0].pop()
        len_kps = len(keypoints[0])
        kps[:len_kps] = np.array(keypoints[0])
    elif side == "left":
        if len(keypoints[1]) == 0:
            return img, keypoints
        keypoints[1].pop()
        len_kps = len(keypoints[1])
        kps[21 : 21 + len_kps] = np.array(keypoints[1])
    vis_hand = visualize_hand(kps, img, side, len_kps)
    return vis_hand, keypoints


def reset_kps(img, keypoints, side: Literal["right", "left"]):
    if keypoints is None:
        return img, None
    if side == "right":
        keypoints[0] = []
    elif side == "left":
        keypoints[1] = []
    return img, keypoints

@spaces.GPU(duration=60)
def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg):
    set_seed(seed)
    z = torch.randn(
        (num_gen, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]),
        device=device,
    )
    print(f"z.device: {z.device}")
    target_cond = target_cond.repeat(num_gen, 1, 1, 1).to(z.device)
    ref_cond = ref_cond.repeat(num_gen, 1, 1, 1).to(z.device)
    print(f"target_cond.max(): {target_cond.max()}, target_cond.min(): {target_cond.min()}")
    print(f"ref_cond.max(): {ref_cond.max()}, ref_cond.min(): {ref_cond.min()}")
    # novel view synthesis mode = off
    nvs = torch.zeros(num_gen, dtype=torch.int, device=device)
    z = torch.cat([z, z], 0)
    model_kwargs = dict(
        target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]),
        ref_cond=torch.cat([ref_cond, torch.zeros_like(ref_cond)]),
        nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]),
        cfg_scale=cfg,
    )

    samples, _ = diffusion.p_sample_loop(
        model.forward_with_cfg,
        z.shape,
        z,
        clip_denoised=False,
        model_kwargs=model_kwargs,
        progress=True,
        device=device,
    ).chunk(2)
    sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
    sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
    sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy())

    results = []
    results_pose = []
    for i in range(MAX_N):
        if i < num_gen:
            results.append(sampled_images[i])
            results_pose.append(visualize_hand(target_keypts, sampled_images[i]))
        else:
            results.append(placeholder)
            results_pose.append(placeholder)
    print(f"results[0].max(): {results[0].max()}")
    return results, results_pose

@spaces.GPU(duration=120)
def ready_sample(img_ori, inpaint_mask, keypts):
    img = cv2.resize(img_ori[..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
    sam_predictor.set_image(img)
    if len(keypts[0]) == 0:
        keypts[0] = np.zeros((21, 2))
    elif len(keypts[0]) == 21:
        keypts[0] = np.array(keypts[0], dtype=np.float32)
    else:
        gr.Info("Number of right hand keypoints should be either 0 or 21.")
        return None, None

    if len(keypts[1]) == 0:
        keypts[1] = np.zeros((21, 2))
    elif len(keypts[1]) == 21:
        keypts[1] = np.array(keypts[1], dtype=np.float32)
    else:
        gr.Info("Number of left hand keypoints should be either 0 or 21.")
        return None, None

    keypts = np.concatenate(keypts, axis=0)
    keypts = scale_keypoint(keypts, (LENGTH, LENGTH), opts.image_size)

    box_shift_ratio = 0.5
    box_size_factor = 1.2

    if keypts[0].sum() != 0 and keypts[21].sum() != 0:
        input_point = np.array(keypts)
        input_box = np.stack([keypts.min(axis=0), keypts.max(axis=0)])
    elif keypts[0].sum() != 0:
        input_point = np.array(keypts[:21])
        input_box = np.stack([keypts[:21].min(axis=0), keypts[:21].max(axis=0)])
    elif keypts[21].sum() != 0:
        input_point = np.array(keypts[21:])
        input_box = np.stack([keypts[21:].min(axis=0), keypts[21:].max(axis=0)])
    else:
        raise ValueError(
            "Something wrong. If no hand detected, it should not reach here."
        )

    input_label = np.ones_like(input_point[:, 0]).astype(np.int32)
    box_trans = input_box[0] * box_shift_ratio + input_box[1] * (1 - box_shift_ratio)
    input_box = ((input_box - box_trans) * box_size_factor + box_trans).reshape(-1)

    masks, _, _ = sam_predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        box=input_box[None, :],
        multimask_output=False,
    )
    hand_mask = masks[0]

    inpaint_latent_mask = torch.tensor(
        cv2.resize(
            inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
        ),
        dtype=torch.float,
        # device=device,
    ).unsqueeze(0)[None, ...]

    def make_ref_cond(
        img,
        keypts,
        hand_mask,
        device=device,
        target_size=(256, 256),
        latent_size=(32, 32),
    ):
        image_transform = Compose(
            [
                ToTensor(),
                Resize(target_size),
                Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
            ]
        )
        image = image_transform(img)
        kpts_valid = check_keypoints_validity(keypts, target_size)
        heatmaps = torch.tensor(
            keypoint_heatmap(
                scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0
            )
            * kpts_valid[:, None, None],
            dtype=torch.float,
            # device=device,
        )[None, ...]
        mask = torch.tensor(
            cv2.resize(
                hand_mask.astype(int),
                dsize=latent_size,
                interpolation=cv2.INTER_NEAREST,
            ),
            dtype=torch.float,
            # device=device,
        ).unsqueeze(0)[None, ...]
        return image[None, ...], heatmaps, mask

    image, heatmaps, mask = make_ref_cond(
        img,
        keypts,
        hand_mask * (1 - inpaint_mask),
        device=device,
        target_size=opts.image_size,
        latent_size=opts.latent_size,
    )
    latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
    target_cond = torch.cat([heatmaps, torch.zeros_like(mask)], 1)
    ref_cond = torch.cat([latent, heatmaps, mask], 1)
    ref_cond = torch.zeros_like(ref_cond)

    img32 = cv2.resize(img, opts.latent_size, interpolation=cv2.INTER_NEAREST)
    assert mask.max() == 1
    vis_mask32 = mask_image(
        img32, inpaint_latent_mask[0,0].cpu().numpy(), (255,255,255), transparent=False
    ).astype(np.uint8) # 1.0 - mask[0, 0].cpu().numpy()

    assert np.unique(inpaint_mask).shape[0] <= 2
    assert hand_mask.dtype == bool
    mask256 = inpaint_mask # hand_mask * (1 - inpaint_mask)
    vis_mask256 = mask_image(img, mask256, (255,255,255), transparent=False).astype(
        np.uint8
    ) # 1 - mask256

    return (
        ref_cond,
        target_cond,
        latent,
        inpaint_latent_mask,
        keypts,
        vis_mask32,
        vis_mask256,
    )


def switch_mask_size(radio):
    if radio == "256x256":
        out = (gr.update(visible=False), gr.update(visible=True))
    elif radio == "latent size (32x32)":
        out = (gr.update(visible=True), gr.update(visible=False))
    return out

@spaces.GPU(duration=300)
def sample_inpaint(
    ref_cond,
    target_cond,
    latent,
    inpaint_latent_mask,
    keypts,
    num_gen,
    seed,
    cfg,
    quality,
):
    set_seed(seed)
    N = num_gen
    jump_length = 10
    jump_n_sample = quality
    cfg_scale = cfg
    z = torch.randn(
        (N, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device=device
    )
    target_cond_N = target_cond.repeat(N, 1, 1, 1).to(z.device)
    ref_cond_N = ref_cond.repeat(N, 1, 1, 1).to(z.device)
    # novel view synthesis mode = off
    nvs = torch.zeros(N, dtype=torch.int, device=device)
    z = torch.cat([z, z], 0)
    model_kwargs = dict(
        target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]),
        ref_cond=torch.cat([ref_cond_N, torch.zeros_like(ref_cond_N)]),
        nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]),
        cfg_scale=cfg_scale,
    )

    samples, _ = diffusion.inpaint_p_sample_loop(
        model.forward_with_cfg,
        z.shape,
        latent.to(z.device),
        inpaint_latent_mask.to(z.device),
        z,
        clip_denoised=False,
        model_kwargs=model_kwargs,
        progress=True,
        device=z.device,
        jump_length=jump_length,
        jump_n_sample=jump_n_sample,
    ).chunk(2)
    sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
    sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
    sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy())

    # visualize
    results = []
    results_pose = []
    for i in range(FIX_MAX_N):
        if i < num_gen:
            results.append(sampled_images[i])
            results_pose.append(visualize_hand(keypts, sampled_images[i]))
        else:
            results.append(placeholder)
            results_pose.append(placeholder)
    return results, results_pose


def flip_hand(
    img, pose_img, cond: Optional[torch.Tensor], keypts: Optional[torch.Tensor] = None, pose_manual_img = None,
    manual_kp_right=None, manual_kp_left=None
):
    if cond is None:  # clear clicked
        return None, None, None, None
    img["composite"] = img["composite"][:, ::-1, :]
    img["background"] = img["background"][:, ::-1, :]
    img["layers"] = [layer[:, ::-1, :] for layer in img["layers"]]
    pose_img = pose_img[:, ::-1, :]
    cond = cond.flip(-1)
    if keypts is not None:  # cond is target_cond
        if keypts[:21, :].sum() != 0:
            keypts[:21, 0] = opts.image_size[1] - keypts[:21, 0]
            # keypts[:21, 1] = opts.image_size[0] - keypts[:21, 1]
        if keypts[21:, :].sum() != 0:
            keypts[21:, 0] = opts.image_size[1] - keypts[21:, 0]
            # keypts[21:, 1] = opts.image_size[0] - keypts[21:, 1]
    if pose_manual_img is not None:
        pose_manual_img = pose_manual_img[:, ::-1, :]
        manual_kp_right = manual_kp_right[:, ::-1, :]
        manual_kp_left = manual_kp_left[:, ::-1, :]
    return img, pose_img, cond, keypts, pose_manual_img, manual_kp_right, manual_kp_left


def resize_to_full(img):
    img["background"] = cv2.resize(img["background"], (LENGTH, LENGTH))
    img["composite"] = cv2.resize(img["composite"], (LENGTH, LENGTH))
    img["layers"] = [cv2.resize(layer, (LENGTH, LENGTH)) for layer in img["layers"]]
    return img


def clear_all():
    return (
        None,
        None,
        None,
        None,
        None,
        False,
        None,
        None,
        False,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        1,
        42,
        3.0,
        gr.update(interactive=False),
        []
    )


def fix_clear_all():
    return (
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        1,
        # (0,0),
        42,
        3.0,
        10,
    )


def enable_component(image1, image2):
    if image1 is None or image2 is None:
        return gr.update(interactive=False)
    if "background" in image1 and "layers" in image1 and "composite" in image1:
        if (
            image1["background"].sum() == 0
            and (sum([im.sum() for im in image1["layers"]]) == 0)
            and image1["composite"].sum() == 0
        ):
            return gr.update(interactive=False)
    if "background" in image2 and "layers" in image2 and "composite" in image2:
        if (
            image2["background"].sum() == 0
            and (sum([im.sum() for im in image2["layers"]]) == 0)
            and image2["composite"].sum() == 0
        ):
            return gr.update(interactive=False)
    return gr.update(interactive=True)


def set_visible(checkbox, kpts, img_clean, img_pose_right, img_pose_left, done=None, done_info=None):
    if kpts is None:
        kpts = [[], []]
    if "Right hand" not in checkbox:
        kpts[0] = []
        vis_right = img_clean
        update_right = gr.update(visible=False)
        update_r_info = gr.update(visible=False)
    else:
        vis_right = img_pose_right
        update_right = gr.update(visible=True)
        update_r_info = gr.update(visible=True)

    if "Left hand" not in checkbox:
        kpts[1] = []
        vis_left = img_clean
        update_left = gr.update(visible=False)
        update_l_info = gr.update(visible=False)
    else:
        vis_left = img_pose_left
        update_left = gr.update(visible=True)
        update_l_info = gr.update(visible=True)

    ret = [
        kpts,
        vis_right,
        vis_left,
        update_right,
        update_right,
        update_right,
        update_left,
        update_left,
        update_left,
        update_r_info,
        update_l_info,
    ]
    if done is not None:
        if not checkbox:
            ret.append(gr.update(visible=False))
            ret.append(gr.update(visible=False))
        else:
            ret.append(gr.update(visible=True))
            ret.append(gr.update(visible=True))
    return tuple(ret)

def set_unvisible():
    return (
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False)
    )

def set_no_hands(decider, component):
    if decider is None:
        no_hands = cv2.resize(np.array(Image.open("no_hands.png"))[..., :3], (LENGTH, LENGTH))
        return no_hands
    else:
        return component

def visible_component(decider, component):
    if decider is not None:
        update_component = gr.update(visible=True)
    else:
        update_component = gr.update(visible=False)
    return update_component

def unvisible_component(decider, component):
    if decider is not None:
        update_component = gr.update(visible=False)
    else:
        update_component = gr.update(visible=True)
    return update_component

# def make_change(decider, state):
#     '''
#     if decider is not None, change the state's value. True/False does not matter.
#     '''
#     if decider is not None:
#         if state:
#             state = False
#         else:
#             state = True
#         return state
#     else:
#         return state

LENGTH = 480

example_ref_imgs = [
    [
        "sample_images/sample1.jpg",
    ],
    [
        "sample_images/sample2.jpg",
    ],
    [
        "sample_images/sample3.jpg",
    ],
    [
        "sample_images/sample4.jpg",
    ],
    # [
    #     "sample_images/sample5.jpg",
    # ],
    [
        "sample_images/sample6.jpg",
    ],
    # [
    #     "sample_images/sample7.jpg",
    # ],
    # [
    #     "sample_images/sample8.jpg",
    # ],
    # [
    #     "sample_images/sample9.jpg",
    # ],
    # [
    #     "sample_images/sample10.jpg",
    # ],
    # [
    #     "sample_images/sample11.jpg",
    # ],
    # ["pose_images/pose1.jpg"],
    # ["pose_images/pose2.jpg"],
    # ["pose_images/pose3.jpg"],
    # ["pose_images/pose4.jpg"],
    # ["pose_images/pose5.jpg"],
    # ["pose_images/pose6.jpg"],
    # ["pose_images/pose7.jpg"],
    # ["pose_images/pose8.jpg"],
]
example_target_imgs = [
    # [
    #     "sample_images/sample1.jpg",
    # ],
    # [
    #     "sample_images/sample2.jpg",
    # ],
    # [
    #     "sample_images/sample3.jpg",
    # ],
    # [
    #     "sample_images/sample4.jpg",
    # ],
    [
        "sample_images/sample5.jpg",
    ],
    # [
        # "sample_images/sample6.jpg",
    # ],
    # [
    #     "sample_images/sample7.jpg",
    # ],
    # [
    #     "sample_images/sample8.jpg",
    # ],
    [
        "sample_images/sample9.jpg",
    ],
    [
        "sample_images/sample10.jpg",
    ],
    [
        "sample_images/sample11.jpg",
    ],
    ["pose_images/pose1.jpg"],
    # ["pose_images/pose2.jpg"],
    # ["pose_images/pose3.jpg"],
    # ["pose_images/pose4.jpg"],
    # ["pose_images/pose5.jpg"],
    # ["pose_images/pose6.jpg"],
    # ["pose_images/pose7.jpg"],
    # ["pose_images/pose8.jpg"],
]
fix_example_imgs = [
    ["bad_hands/1.jpg"],  # "bad_hands/1_mask.jpg"],
    # ["bad_hands/2.jpg"],  # "bad_hands/2_mask.jpg"],
    ["bad_hands/3.jpg"],  # "bad_hands/3_mask.jpg"],
    # ["bad_hands/4.jpg"],  # "bad_hands/4_mask.jpg"],
    ["bad_hands/5.jpg"],  # "bad_hands/5_mask.jpg"],
    ["bad_hands/6.jpg"],  # "bad_hands/6_mask.jpg"],
    ["bad_hands/7.jpg"],  # "bad_hands/7_mask.jpg"],
    # ["bad_hands/8.jpg"],  # "bad_hands/8_mask.jpg"],
    # ["bad_hands/9.jpg"],  # "bad_hands/9_mask.jpg"],
    # ["bad_hands/10.jpg"],  # "bad_hands/10_mask.jpg"],
    # ["bad_hands/11.jpg"],  # "bad_hands/11_mask.jpg"],
    # ["bad_hands/12.jpg"],  # "bad_hands/12_mask.jpg"],
    # ["bad_hands/13.jpg"],  # "bad_hands/13_mask.jpg"],
    ["bad_hands/14.jpg"],
    ["bad_hands/15.jpg"],
]
custom_css = """
.gradio-container .examples img {
    width: 240px !important;
    height: 240px !important;
}
"""

_HEADER_ = '''
<div style="text-align: center;">
    <h1><b>FoundHand: Large-Scale Domain-Specific Learning for Controllable Hand Image Generation</b></h1>
    <h2 style="color: #777777;">CVPR 2025</h2>
    <style>
        .link-spacing {
            margin-right: 20px;
        }
    </style>
    <p style="font-size: 15px;">
        <span style="display: inline-block; margin-right: 30px;">Brown University</span>
        <span style="display: inline-block;">Meta Reality Labs</span>
    </p>
    <h3>
        <a href='https://arxiv.org/abs/2412.02690' target='_blank' class="link-spacing">Paper</a>
        <a href='https://ivl.cs.brown.edu/research/foundhand.html' target='_blank' class="link-spacing">Project Page</a>
        <a href='' target='_blank' class="link-spacing">Code</a>
        <a href='' target='_blank'>Model Weights</a>
    </h3>
    <p>Below are two important abilities of our model. First, we can <b>edit hand poses</b> given two hand images - one is the image to edit, and the other one provides target hand pose. Second, we can automatically <b>fix malformed hand images</b>, following the user-provided target hand pose and area to fix.</p>
</div>
'''

_CITE_ = r"""
```
    @article{chen2024foundhand,
    title={FoundHand: Large-Scale Domain-Specific Learning for Controllable Hand Image Generation},
    author={Chen, Kefan and Min, Chaerin and Zhang, Linguang and Hampali, Shreyas and Keskin, Cem and Sridhar, Srinath},
    journal={arXiv preprint arXiv:2412.02690},
    year={2024}
    }
```
"""

with gr.Blocks(css=custom_css, theme="soft") as demo:
    gr.Markdown(_HEADER_)
    with gr.Tab("Edit Hand Poses"):
        ref_img = gr.State(value=None)
        ref_im_raw = gr.State(value=None)
        ref_kp_raw = gr.State(value=0)
        ref_kp_got = gr.State(value=None)
        dump = gr.State(value=None)
        ref_cond = gr.State(value=None)
        ref_manual_cond = gr.State(value=None)
        ref_auto_cond = gr.State(value=None)
        keypts = gr.State(value=None)
        target_img = gr.State(value=None)
        target_cond = gr.State(value=None)
        target_keypts = gr.State(value=None)
        dump = gr.State(value=None)
        with gr.Row():
            with gr.Column():
                gr.Markdown(
                    """<p style="text-align: center; font-size: 20px; font-weight: bold;">1. Upload a hand image to edit 📥</p>"""
                )
                gr.Markdown(
                    """<p style="text-align: center;">&#9312; Optionally crop the image</p>"""
                )
                ref = gr.ImageEditor(
                    type="numpy",
                    label="Reference",
                    show_label=True,
                    height=LENGTH,
                    width=LENGTH,
                    brush=False,
                    layers=False,
                    crop_size="1:1",
                )
                gr.Examples(example_ref_imgs, [ref], examples_per_page=20)
                gr.Markdown(
                    """<p style="text-align: center;">&#9313; Hit the &quot;Finish Cropping&quot; button to get hand pose</p>"""
                )
                ref_finish_crop = gr.Button(value="Finish Cropping", interactive=False)
                with gr.Tab("Automatic hand keypoints"):
                    ref_pose = gr.Image(
                        type="numpy",
                        label="Reference Pose",
                        show_label=True,
                        height=LENGTH,
                        width=LENGTH,
                        interactive=False,
                    )
                    ref_use_auto = gr.Button(value="Click here to use automatic, not manual", interactive=False, visible=True)
                with gr.Tab("Manual hand keypoints"):
                    ref_manual_checkbox_info = gr.Markdown(
                        """<p style="text-align: center;"><b>Step 1.</b> Tell us if this is right, left, or both hands.</p>""",
                        visible=True,
                    )
                    ref_manual_checkbox = gr.CheckboxGroup(
                        ["Right hand", "Left hand"],
                        show_label=False,
                        visible=True,
                        interactive=True,
                    )
                    ref_manual_kp_r_info = gr.Markdown(
                        """<p style="text-align: center;"><b>Step 2.</b> Click on image to provide hand keypoints for <b>right</b> hand. See \"OpenPose Keypoint Convention\" for guidance.</p>""",
                        visible=False,
                    )
                    ref_manual_kp_right = gr.Image(
                        type="numpy",
                        label="Keypoint Selection (right hand)",
                        show_label=True,
                        height=LENGTH,
                        width=LENGTH,
                        interactive=False,
                        visible=False,
                        sources=[],
                    )
                    with gr.Row():
                        ref_manual_undo_right = gr.Button(
                            value="Undo", interactive=True, visible=False
                        )
                        ref_manual_reset_right = gr.Button(
                            value="Reset", interactive=True, visible=False
                        )
                    ref_manual_kp_l_info = gr.Markdown(
                        """<p style="text-align: center;"><b>Step 2.</b> Click on image to provide hand keypoints for <b>left</b> hand. See \"OpenPose keypoint convention\" for guidance.</p>""",
                        visible=False
                    )
                    ref_manual_kp_left = gr.Image(
                        type="numpy",
                        label="Keypoint Selection (left hand)",
                        show_label=True,
                        height=LENGTH,
                        width=LENGTH,
                        interactive=False,
                        visible=False,
                        sources=[],
                    )
                    with gr.Row():
                        ref_manual_undo_left = gr.Button(
                            value="Undo", interactive=True, visible=False
                        )
                        ref_manual_reset_left = gr.Button(
                            value="Reset", interactive=True, visible=False
                        )
                    ref_manual_done_info = gr.Markdown(
                        """<p style="text-align: center;"><b>Step 3.</b> Hit \"Done\" button to confirm.</p>""",
                        visible=False,
                    )
                    ref_manual_done = gr.Button(value="Done", interactive=True, visible=False)
                    ref_manual_pose = gr.Image(
                        type="numpy",
                        label="Reference Pose",
                        show_label=True,
                        height=LENGTH,
                        width=LENGTH,
                        interactive=False,
                        visible=False
                    )
                    ref_use_manual = gr.Button(value="Click here to use manual, not automatic", interactive=True, visible=False)
                    ref_manual_instruct = gr.Markdown(
                        value="""<p style="text-align: left; font-weight: bold; ">OpenPose Keypoints Convention</p>""",
                        visible=True
                    )
                    ref_manual_openpose = gr.Image(
                        value="openpose.png",
                        type="numpy",
                        show_label=False,
                        height=LENGTH // 2,
                        width=LENGTH // 2,
                        interactive=False,
                        visible=True
                    )
                gr.Markdown(
                    """<p style="text-align: center;">&#9314; Optionally flip the hand</p>"""
                )
                ref_flip = gr.Checkbox(
                    value=False, label="Flip Handedness (Reference)", interactive=False
                )
            with gr.Column():
                gr.Markdown(
                    """<p style="text-align: center; font-size: 20px; font-weight: bold;">2. Upload a hand image for target hand pose 📥</p>"""
                )
                gr.Markdown(
                    """<p style="text-align: center;">&#9312; Optionally crop the image</p>"""
                )
                target = gr.ImageEditor(
                    type="numpy",
                    label="Target",
                    show_label=True,
                    height=LENGTH,
                    width=LENGTH,
                    brush=False,
                    layers=False,
                    crop_size="1:1",
                )
                gr.Examples(example_target_imgs, [target], examples_per_page=20)
                gr.Markdown(
                    """<p style="text-align: center;">&#9313; Hit the &quot;Finish Cropping&quot; button to get hand pose</p>"""
                )
                target_finish_crop = gr.Button(
                    value="Finish Cropping", interactive=False
                )
                target_pose = gr.Image(
                    type="numpy",
                    label="Target Pose",
                    show_label=True,
                    height=LENGTH,
                    width=LENGTH,
                    interactive=False,
                )
                gr.Markdown(
                    """<p style="text-align: center;">&#9314; Optionally flip the hand</p>"""
                )
                target_flip = gr.Checkbox(
                    value=False, label="Flip Handedness (Target)", interactive=False
                )
            with gr.Column():
                gr.Markdown(
                    """<p style="text-align: center; font-size: 20px; font-weight: bold;">3. Press &quot;Run&quot; to get the edited results 🎯</p>"""
                )
                run = gr.Button(value="Run", interactive=False)
                gr.Markdown(
                    """<p style="text-align: center;">⚠️ ~20s per generation with RTX3090. ~50s with A100. <br>(For example, if you set Number of generations as 2, it would take around 40s)</p>"""
                )
                results = gr.Gallery(
                    type="numpy",
                    label="Results",
                    show_label=True,
                    height=LENGTH,
                    min_width=LENGTH,
                    columns=MAX_N,
                    interactive=False,
                    preview=True,
                )
                results_pose = gr.Gallery(
                    type="numpy",
                    label="Results Pose",
                    show_label=True,
                    height=LENGTH,
                    min_width=LENGTH,
                    columns=MAX_N,
                    interactive=False,
                    preview=True,
                )
                gr.Markdown(
                    """<p style="text-align: center;">✨ Hit &quot;Clear&quot; to restart from the beginning</p>"""
                )
                clear = gr.ClearButton()

        with gr.Tab("More options"):
            with gr.Row():
                n_generation = gr.Slider(
                    label="Number of generations",
                    value=1,
                    minimum=1,
                    maximum=MAX_N,
                    step=1,
                    randomize=False,
                    interactive=True,
                )
                seed = gr.Slider(
                    label="Seed",
                    value=42,
                    minimum=0,
                    maximum=10000,
                    step=1,
                    randomize=False,
                    interactive=True,
                )
                cfg = gr.Slider(
                    label="Classifier free guidance scale",
                    value=2.5,
                    minimum=0.0,
                    maximum=10.0,
                    step=0.1,
                    randomize=False,
                    interactive=True,
                )

        ref.change(enable_component, [ref, ref], ref_finish_crop)
        ref_finish_crop.click(prepare_ref_anno, [ref], [ref_im_raw, ref_kp_raw])
        ref_kp_raw.change(lambda x: x, ref_im_raw, ref_manual_kp_right)
        ref_kp_raw.change(lambda x: x, ref_im_raw, ref_manual_kp_left)
        ref_manual_checkbox.select(
            set_visible,
            [ref_manual_checkbox, ref_kp_got, ref_im_raw, ref_manual_kp_right, ref_manual_kp_left, ref_manual_done],
            [
                ref_kp_got,
                ref_manual_kp_right,
                ref_manual_kp_left,
                ref_manual_kp_right,
                ref_manual_undo_right,
                ref_manual_reset_right,
                ref_manual_kp_left,
                ref_manual_undo_left,
                ref_manual_reset_left,
                ref_manual_kp_r_info,
                ref_manual_kp_l_info,
                ref_manual_done,
                ref_manual_done_info
            ]
        )
        ref_manual_kp_right.select(
            get_kps, [ref_im_raw, ref_kp_got, gr.State("right")], [ref_manual_kp_right, ref_kp_got]
        )
        ref_manual_undo_right.click(
            undo_kps, [ref_im_raw, ref_kp_got, gr.State("right")], [ref_manual_kp_right, ref_kp_got]
        )
        ref_manual_reset_right.click(
            reset_kps, [ref_im_raw, ref_kp_got, gr.State("right")], [ref_manual_kp_right, ref_kp_got]
        )
        ref_manual_kp_left.select(
            get_kps, [ref_im_raw, ref_kp_got, gr.State("left")], [ref_manual_kp_left, ref_kp_got]
        )
        ref_manual_undo_left.click(
            undo_kps, [ref_im_raw, ref_kp_got, gr.State("left")], [ref_manual_kp_left, ref_kp_got]
        )
        ref_manual_reset_left.click(
            reset_kps, [ref_im_raw, ref_kp_got, gr.State("left")], [ref_manual_kp_left, ref_kp_got]
        )
        ref_manual_done.click(get_ref_anno, [ref_im_raw, ref_kp_got], [ref_img, ref_manual_pose, ref_manual_cond])
        ref_manual_cond.change(lambda x: x, ref_manual_cond, ref_cond)
        ref_use_manual.click(lambda x: x, ref_manual_cond, ref_cond)
        # ref_use_manual.click(lambda x: gr.Info("Manual hand keypoints will be used for 'Reference'", duration=3))
        ref_manual_done.click(visible_component, [ref_manual_pose, ref_manual_pose], ref_manual_pose)
        ref_manual_done.click(visible_component, [ref_use_manual, ref_use_manual], ref_use_manual)
        ref_manual_pose.change(enable_component, [ref_manual_pose, ref_manual_pose], ref_manual_done)
        ref_kp_raw.change(get_ref_anno, [ref_im_raw, ref_kp_raw], [ref_img, ref_pose, ref_auto_cond])
        ref_auto_cond.change(lambda x: x, ref_auto_cond, ref_cond)
        ref_use_auto.click(lambda x: x, ref_auto_cond, ref_cond)
        # ref_use_auto.click(lambda x: gr.Info("Automatic hand keypoints will be used for 'Reference'", duration=3))
        ref_pose.change(enable_component, [ref_kp_raw, ref_pose], ref_use_auto)
        ref_pose.change(enable_component, [ref_img, ref_pose], ref_flip)
        ref_manual_pose.change(enable_component, [ref_img, ref_manual_pose], ref_flip)
        ref_flip.select(
            flip_hand, [ref, ref_pose, ref_cond, gr.State(value=None), ref_manual_pose, ref_manual_kp_right, ref_manual_kp_left], [ref, ref_pose, ref_cond, dump, ref_manual_pose, ref_manual_kp_right, ref_manual_kp_left]
        )
        target.change(enable_component, [target, target], target_finish_crop)
        target_finish_crop.click(
            get_target_anno,
            [target],
            [target_img, target_pose, target_cond, target_keypts],
        )
        target_pose.change(enable_component, [target_img, target_pose], target_flip)
        target_flip.select(
            flip_hand,
            [target, target_pose, target_cond, target_keypts],
            [target, target_pose, target_cond, target_keypts],
        )
        ref_pose.change(enable_component, [ref_pose, target_pose], run)
        ref_manual_pose.change(enable_component, [ref_manual_pose, target_pose], run)
        target_pose.change(enable_component, [ref_pose, target_pose], run)
        run.click(
            sample_diff,
            [ref_cond, target_cond, target_keypts, n_generation, seed, cfg],
            [results, results_pose],
        )
        clear.click(
            clear_all,
            [],
            [
                ref,
                ref_manual_kp_right,
                ref_manual_kp_left,
                ref_pose,
                ref_manual_pose,
                ref_flip,
                target,
                target_pose,
                target_flip,
                results,
                results_pose,
                ref_img,
                ref_cond,
                target_img,
                target_cond,
                target_keypts,
                n_generation,
                seed,
                cfg,
                ref_kp_raw,
                ref_manual_checkbox
            ],
        )
        clear.click(
            set_unvisible,
            [],
            [
                ref_manual_kp_r_info,
                ref_manual_kp_l_info,
                ref_manual_undo_left,
                ref_manual_undo_right,
                ref_manual_reset_left,
                ref_manual_reset_right,
                ref_manual_done,
                ref_manual_done_info,
                ref_manual_pose,
                ref_use_manual,
                ref_manual_kp_right,
                ref_manual_kp_left
            ]
        )

    with gr.Tab("Fix Hands"):
        fix_inpaint_mask = gr.State(value=None)
        fix_original = gr.State(value=None)
        fix_img = gr.State(value=None)
        fix_kpts = gr.State(value=None)
        fix_kpts_np = gr.State(value=None)
        fix_ref_cond = gr.State(value=None)
        fix_target_cond = gr.State(value=None)
        fix_latent = gr.State(value=None)
        fix_inpaint_latent = gr.State(value=None)
        with gr.Row():
            with gr.Column():
                gr.Markdown(
                    """<p style="text-align: center; font-size: 20px; font-weight: bold;">1. Upload a malformed hand image to fix 📥</p>"""
                )
                gr.Markdown(
                    """<p style="text-align: center;">&#9312; Optionally crop the image around the hand</p>"""
                )
                fix_crop = gr.ImageEditor(
                    type="numpy",
                    sources=["upload", "webcam", "clipboard"],
                    label="Image crop",
                    show_label=True,
                    height=LENGTH,
                    width=LENGTH,
                    layers=False,
                    crop_size="1:1",
                    brush=False,
                    image_mode="RGBA",
                    container=False,
                )
                fix_example = gr.Examples(
                    fix_example_imgs,
                    inputs=[fix_crop],
                    examples_per_page=20,
                )
                gr.Markdown(
                    """<p style="text-align: center;">&#9313; Brush area (e.g., wrong finger) that needs to be fixed. This will serve as an inpaint mask</p>"""
                )
                fix_ref = gr.ImageEditor(
                    type="numpy",
                    label="Image brush",
                    sources=(),
                    show_label=True,
                    height=LENGTH,
                    width=LENGTH,
                    layers=False,
                    transforms=("brush"),
                    brush=gr.Brush(
                        colors=["rgb(255, 255, 255)"], default_size=20
                    ),  # 204, 50, 50
                    image_mode="RGBA",
                    container=False,
                    interactive=False,
                )
                fix_finish_crop = gr.Button(
                    value="Finish Croping & Brushing", interactive=False
                )
            with gr.Column():
                gr.Markdown(
                    """<p style="text-align: center; font-size: 20px; font-weight: bold;">2. Click on hand to get target hand pose</p>"""
                )
                gr.Markdown(
                    """<p style="text-align: center;">&#9312; Tell us if this is right, left, or both hands</p>"""
                )
                fix_checkbox = gr.CheckboxGroup(
                    ["Right hand", "Left hand"],
                    show_label=False,
                    interactive=False,
                )
                gr.Markdown(
                    """<p style="text-align: center;">&#9313; On the image, click 21 hand keypoints. This will serve as target hand poses. See the \"OpenPose keypoints convention\" for guidance.</p>"""
                )
                fix_kp_r_info = gr.Markdown(
                    """<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select right only</p>""",
                    visible=False,
                )
                fix_kp_right = gr.Image(
                    type="numpy",
                    label="Keypoint Selection (right hand)",
                    show_label=True,
                    height=LENGTH,
                    width=LENGTH,
                    interactive=False,
                    visible=False,
                    sources=[],
                )
                with gr.Row():
                    fix_undo_right = gr.Button(
                        value="Undo", interactive=False, visible=False
                    )
                    fix_reset_right = gr.Button(
                        value="Reset", interactive=False, visible=False
                    )
                fix_kp_l_info = gr.Markdown(
                    """<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select left only</p>""",
                    visible=False
                )
                fix_kp_left = gr.Image(
                    type="numpy",
                    label="Keypoint Selection (left hand)",
                    show_label=True,
                    height=LENGTH,
                    width=LENGTH,
                    interactive=False,
                    visible=False,
                    sources=[],
                )
                with gr.Row():
                    fix_undo_left = gr.Button(
                        value="Undo", interactive=False, visible=False
                    )
                    fix_reset_left = gr.Button(
                        value="Reset", interactive=False, visible=False
                    )
                gr.Markdown(
                    """<p style="text-align: left; font-weight: bold; ">OpenPose keypoints convention</p>"""
                )
                fix_openpose = gr.Image(
                    value="openpose.png",
                    type="numpy",
                    show_label=False,
                    height=LENGTH // 2,
                    width=LENGTH // 2,
                    interactive=False,
                )
            with gr.Column():
                gr.Markdown(
                    """<p style="text-align: center; font-size: 20px; font-weight: bold;">3. Press &quot;Ready&quot; to start pre-processing</p>"""
                )
                fix_ready = gr.Button(value="Ready", interactive=False)
                gr.Markdown(
                    """<p style="text-align: center; font-weight: bold; ">Visualized (256, 256) Inpaint Mask</p>"""
                )
                fix_vis_mask32 = gr.Image(
                    type="numpy",
                    label=f"Visualized {opts.latent_size} Inpaint Mask",
                    show_label=True,
                    height=opts.latent_size,
                    width=opts.latent_size,
                    interactive=False,
                    visible=False,
                )
                fix_vis_mask256 = gr.Image(
                    type="numpy",
                    visible=True,
                    show_label=False,
                    height=opts.image_size,
                    width=opts.image_size,
                    interactive=False,
                )
                gr.Markdown(
                    """<p style="text-align: center;">[NOTE] Above should be inpaint mask that you brushed, NOT the segmentation mask of the entire hand. </p>"""
                )
            with gr.Column():
                gr.Markdown(
                    """<p style="text-align: center; font-size: 20px; font-weight: bold;">4. Press &quot;Run&quot; to get the fixed hand image 🎯</p>"""
                )
                fix_run = gr.Button(value="Run", interactive=False)
                gr.Markdown(
                    """<p style="text-align: center;">⚠️  >3min and ~24GB per generation</p>"""
                )
                fix_result = gr.Gallery(
                    type="numpy",
                    label="Results",
                    show_label=True,
                    height=LENGTH,
                    min_width=LENGTH,
                    columns=FIX_MAX_N,
                    interactive=False,
                    preview=True,
                )
                fix_result_pose = gr.Gallery(
                    type="numpy",
                    label="Results Pose",
                    show_label=True,
                    height=LENGTH,
                    min_width=LENGTH,
                    columns=FIX_MAX_N,
                    interactive=False,
                    preview=True,
                )
                gr.Markdown(
                    """<p style="text-align: center;">✨ Hit &quot;Clear&quot; to restart from the beginning</p>"""
                )
                fix_clear = gr.ClearButton()

        gr.Markdown(
            """<p style="text-align: left; font-size: 25px;"><b>More options</b></p>"""
        )
        gr.Markdown(
            "⚠️ Currently, Number of generation > 1 could lead to out-of-memory"
        )
        with gr.Row():
            fix_n_generation = gr.Slider(
                label="Number of generations",
                value=1,
                minimum=1,
                maximum=FIX_MAX_N,
                step=1,
                randomize=False,
                interactive=True,
            )
            fix_seed = gr.Slider(
                label="Seed",
                value=42,
                minimum=0,
                maximum=10000,
                step=1,
                randomize=False,
                interactive=True,
            )
            fix_cfg = gr.Slider(
                label="Classifier free guidance scale",
                value=3.0,
                minimum=0.0,
                maximum=10.0,
                step=0.1,
                randomize=False,
                interactive=True,
            )
            fix_quality = gr.Slider(
                label="Quality",
                value=10,
                minimum=1,
                maximum=10,
                step=1,
                randomize=False,
                interactive=True,
            )
        fix_crop.change(enable_component, [fix_crop, fix_crop], fix_ref)
        fix_crop.change(resize_to_full, fix_crop, fix_ref)
        fix_ref.change(enable_component, [fix_ref, fix_ref], fix_finish_crop)
        fix_finish_crop.click(get_mask_inpaint, [fix_ref], [fix_inpaint_mask])
        fix_finish_crop.click(lambda x: x["background"], [fix_crop], [fix_original])
        fix_finish_crop.click(visualize_ref, [fix_crop, fix_ref], [fix_img])
        fix_img.change(lambda x: x, [fix_img], [fix_kp_right])
        fix_img.change(lambda x: x, [fix_img], [fix_kp_left])
        fix_inpaint_mask.change(
            enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_checkbox
        )
        fix_inpaint_mask.change(
            enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_right
        )
        fix_inpaint_mask.change(
            enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_right
        )
        fix_inpaint_mask.change(
            enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_right
        )
        fix_inpaint_mask.change(
            enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_left
        )
        fix_inpaint_mask.change(
            enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_left
        )
        fix_inpaint_mask.change(
            enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_left
        )
        fix_inpaint_mask.change(
            enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_ready
        )
        fix_checkbox.select(
            set_visible,
            [fix_checkbox, fix_kpts, fix_img, fix_kp_right, fix_kp_left],
            [
                fix_kpts,
                fix_kp_right,
                fix_kp_left,
                fix_kp_right,
                fix_undo_right,
                fix_reset_right,
                fix_kp_left,
                fix_undo_left,
                fix_reset_left,
                fix_kp_r_info,
                fix_kp_l_info,
            ],
        )
        fix_kp_right.select(
            get_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
        )
        fix_undo_right.click(
            undo_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
        )
        fix_reset_right.click(
            reset_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
        )
        fix_kp_left.select(
            get_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
        )
        fix_undo_left.click(
            undo_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
        )
        fix_reset_left.click(
            reset_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
        )
        fix_vis_mask32.change(
            enable_component, [fix_vis_mask32, fix_vis_mask256], fix_run
        )
        fix_ready.click(
            ready_sample,
            [fix_original, fix_inpaint_mask, fix_kpts],
            [
                fix_ref_cond,
                fix_target_cond,
                fix_latent,
                fix_inpaint_latent,
                fix_kpts_np,
                fix_vis_mask32,
                fix_vis_mask256,
            ],
        )
        fix_run.click(
            sample_inpaint,
            [
                fix_ref_cond,
                fix_target_cond,
                fix_latent,
                fix_inpaint_latent,
                fix_kpts_np,
                fix_n_generation,
                fix_seed,
                fix_cfg,
                fix_quality,
            ],
            [fix_result, fix_result_pose],
        )
        fix_clear.click(
            fix_clear_all,
            [],
            [
                fix_crop,
                fix_ref,
                fix_kp_right,
                fix_kp_left,
                fix_result,
                fix_result_pose,
                fix_inpaint_mask,
                fix_original,
                fix_img,
                fix_vis_mask32,
                fix_vis_mask256,
                fix_kpts,
                fix_kpts_np,
                fix_ref_cond,
                fix_target_cond,
                fix_latent,
                fix_inpaint_latent,
                fix_n_generation,
                fix_seed,
                fix_cfg,
                fix_quality,
            ],
        )

    gr.Markdown("<h1>Citation</h1>")
    gr.Markdown(
        """<p style="text-align: left;">If this was useful, please cite us! ❤️</p>"""
    )
    gr.Markdown(_CITE_)

# print("Ready to launch..")
# _, _, shared_url = demo.queue().launch(
    # share=True, server_name="0.0.0.0", server_port=7739
# )
demo.launch(share=True)