FoundHand / app.py
chaerinmin's picture
improve simplicity, allow robot hand seg
ac9fa25
raw
history blame
89.1 kB
import os
import torch
from dataclasses import dataclass
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import cv2
import mediapipe as mp
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
import vqvae
import vit
from typing import Literal
from diffusion import create_diffusion
from utils import scale_keypoint, keypoint_heatmap, check_keypoints_validity
from segment_hoi import init_sam
from io import BytesIO
from PIL import Image
import random
from copy import deepcopy
from typing import Optional
import requests
from huggingface_hub import hf_hub_download
try:
import spaces
except:
pass
MAX_N = 6
FIX_MAX_N = 6
LENGTH = 480
placeholder = cv2.cvtColor(cv2.imread("placeholder.png"), cv2.COLOR_BGR2RGB)
NEW_MODEL = True
MODEL_EPOCH = 6
REF_POSE_MASK = True
HF = False
pre_device = "cpu" if HF else "cuda"
spaces_60_fn = spaces.GPU(duration=60) if HF else (lambda f: f)
spaces_120_fn = spaces.GPU(duration=60) if HF else (lambda f: f)
spaces_300_fn = spaces.GPU(duration=60) if HF else (lambda f: f)
def set_seed(seed):
seed = int(seed)
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
# if torch.cuda.is_available():
device = "cuda"
# else:
# device = "cpu"
def remove_prefix(text, prefix):
if text.startswith(prefix):
return text[len(prefix) :]
return text
def unnormalize(x):
return (((x + 1) / 2) * 255).astype(np.uint8)
def visualize_hand(all_joints, img, side=["right", "left"], n_avail_joints=21):
# Define the connections between joints for drawing lines and their corresponding colors
connections = [
((0, 1), "red"),
((1, 2), "green"),
((2, 3), "blue"),
((3, 4), "purple"),
((0, 5), "orange"),
((5, 6), "pink"),
((6, 7), "brown"),
((7, 8), "cyan"),
((0, 9), "yellow"),
((9, 10), "magenta"),
((10, 11), "lime"),
((11, 12), "indigo"),
((0, 13), "olive"),
((13, 14), "teal"),
((14, 15), "navy"),
((15, 16), "gray"),
((0, 17), "lavender"),
((17, 18), "silver"),
((18, 19), "maroon"),
((19, 20), "fuchsia"),
]
H, W, C = img.shape
# Create a figure and axis
plt.figure()
ax = plt.gca()
# Plot joints as points
ax.imshow(img)
start_is = []
if "right" in side:
start_is.append(0)
if "left" in side:
start_is.append(21)
for start_i in start_is:
joints = all_joints[start_i : start_i + n_avail_joints]
if len(joints) == 1:
ax.scatter(joints[0][0], joints[0][1], color="red", s=10)
else:
for connection, color in connections[: len(joints) - 1]:
joint1 = joints[connection[0]]
joint2 = joints[connection[1]]
ax.plot([joint1[0], joint2[0]], [joint1[1], joint2[1]], color=color)
ax.set_xlim([0, W])
ax.set_ylim([0, H])
ax.grid(False)
ax.set_axis_off()
ax.invert_yaxis()
# plt.subplots_adjust(wspace=0.01)
# plt.show()
buf = BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
plt.close()
# Convert BytesIO object to numpy array
buf.seek(0)
img_pil = Image.open(buf)
img_pil = img_pil.resize((W, H))
numpy_img = np.array(img_pil)
return numpy_img
def mask_image(image, mask, color=[0, 0, 0], alpha=0.6, transparent=True):
"""Overlay mask on image for visualization purpose.
Args:
image (H, W, 3) or (H, W): input image
mask (H, W): mask to be overlaid
color: the color of overlaid mask
alpha: the transparency of the mask
"""
out = deepcopy(image)
img = deepcopy(image)
img[mask == 1] = color
if transparent:
out = cv2.addWeighted(img, alpha, out, 1 - alpha, 0, out)
else:
out = img
return out
def scale_keypoint(keypoint, original_size, target_size):
"""Scale a keypoint based on the resizing of the image."""
keypoint_copy = keypoint.copy()
keypoint_copy[:, 0] *= target_size[0] / original_size[0]
keypoint_copy[:, 1] *= target_size[1] / original_size[1]
return keypoint_copy
print("Configure...")
@dataclass
class HandDiffOpts:
run_name: str = "ViT_256_handmask_heatmap_nvs_b25_lr1e-5"
sd_path: str = "/users/kchen157/scratch/weights/SD/sd-v1-4.ckpt"
log_dir: str = "/users/kchen157/scratch/log"
data_root: str = "/users/kchen157/data/users/kchen157/dataset/handdiff"
image_size: tuple = (256, 256)
latent_size: tuple = (32, 32)
latent_dim: int = 4
mask_bg: bool = False
kpts_form: str = "heatmap"
n_keypoints: int = 42
n_mask: int = 1
noise_steps: int = 1000
test_sampling_steps: int = 250
ddim_steps: int = 100
ddim_discretize: str = "uniform"
ddim_eta: float = 0.0
beta_start: float = 8.5e-4
beta_end: float = 0.012
latent_scaling_factor: float = 0.18215
cfg_pose: float = 5.0
cfg_appearance: float = 3.5
batch_size: int = 25
lr: float = 1e-5
max_epochs: int = 500
log_every_n_steps: int = 100
limit_val_batches: int = 1
n_gpu: int = 8
num_nodes: int = 1
precision: str = "16-mixed"
profiler: str = "simple"
swa_epoch_start: int = 10
swa_lrs: float = 1e-3
num_workers: int = 10
n_val_samples: int = 4
# load models
token = os.getenv("HF_TOKEN")
if NEW_MODEL:
opts = HandDiffOpts()
if MODEL_EPOCH == 7:
model_path = './DINO_EMA_11M_b50_lr1e-5_epoch7_step380k.ckpt'
elif MODEL_EPOCH == 6:
model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt"
if not os.path.exists(model_path):
model_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt", token=token)
elif MODEL_EPOCH == 4:
model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch4_step210k.ckpt"
elif MODEL_EPOCH == 10:
model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch10_step550k.ckpt"
else:
raise ValueError(f"new model epoch should be either 6 or 7, got {MODEL_EPOCH}")
vae_path = './vae-ft-mse-840000-ema-pruned.ckpt'
if not os.path.exists(vae_path):
vae_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="vae-ft-mse-840000-ema-pruned.ckpt", token=token)
# sd_path = './sd-v1-4.ckpt'
print('Load diffusion model...')
diffusion = create_diffusion(str(opts.test_sampling_steps))
model = vit.DiT_XL_2(
input_size=opts.latent_size[0],
latent_dim=opts.latent_dim,
in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
learn_sigma=True,
).to(device)
# ckpt_state_dict = torch.load(model_path)['model_state_dict']
ckpt_state_dict = torch.load(model_path, map_location='cpu')['ema_state_dict']
missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
model = model.to(device)
model.eval()
print(missing_keys, extra_keys)
assert len(missing_keys) == 0
vae_state_dict = torch.load(vae_path, map_location='cpu')['state_dict']
print(f"vae_state_dict encoder dtype: {vae_state_dict['encoder.conv_in.weight'].dtype}")
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False)
print(f"autoencoder encoder dtype: {next(autoencoder.encoder.parameters()).dtype}")
print(f"encoder before load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
print(f"encoder before load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
print(f"encoder after load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
print(f"encoder after load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
autoencoder = autoencoder.to(device)
autoencoder.eval()
print(f"encoder after eval() min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
print(f"encoder after eval() max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
print(f"autoencoder encoder after eval() dtype: {next(autoencoder.encoder.parameters()).dtype}")
assert len(missing_keys) == 0
sam_path = "sam_vit_h_4b8939.pth"
if not os.path.exists(sam_path):
sam_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="sam_vit_h_4b8939.pth", token=token)
sam_predictor = init_sam(ckpt_path=sam_path, device=pre_device)
print("Mediapipe hand detector and SAM ready...")
mp_hands = mp.solutions.hands
hands = mp_hands.Hands(
static_image_mode=True, # Use False if image is part of a video stream
max_num_hands=2, # Maximum number of hands to detect
min_detection_confidence=0.1,
)
no_hands_open = cv2.resize(np.array(Image.open("no_hands_open.jpeg"))[..., :3], (LENGTH, LENGTH))
def prepare_anno(ref, ref_is_user):
if not ref_is_user: # no_hand_open.jpeg
return gr.update(value=None), gr.update(value=None)
if ref is None or ref["background"] is None or ref["background"].sum()==0: # clear_all
return (
gr.update(value=None),
gr.update(value=None),
)
img = ref["composite"][..., :3]
img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
keypts = np.zeros((42, 2))
mp_pose = hands.process(img)
if mp_pose.multi_hand_landmarks:
# handedness is flipped assuming the input image is mirrored in MediaPipe
for hand_landmarks, handedness in zip(
mp_pose.multi_hand_landmarks, mp_pose.multi_handedness
):
# actually right hand
if handedness.classification[0].label == "Left":
start_idx = 0
# actually left hand
elif handedness.classification[0].label == "Right":
start_idx = 21
for i, landmark in enumerate(hand_landmarks.landmark):
keypts[start_idx + i] = [
landmark.x * opts.image_size[1],
landmark.y * opts.image_size[0],
]
print(f"keypts.max(): {keypts.max()}, keypts.min(): {keypts.min()}")
return img, keypts
else:
return img, None
def get_ref_anno(img, keypts):
if img.sum() == 0: # clear_all
return None, gr.update(), None, gr.update(), True
elif keypts is None: # hand not detected
no_hands = cv2.resize(np.array(Image.open("no_hands.png"))[..., :3], (LENGTH, LENGTH))
return None, no_hands, None, no_hands_open, False
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
if isinstance(keypts, list):
if len(keypts[0]) == 0:
keypts[0] = np.zeros((21, 2))
elif len(keypts[0]) == 21:
keypts[0] = np.array(keypts[0], dtype=np.float32)
else:
gr.Info("Number of right hand keypoints should be either 0 or 21.")
return None, None, None, gr.update(), gr.update()
if len(keypts[1]) == 0:
keypts[1] = np.zeros((21, 2))
elif len(keypts[1]) == 21:
keypts[1] = np.array(keypts[1], dtype=np.float32)
else:
gr.Info("Number of left hand keypoints should be either 0 or 21.")
return None, None, None, gr.update(), gr.update()
keypts = np.concatenate(keypts, axis=0)
if REF_POSE_MASK:
sam_predictor.set_image(img)
if keypts[0].sum() != 0 and keypts[21].sum() != 0:
# input_point = np.array([keypts[0], keypts[21]])
input_point = np.array(keypts)
input_box = np.stack([keypts.min(axis=0), keypts.max(axis=0)])
# input_label = np.array([1, 1])
elif keypts[0].sum() != 0:
# input_point = np.array(keypts[:1])
input_point = np.array(keypts[:21])
input_box = np.stack([keypts[:21].min(axis=0), keypts[:21].max(axis=0)])
# input_label = np.array([1])
elif keypts[21].sum() != 0:
input_point = np.array(keypts[21:])
# input_label = np.array([1])
input_box = np.stack([keypts[21:].min(axis=0), keypts[21:].max(axis=0)])
input_label = np.ones_like(input_point[:, 0]).astype(np.int32)
box_shift_ratio = 0.5
box_size_factor = 1.2
box_trans = input_box[0] * box_shift_ratio + input_box[1] * (1 - box_shift_ratio)
input_box = ((input_box - box_trans) * box_size_factor + box_trans).reshape(-1)
masks, _, _ = sam_predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box[None, :],
multimask_output=False,
)
hand_mask = masks[0]
masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None])
ref_pose = visualize_hand(keypts, masked_img)
else:
hand_mask = np.zeros_like(img[:,:, 0])
ref_pose = np.zeros_like(img)
def make_ref_cond(
img,
keypts,
hand_mask,
device="cuda",
target_size=(256, 256),
latent_size=(32, 32),
):
image_transform = Compose(
[
ToTensor(),
Resize(target_size),
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
image = image_transform(img).to(device)
kpts_valid = check_keypoints_validity(keypts, target_size)
heatmaps = torch.tensor(
keypoint_heatmap(
scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0
)
* kpts_valid[:, None, None],
dtype=torch.float,
device=device
)[None, ...]
mask = torch.tensor(
cv2.resize(
hand_mask.astype(int),
dsize=latent_size,
interpolation=cv2.INTER_NEAREST,
),
dtype=torch.float,
device=device,
).unsqueeze(0)[None, ...]
return image[None, ...], heatmaps, mask
print(f"img.max(): {img.max()}, img.min(): {img.min()}")
image, heatmaps, mask = make_ref_cond(
img,
keypts,
hand_mask,
device=pre_device,
target_size=opts.image_size,
latent_size=opts.latent_size,
)
print(f"image.max(): {image.max()}, image.min(): {image.min()}")
print(f"opts.latent_scaling_factor: {opts.latent_scaling_factor}")
print(f"autoencoder encoder before operating max: {min([p.min() for p in autoencoder.encoder.parameters()])}")
print(f"autoencoder encoder before operating min: {max([p.max() for p in autoencoder.encoder.parameters()])}")
print(f"autoencoder encoder before operating dtype: {next(autoencoder.encoder.parameters()).dtype}")
latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
print(f"latent.max(): {latent.max()}, latent.min(): {latent.min()}")
if not REF_POSE_MASK:
heatmaps = torch.zeros_like(heatmaps)
mask = torch.zeros_like(mask)
print(f"heatmaps.max(): {heatmaps.max()}, heatmaps.min(): {heatmaps.min()}")
print(f"mask.max(): {mask.max()}, mask.min(): {mask.min()}")
ref_cond = torch.cat([latent, heatmaps, mask], 1)
print(f"ref_cond.max(): {ref_cond.max()}, ref_cond.min(): {ref_cond.min()}")
return img, ref_pose, ref_cond, gr.update(), True
def get_target_anno(img, keypts):
if img.sum() == 0: # clear_all
return None, gr.update(), None, gr.update(), True
if keypts is None: # hands not detected
no_hands = cv2.resize(np.array(Image.open("no_hands.png"))[..., :3], (LENGTH, LENGTH))
return None, no_hands, None, None, no_hands_open, False
if isinstance(keypts, list):
if len(keypts[0]) == 0:
keypts[0] = np.zeros((21, 2))
elif len(keypts[0]) == 21:
keypts[0] = np.array(keypts[0], dtype=np.float32)
else:
gr.Info("Number of right hand keypoints should be either 0 or 21.")
return None, None, None, gr.update(), gr.update(), gr.update()
if len(keypts[1]) == 0:
keypts[1] = np.zeros((21, 2))
elif len(keypts[1]) == 21:
keypts[1] = np.array(keypts[1], dtype=np.float32)
else:
gr.Info("Number of left hand keypoints should be either 0 or 21.")
return None, None, None, gr.update(), gr.update(), gr.update()
keypts = np.concatenate(keypts, axis=0)
target_pose = visualize_hand(keypts, img)
kpts_valid = check_keypoints_validity(keypts, opts.image_size)
target_heatmaps = torch.tensor(
keypoint_heatmap(
scale_keypoint(keypts, opts.image_size, opts.latent_size),
opts.latent_size,
var=1.0,
)
* kpts_valid[:, None, None],
dtype=torch.float,
device=pre_device,
)[None, ...]
target_cond = torch.cat(
[target_heatmaps, torch.zeros_like(target_heatmaps)[:, :1]], 1
)
return img, target_pose, target_cond, keypts, gr.update(), True
# def get_mask_inpaint(ref):
# # inpaint_mask = np.zeros_like(img_original[:, :, 0])
# # cropped_mask = np.array(ref["layers"][0])[..., -1]
# # inpaint_mask[crop_coord[0][1]:crop_coord[1][1], crop_coord[0][0]:crop_coord[1][0]] = cropped_mask
# return inpaint_mask
def visualize_ref(ref):
if ref is None:
return None
# inpaint mask
inpaint_mask = np.array(ref["layers"][0])[..., -1]
inpaint_mask = cv2.resize(
inpaint_mask, opts.image_size, interpolation=cv2.INTER_AREA
)
inpaint_mask = (inpaint_mask >= 128).astype(np.uint8)
# viualization
inpainted = ref["layers"][0][..., -1]
img = ref["background"][..., :3]
# img = cv2.resize(img, inpainted.shape[::-1], interpolation=cv2.INTER_AREA)
mask = inpainted < 128
# img = img.astype(np.int32)
# img[mask, :] = img[mask, :] - 50
# img[np.any(img<0, axis=-1)]=0
# img = img.astype(np.uint8)
img = mask_image(img, mask)
return img, inpaint_mask
def get_kps(img, keypoints, side: Literal["right", "left"], evt: gr.SelectData):
if keypoints is None:
keypoints = [[], []]
kps = np.zeros((42, 2))
if side == "right":
if len(keypoints[0]) == 21:
gr.Info("21 keypoints for right hand already selected. Try reset if something looks wrong.")
else:
keypoints[0].append(list(evt.index))
len_kps = len(keypoints[0])
kps[:len_kps] = np.array(keypoints[0])
elif side == "left":
if len(keypoints[1]) == 21:
gr.Info("21 keypoints for left hand already selected. Try reset if something looks wrong.")
else:
keypoints[1].append(list(evt.index))
len_kps = len(keypoints[1])
kps[21 : 21 + len_kps] = np.array(keypoints[1])
vis_hand = visualize_hand(kps, img, side, len_kps)
return vis_hand, keypoints
def undo_kps(img, keypoints, side: Literal["right", "left"]):
if keypoints is None:
return img, None
kps = np.zeros((42, 2))
if side == "right":
if len(keypoints[0]) == 0:
return img, keypoints
keypoints[0].pop()
len_kps = len(keypoints[0])
kps[:len_kps] = np.array(keypoints[0])
elif side == "left":
if len(keypoints[1]) == 0:
return img, keypoints
keypoints[1].pop()
len_kps = len(keypoints[1])
kps[21 : 21 + len_kps] = np.array(keypoints[1])
vis_hand = visualize_hand(kps, img, side, len_kps)
return vis_hand, keypoints
def reset_kps(img, keypoints, side: Literal["right", "left"]):
if keypoints is None:
return img, None
if side == "right":
keypoints[0] = []
elif side == "left":
keypoints[1] = []
return img, keypoints
def stay_crop(img, crop_coord):
if img is not None:
crop_coord = [[0, 0], [img.shape[1], img.shape[0]]]
cropped = img.copy()
return crop_coord, cropped
else:
return None, None
def process_crop(img, crop_coord, evt:gr.SelectData):
if len(crop_coord) == 2:
crop_coord = [list(evt.index)]
cropped = img.copy()
elif len(crop_coord) == 1:
new_coord =list(evt.index)
if new_coord[0] <= crop_coord[0][0] or new_coord[1] <= crop_coord[0][1]:
gr.Warning("Second click should be more under and more right thand the first click. Try second click again.", duration=3)
cropped = img.copy()
else:
crop_coord.append(new_coord)
x1, y1 = crop_coord[0]
x2, y2 = crop_coord[1]
cropped = img.copy()[y1:y2, x1:x2]
else:
gr.Error("Something is wrong", duration=3)
return crop_coord, cropped
def disable_crop(crop_coord):
if len(crop_coord) == 2:
return gr.update(interactive=False)
else:
return gr.update(interactive=True)
@spaces_60_fn
def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg):
set_seed(seed)
z = torch.randn(
(num_gen, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]),
device=device,
)
print(f"z.device: {z.device}")
target_cond = target_cond.repeat(num_gen, 1, 1, 1).to(z.device)
ref_cond = ref_cond.repeat(num_gen, 1, 1, 1).to(z.device)
print(f"target_cond.max(): {target_cond.max()}, target_cond.min(): {target_cond.min()}")
print(f"ref_cond.max(): {ref_cond.max()}, ref_cond.min(): {ref_cond.min()}")
# novel view synthesis mode = off
nvs = torch.zeros(num_gen, dtype=torch.int, device=device)
z = torch.cat([z, z], 0)
model_kwargs = dict(
target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]),
ref_cond=torch.cat([ref_cond, torch.zeros_like(ref_cond)]),
nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]),
cfg_scale=cfg,
)
samples, _ = diffusion.p_sample_loop(
model.forward_with_cfg,
z.shape,
z,
clip_denoised=False,
model_kwargs=model_kwargs,
progress=True,
device=device,
).chunk(2)
sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy())
results = []
results_pose = []
for i in range(MAX_N):
if i < num_gen:
results.append(sampled_images[i])
results_pose.append(visualize_hand(target_keypts, sampled_images[i]))
else:
results.append(placeholder)
results_pose.append(placeholder)
print(f"results[0].max(): {results[0].max()}")
return results, results_pose
@spaces_120_fn
def ready_sample(img_cropped, inpaint_mask, keypts):
# img = cv2.resize(img_ori[..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
img = cv2.resize(img_cropped["background"][..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
sam_predictor.set_image(img)
if len(keypts[0]) == 0:
keypts[0] = np.zeros((21, 2))
elif len(keypts[0]) == 21:
keypts[0] = np.array(keypts[0], dtype=np.float32)
# keypts[0][:, 0] = keypts[0][:, 0] + crop_coord[0][0]
# keypts[0][:, 1] = keypts[0][:, 1] + crop_coord[0][1]
else:
gr.Info("Number of right hand keypoints should be either 0 or 21.")
return None, None
if len(keypts[1]) == 0:
keypts[1] = np.zeros((21, 2))
elif len(keypts[1]) == 21:
keypts[1] = np.array(keypts[1], dtype=np.float32)
# keypts[1][:, 0] = keypts[1][:, 0] + crop_coord[0][0]
# keypts[1][:, 1] = keypts[1][:, 1] + crop_coord[0][1]
else:
gr.Info("Number of left hand keypoints should be either 0 or 21.")
return None, None
keypts = np.concatenate(keypts, axis=0)
keypts = scale_keypoint(keypts, (img_cropped["background"].shape[1], img_cropped["background"].shape[0]), opts.image_size)
box_shift_ratio = 0.5
box_size_factor = 1.2
if keypts[0].sum() != 0 and keypts[21].sum() != 0:
input_point = np.array(keypts)
input_box = np.stack([keypts.min(axis=0), keypts.max(axis=0)])
elif keypts[0].sum() != 0:
input_point = np.array(keypts[:21])
input_box = np.stack([keypts[:21].min(axis=0), keypts[:21].max(axis=0)])
elif keypts[21].sum() != 0:
input_point = np.array(keypts[21:])
input_box = np.stack([keypts[21:].min(axis=0), keypts[21:].max(axis=0)])
else:
raise ValueError(
"Something wrong. If no hand detected, it should not reach here."
)
input_label = np.ones_like(input_point[:, 0]).astype(np.int32)
box_trans = input_box[0] * box_shift_ratio + input_box[1] * (1 - box_shift_ratio)
input_box = ((input_box - box_trans) * box_size_factor + box_trans).reshape(-1)
masks, _, _ = sam_predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box[None, :],
multimask_output=False,
)
hand_mask = masks[0]
inpaint_latent_mask = torch.tensor(
cv2.resize(
inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
),
dtype=torch.float,
device=pre_device,
).unsqueeze(0)[None, ...]
def make_ref_cond(
img,
keypts,
hand_mask,
device=device,
target_size=(256, 256),
latent_size=(32, 32),
):
image_transform = Compose(
[
ToTensor(),
Resize(target_size),
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
image = image_transform(img).to(device)
kpts_valid = check_keypoints_validity(keypts, target_size)
heatmaps = torch.tensor(
keypoint_heatmap(
scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0
)
* kpts_valid[:, None, None],
dtype=torch.float,
device=device,
)[None, ...]
mask = torch.tensor(
cv2.resize(
hand_mask.astype(int),
dsize=latent_size,
interpolation=cv2.INTER_NEAREST,
),
dtype=torch.float,
device=device,
).unsqueeze(0)[None, ...]
return image[None, ...], heatmaps, mask
image, heatmaps, mask = make_ref_cond(
img,
keypts,
hand_mask * (1 - inpaint_mask),
device=pre_device,
target_size=opts.image_size,
latent_size=opts.latent_size,
)
latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
target_cond = torch.cat([heatmaps, torch.zeros_like(mask)], 1)
ref_cond = torch.cat([latent, heatmaps, mask], 1)
ref_cond = torch.zeros_like(ref_cond)
img32 = cv2.resize(img, opts.latent_size, interpolation=cv2.INTER_NEAREST)
assert mask.max() == 1
vis_mask32 = mask_image(
img32, inpaint_latent_mask[0,0].cpu().numpy(), (255,255,255), transparent=False
).astype(np.uint8) # 1.0 - mask[0, 0].cpu().numpy()
assert np.unique(inpaint_mask).shape[0] <= 2
assert hand_mask.dtype == bool
mask256 = inpaint_mask # hand_mask * (1 - inpaint_mask)
vis_mask256 = mask_image(img, mask256, (255,255,255), transparent=False).astype(
np.uint8
) # 1 - mask256
return (
ref_cond,
target_cond,
latent,
inpaint_latent_mask,
keypts,
vis_mask32,
vis_mask256,
)
def switch_mask_size(radio):
if radio == "256x256":
out = (gr.update(visible=False), gr.update(visible=True))
elif radio == "latent size (32x32)":
out = (gr.update(visible=True), gr.update(visible=False))
return out
@spaces_300_fn
def sample_inpaint(
ref_cond,
target_cond,
latent,
inpaint_latent_mask,
keypts,
img_original,
crop_coord,
num_gen,
seed,
cfg,
quality,
):
if keypts is None:
return None, None, None
set_seed(seed)
N = num_gen
jump_length = 10
jump_n_sample = quality
cfg_scale = cfg
z = torch.randn(
(N, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device=device
)
target_cond_N = target_cond.repeat(N, 1, 1, 1).to(z.device)
ref_cond_N = ref_cond.repeat(N, 1, 1, 1).to(z.device)
# novel view synthesis mode = off
nvs = torch.zeros(N, dtype=torch.int, device=device)
z = torch.cat([z, z], 0)
model_kwargs = dict(
target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]),
ref_cond=torch.cat([ref_cond_N, torch.zeros_like(ref_cond_N)]),
nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]),
cfg_scale=cfg_scale,
)
samples, _ = diffusion.inpaint_p_sample_loop(
model.forward_with_cfg,
z.shape,
latent.to(z.device),
inpaint_latent_mask.to(z.device),
z,
clip_denoised=False,
model_kwargs=model_kwargs,
progress=True,
device=z.device,
jump_length=jump_length,
jump_n_sample=jump_n_sample,
).chunk(2)
sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy())
# visualize
results = []
results_pose = []
results_original = []
for i in range(FIX_MAX_N):
if i < num_gen:
res =sampled_images[i]
results.append(res)
results_pose.append(visualize_hand(keypts, res))
res = cv2.resize(res, (crop_coord[1][0]-crop_coord[0][0], crop_coord[1][1]-crop_coord[0][1]))
res_original = img_original.copy()
res_original[crop_coord[0][1]:crop_coord[1][1], crop_coord[0][0]:crop_coord[1][0], :] = res
results_original.append(res_original)
else:
results.append(placeholder)
results_pose.append(placeholder)
results_original.append(placeholder)
return results, results_pose, results_original
def flip_hand(
img, img_raw, pose_img, pose_manual_img,
manual_kp_right, manual_kp_left,
cond, auto_cond, manual_cond,
keypts=None, auto_keypts=None, manual_keypts=None
):
if cond is None: # clear clicked
return
img["composite"] = img["composite"][:, ::-1, :]
img["background"] = img["background"][:, ::-1, :]
img["layers"] = [layer[:, ::-1, :] for layer in img["layers"]]
# for comp in [pose_img, pose_manual_img, manual_kp_right, manual_kp_left, cond, auto_cond, manual_cond]:
# if comp is not None:
# if isinstance(comp, torch.Tensor):
# comp = comp.flip(-1)
# else:
# comp = comp[:, ::-1, :]
if img_raw is not None:
img_raw = img_raw[:, ::-1, :]
pose_img = pose_img[:, ::-1, :]
if pose_manual_img is not None:
pose_manual_img = pose_manual_img[:, ::-1, :]
if manual_kp_right is not None:
manual_kp_right = manual_kp_right[:, ::-1, :]
if manual_kp_left is not None:
manual_kp_left = manual_kp_left[:, ::-1, :]
cond = cond.flip(-1)
if auto_cond is not None:
auto_cond = auto_cond.flip(-1)
if manual_cond is not None:
manual_cond = manual_cond.flip(-1)
# for comp in [keypts, auto_keypts, manual_keypts]:
# if comp is not None:
# if comp[:21, :].sum() != 0:
# comp[:21, 0] = opts.image_size[1] - comp[:21, 0]
# if comp[21:, :].sum() != 0:
# comp[21:, 0] = opts.image_size[1] - comp[21:, 0]
if keypts is not None:
if keypts[:21, :].sum() != 0:
keypts[:21, 0] = opts.image_size[1] - keypts[:21, 0]
if keypts[21:, :].sum() != 0:
keypts[21:, 0] = opts.image_size[1] - keypts[21:, 0]
if auto_keypts is not None:
if auto_keypts[:21, :].sum() != 0:
auto_keypts[:21, 0] = opts.image_size[1] - auto_keypts[:21, 0]
if auto_keypts[21:, :].sum() != 0:
auto_keypts[21:, 0] = opts.image_size[1] - auto_keypts[21:, 0]
if manual_keypts is not None:
if manual_keypts[:21, :].sum() != 0:
manual_keypts[:21, 0] = opts.image_size[1] - manual_keypts[:21, 0]
if manual_keypts[21:, :].sum() != 0:
manual_keypts[21:, 0] = opts.image_size[1] - manual_keypts[21:, 0]
return img, img_raw, pose_img, pose_manual_img, manual_kp_right, manual_kp_left, cond, auto_cond, manual_cond, keypts, auto_keypts, manual_keypts
def resize_to_full(img):
img["background"] = cv2.resize(img["background"], (LENGTH, LENGTH))
img["composite"] = cv2.resize(img["composite"], (LENGTH, LENGTH))
img["layers"] = [cv2.resize(layer, (LENGTH, LENGTH)) for layer in img["layers"]]
return img
def clear_all():
return (
None,
[],
None,
None,
None,
None,
None,
None,
False,
None,
None,
[],
None,
None,
None,
None,
None,
None,
False,
None,
None,
1,
42,
3.0,
gr.update(interactive=False),
)
def fix_clear_all():
return (
None,
None,
None,
[],
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
1,
# (0,0),
42,
3.0,
10,
)
def enable_component(image1, image2):
if image1 is None or image2 is None:
return gr.update(interactive=False)
if isinstance(image1, dict) and "background" in image1 and "layers" in image1 and "composite" in image1:
if image1["background"] is None or (
image1["background"].sum() == 0
and (sum([im.sum() for im in image1["layers"]]) == 0)
and image1["composite"].sum() == 0
):
return gr.update(interactive=False)
if isinstance(image1, dict) and "background" in image2 and "layers" in image2 and "composite" in image2:
if image2["background"] is None or (
image2["background"].sum() == 0
and (sum([im.sum() for im in image2["layers"]]) == 0)
and image2["composite"].sum() == 0
):
return gr.update(interactive=False)
return gr.update(interactive=True)
def set_visible(checkbox, kpts, img_clean, img_pose_right, img_pose_left, done=None, done_info=None):
if kpts is None:
kpts = [[], []]
if "Right hand" not in checkbox:
kpts[0] = []
vis_right = img_clean
update_right = gr.update(visible=False)
update_r_info = gr.update(visible=False)
else:
vis_right = img_pose_right
update_right = gr.update(visible=True)
update_r_info = gr.update(visible=True)
if "Left hand" not in checkbox:
kpts[1] = []
vis_left = img_clean
update_left = gr.update(visible=False)
update_l_info = gr.update(visible=False)
else:
vis_left = img_pose_left
update_left = gr.update(visible=True)
update_l_info = gr.update(visible=True)
ret = [
kpts,
vis_right,
vis_left,
update_right,
update_right,
update_right,
update_left,
update_left,
update_left,
update_r_info,
update_l_info,
]
if done is not None:
if not checkbox:
ret.append(gr.update(visible=False))
ret.append(gr.update(visible=False))
else:
ret.append(gr.update(visible=True))
ret.append(gr.update(visible=True))
return tuple(ret)
def set_unvisible():
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False)
)
def fix_set_unvisible():
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False)
)
def visible_component(decider, component):
if decider is not None:
update_component = gr.update(visible=True)
else:
update_component = gr.update(visible=False)
return update_component
def unvisible_component(decider, component):
if decider is not None:
update_component = gr.update(visible=False)
else:
update_component = gr.update(visible=True)
return update_component
example_ref_imgs = [
[
"sample_images/sample1.jpg",
],
[
"sample_images/sample2.jpg",
],
[
"sample_images/sample3.jpg",
],
[
"sample_images/sample4.jpg",
],
# [
# "sample_images/sample5.jpg",
# ],
[
"sample_images/sample6.jpg",
],
# [
# "sample_images/sample7.jpg",
# ],
# [
# "sample_images/sample8.jpg",
# ],
# [
# "sample_images/sample9.jpg",
# ],
# [
# "sample_images/sample10.jpg",
# ],
# [
# "sample_images/sample11.jpg",
# ],
# ["pose_images/pose1.jpg"],
# ["pose_images/pose2.jpg"],
# ["pose_images/pose3.jpg"],
# ["pose_images/pose4.jpg"],
# ["pose_images/pose5.jpg"],
# ["pose_images/pose6.jpg"],
# ["pose_images/pose7.jpg"],
# ["pose_images/pose8.jpg"],
]
example_target_imgs = [
# [
# "sample_images/sample1.jpg",
# ],
# [
# "sample_images/sample2.jpg",
# ],
# [
# "sample_images/sample3.jpg",
# ],
# [
# "sample_images/sample4.jpg",
# ],
[
"sample_images/sample5.jpg",
],
# [
# "sample_images/sample6.jpg",
# ],
# [
# "sample_images/sample7.jpg",
# ],
# [
# "sample_images/sample8.jpg",
# ],
[
"sample_images/sample9.jpg",
],
[
"sample_images/sample10.jpg",
],
[
"sample_images/sample11.jpg",
],
["pose_images/pose1.jpg"],
# ["pose_images/pose2.jpg"],
# ["pose_images/pose3.jpg"],
# ["pose_images/pose4.jpg"],
# ["pose_images/pose5.jpg"],
# ["pose_images/pose6.jpg"],
# ["pose_images/pose7.jpg"],
# ["pose_images/pose8.jpg"],
]
fix_example_imgs = [
["bad_hands/1.jpg"], # "bad_hands/1_mask.jpg"],
# ["bad_hands/2.jpg"], # "bad_hands/2_mask.jpg"],
["bad_hands/3.jpg"], # "bad_hands/3_mask.jpg"],
# ["bad_hands/4.jpg"], # "bad_hands/4_mask.jpg"],
["bad_hands/5.jpg"], # "bad_hands/5_mask.jpg"],
["bad_hands/6.jpg"], # "bad_hands/6_mask.jpg"],
# ["bad_hands/7.jpg"], # "bad_hands/7_mask.jpg"],
# ["bad_hands/8.jpg"], # "bad_hands/8_mask.jpg"],
# ["bad_hands/9.jpg"], # "bad_hands/9_mask.jpg"],
# ["bad_hands/10.jpg"], # "bad_hands/10_mask.jpg"],
# ["bad_hands/11.jpg"], # "bad_hands/11_mask.jpg"],
# ["bad_hands/12.jpg"], # "bad_hands/12_mask.jpg"],
# ["bad_hands/13.jpg"], # "bad_hands/13_mask.jpg"],
["bad_hands/14.jpg"],
["bad_hands/15.jpg"],
]
custom_css = """
.gradio-container .examples img {
width: 240px !important;
height: 240px !important;
}
#fix-tab-button {
font-size: 18px !important;
font-weight: bold !important;
background-color: #FFDAB9 !important;
}
#repose-tab-button {
font-size: 18px !important;
font-weight: bold !important;
background-color: #90EE90 !important;
}
"""
# color: black !important;
_HEADER_ = '''
<div style="text-align: center;">
<h1><b>FoundHand: Large-Scale Domain-Specific Learning for Controllable Hand Image Generation</b></h1>
<h2 style="color: #777777;">CVPR 2025 <span style="color: #990000; font-style: italic;">highlight</span></h2>
<style>
.link-spacing {
margin-right: 20px;
}
</style>
<p style="font-size: 15px;">
<a href="https://arthurchen0518.github.io/" class="link-spacing">Kefan Chen<sup>1,2*</sup></a>
<a href="https://chaerinmin.github.io/" class="link-spacing">Chaerin Min<sup>1*</sup></a>
<a href="https://lg-zhang.github.io/" class="link-spacing">Linguang Zhang<sup>2</sup></a>
<a href="https://shreyashampali.github.io/" class="link-spacing">Shreyas Hampali<sup>2</sup></a>
<a href="https://scholar.google.co.uk/citations?user=9HoiYnYAAAAJ&hl=en" class="link-spacing">Cem Keskin<sup>2</sup></a>
<a href="https://cs.brown.edu/people/ssrinath/" class="link-spacing">Srinath Sridhar<sup>1</sup></a>
</p>
<p style="font-size: 15px;">
<span style="display: inline-block; margin-right: 30px;"><sup>1</sup>Brown University</span>
<span style="display: inline-block;"><sup>2</sup>Meta Reality Labs</span>
</p>
<h3>
<a href='https://arxiv.org/abs/2412.02690' target='_blank' class="link-spacing">Paper</a>
<a href='https://ivl.cs.brown.edu/research/foundhand.html' target='_blank' class="link-spacing">Project Page</a>
<a href='' target='_blank' class="link-spacing">Code</a>
<a href='' target='_blank'>Model Weights</a>
</h3>
<p>Below are two important abilities of our model. First, we can automatically <b>fix malformed hand images</b>, following the user-provided target hand pose and area to fix. Second, we can <b>repose hand</b> given two hand images - one is the image to edit, and the other one provides target hand pose.</p>
</div>
'''
_CITE_ = r"""
<pre style="white-space: pre-wrap; margin: 0;">
@article{chen2024foundhand,
title={FoundHand: Large-Scale Domain-Specific Learning for Controllable Hand Image Generation},
author={Chen, Kefan and Min, Chaerin and Zhang, Linguang and Hampali, Shreyas and Keskin, Cem and Sridhar, Srinath},
journal={arXiv preprint arXiv:2412.02690},
year={2024}
}
</pre>
"""
_ACK_ = r"""
<pre style="white-space: pre-wrap; margin: 0;">
Part of this work was done during Kefan (Arthur) Chen’s internship at Meta Reality Lab. This work was additionally supported by NSF CAREER grant #2143576, NASA grant #80NSSC23M0075, and an Amazon Cloud Credits Award.
</pre>
"""
with gr.Blocks(css=custom_css, theme="soft") as demo:
gr.Markdown(_HEADER_)
with gr.Tab("Demo 1. Malformed Hand Correction", elem_id="fix-tab"):
# gr.Markdown("""<p style="background-color: #FFDAB9; text-align: center; font-size: 20px; font-weight: bold;">Demo 1. Malformed Hand Correction</p>""")
fix_inpaint_mask = gr.State(value=None)
fix_original = gr.State(value=None)
fix_crop_coord = gr.State(value=None)
fix_img = gr.State(value=None)
fix_kpts = gr.State(value=None)
fix_kpts_np = gr.State(value=None)
fix_ref_cond = gr.State(value=None)
fix_target_cond = gr.State(value=None)
fix_latent = gr.State(value=None)
fix_inpaint_latent = gr.State(value=None)
with gr.Row():
# crop & brush
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 18px; font-weight: bold;">1. Upload a malformed hand image 📥</p>"""
)
gr.Markdown(
"""<p style="text-align: center;">Optionally crop the image.<br>(Click <b>top left</b> and <b>bottom right</b> of your desired bounding box around the hand)</p>"""
)
# fix_crop = gr.ImageEditor(
# type="numpy",
# sources=["upload", "webcam", "clipboard"],
# label="Image crop",
# show_label=True,
# height=LENGTH,
# width=LENGTH,
# layers=False,
# # crop_size="1:1",
# transforms=(),
# brush=False,
# image_mode="RGBA",
# container=False,
# )
fix_crop = gr.Image(
type="numpy",
sources=["upload", "webcam", "clipboard"],
label="Input Image",
show_label=True,
height=LENGTH,
width=LENGTH,
interactive=True,
visible=True,
)
gr.Markdown(
"""<p style="text-align: center;">💡 If you crop, the model can focus on more details of the cropped area. Square crops might work better than rectangle crops.</p>"""
)
# fix_tmp = gr.Image(
# type="numpy",
# label="tmp",
# show_label=True,
# height=LENGTH,
# width=LENGTH,
# interactive=True,
# visible=True,
# sources=[],
# )
fix_example = gr.Examples(
fix_example_imgs,
inputs=[fix_crop],
examples_per_page=20,
)
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 18px; font-weight: bold;">2. Brush wrong finger and its surrounding area</p>"""
)
gr.Markdown(
"""<p style="text-align: center;">Don't brush the entire hand!</p>"""
)
fix_ref = gr.ImageEditor(
type="numpy",
label="Image Brushing",
sources=(),
show_label=True,
height=LENGTH,
width=LENGTH,
layers=False,
transforms=("brush"),
brush=gr.Brush(
colors=["rgb(255, 255, 255)"], default_size=20
), # 204, 50, 50
image_mode="RGBA",
container=False,
interactive=False,
)
# gr.Markdown(
# """<p style="text-align: center;">&#9314; Hit the \"Finish Cropping & Brushing\" button</p>"""
# )
fix_finish_crop = gr.Button(
value="Finish Croping & Brushing", interactive=False
)
# keypoint selection
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 18px; font-weight: bold;">3. Click on hand to get target hand pose</p>"""
)
gr.Markdown(
"""<p style="text-align: center;">&#9312; Tell us if this is right, left, or both hands</p>"""
)
fix_checkbox = gr.CheckboxGroup(
["Right hand", "Left hand"],
show_label=False,
interactive=False,
)
fix_kp_r_info = gr.Markdown(
"""<p style="text-align: center;">&#9313; Click 21 keypoints on the image to provide the target hand pose of <b>right hand</b>. See the \"OpenPose keypoints convention\" for guidance.</p>""",
visible=False
)
# fix_kp_r_info = gr.Markdown(
# """<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select right only</p>""",
# visible=False,
# )
fix_kp_right = gr.Image(
type="numpy",
label="Keypoint Selection (right hand)",
show_label=True,
height=LENGTH,
width=LENGTH,
interactive=False,
visible=False,
sources=[],
)
with gr.Row():
fix_undo_right = gr.Button(
value="Undo", interactive=False, visible=False
)
fix_reset_right = gr.Button(
value="Reset", interactive=False, visible=False
)
fix_kp_l_info = gr.Markdown(
"""<p style="text-align: center;">&#9313; Click 21 keypoints on the image to provide the target hand pose of <b>left hand</b>. See the \"OpenPose keypoints convention\" for guidance.</p>""",
visible=False
)
fix_kp_left = gr.Image(
type="numpy",
label="Keypoint Selection (left hand)",
show_label=True,
height=LENGTH,
width=LENGTH,
interactive=False,
visible=False,
sources=[],
)
with gr.Row():
fix_undo_left = gr.Button(
value="Undo", interactive=False, visible=False
)
fix_reset_left = gr.Button(
value="Reset", interactive=False, visible=False
)
gr.Markdown(
"""<p style="text-align: left; font-weight: bold; ">OpenPose keypoints convention</p>"""
)
fix_openpose = gr.Image(
value="openpose.png",
type="numpy",
show_label=False,
height=LENGTH // 2,
width=LENGTH // 2,
interactive=False,
)
# get latent
# with gr.Column():
# result column
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 18px; font-weight: bold;">4. Press &quot;Run&quot; to get the corrected hand image 🎯</p>"""
)
# gr.Markdown(
# """<p style="text-align: center; font-size: 18px; font-weight: bold;">3. Press &quot;Ready&quot; to start pre-processing</p>"""
# )
# fix_ready = gr.Button(value="Ready", interactive=False)
# gr.Markdown(
# """<p style="text-align: center; font-weight: bold; ">Visualized (256, 256)-resized, brushed image</p>"""
# )
fix_vis_mask32 = gr.Image(
type="numpy",
label=f"Visualized {opts.latent_size} Inpaint Mask",
show_label=True,
height=opts.latent_size,
width=opts.latent_size,
interactive=False,
visible=False,
)
fix_run = gr.Button(value="Run", interactive=False)
with gr.Accordion(label="Visualized (256, 256) resized, brushed image", open=False):
fix_vis_mask256 = gr.Image(
type="numpy",
show_label=False,
height=opts.image_size,
width=opts.image_size,
interactive=False,
visible=True,
)
# gr.Markdown(
# """<p style="text-align: center;">[NOTE] Above should be inpaint mask that you brushed, NOT the segmentation mask of the entire hand. </p>"""
# )
gr.Markdown(
"""<p style="text-align: center;">⚠️ >3min and ~24GB per generation</p>"""
)
fix_result_original = gr.Gallery(
type="numpy",
label="Results on original input",
show_label=True,
height=LENGTH,
min_width=LENGTH,
columns=FIX_MAX_N,
interactive=False,
preview=True,
)
with gr.Accordion(label="Results of cropped area / Results with pose", open=False):
fix_result = gr.Gallery(
type="numpy",
label="Results",
show_label=True,
height=LENGTH,
min_width=LENGTH,
columns=FIX_MAX_N,
interactive=False,
preview=True,
)
fix_result_pose = gr.Gallery(
type="numpy",
label="Results Pose",
show_label=True,
height=LENGTH,
min_width=LENGTH,
columns=FIX_MAX_N,
interactive=False,
preview=True,
)
gr.Markdown(
"""<p style="text-align: center;">✨ Hit &quot;Clear&quot; to restart from the beginning</p>"""
)
fix_clear = gr.ClearButton()
with gr.Accordion(label="More options", open=False):
gr.Markdown(
"⚠️ Currently, Number of generation > 1 could lead to out-of-memory"
)
with gr.Row():
fix_n_generation = gr.Slider(
label="Number of generations",
value=1,
minimum=1,
maximum=FIX_MAX_N,
step=1,
randomize=False,
interactive=True,
)
fix_seed = gr.Slider(
label="Seed",
value=42,
minimum=0,
maximum=10000,
step=1,
randomize=False,
interactive=True,
)
fix_cfg = gr.Slider(
label="Classifier free guidance scale",
value=3.0,
minimum=0.0,
maximum=10.0,
step=0.1,
randomize=False,
interactive=True,
)
fix_quality = gr.Slider(
label="Quality",
value=10,
minimum=1,
maximum=10,
step=1,
randomize=False,
interactive=True,
)
# listeners
# fix_crop.change(resize_to_full, fix_crop, fix_ref)
fix_crop.change(lambda x: x, fix_crop, fix_original) # fix_original: (real_H, real_W, 3)
fix_crop.change(stay_crop, [fix_crop, fix_crop_coord], [fix_crop_coord, fix_ref])
fix_crop.select(process_crop, [fix_crop, fix_crop_coord], [fix_crop_coord, fix_ref])
# fix_ref.change(disable_crop, fix_crop_coord, fix_crop)
fix_ref.change(enable_component, [fix_crop, fix_crop], fix_ref)
fix_ref.change(enable_component, [fix_crop, fix_crop], fix_finish_crop)
fix_finish_crop.click(visualize_ref, [fix_ref], [fix_img, fix_inpaint_mask])
# fix_finish_crop.click(get_mask_inpaint, [fix_ref], []) # fix_ref: (real_cropped_H, real_cropped_W, 3)
fix_img.change(lambda x: x, [fix_img], [fix_kp_right])
fix_img.change(lambda x: x, [fix_img], [fix_kp_left])
fix_inpaint_mask.change(
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_checkbox
)
fix_inpaint_mask.change(
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_right
)
fix_inpaint_mask.change(
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_right
)
fix_inpaint_mask.change(
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_right
)
fix_inpaint_mask.change(
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_left
)
fix_inpaint_mask.change(
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_left
)
fix_inpaint_mask.change(
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_left
)
fix_inpaint_mask.change(
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_run
)
fix_checkbox.select(
set_visible,
[fix_checkbox, fix_kpts, fix_img, fix_kp_right, fix_kp_left],
[
fix_kpts,
fix_kp_right,
fix_kp_left,
fix_kp_right,
fix_undo_right,
fix_reset_right,
fix_kp_left,
fix_undo_left,
fix_reset_left,
fix_kp_r_info,
fix_kp_l_info,
],
)
fix_kp_right.select(
get_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts] # fix_img: (real_cropped_H, real_cropped_W, 3)
)
fix_undo_right.click(
undo_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
)
fix_reset_right.click(
reset_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
)
fix_kp_left.select(
get_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
)
fix_undo_left.click(
undo_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
)
fix_reset_left.click(
reset_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
)
# fix_vis_mask32.change(
# enable_component, [fix_vis_mask32, fix_vis_mask256], fix_run
# )
fix_run.click(
ready_sample,
[fix_ref, fix_inpaint_mask, fix_kpts],
[
fix_ref_cond,
fix_target_cond,
fix_latent,
fix_inpaint_latent,
fix_kpts_np,
fix_vis_mask32,
fix_vis_mask256,
],
)
fix_kpts_np.change(
sample_inpaint,
[
fix_ref_cond,
fix_target_cond,
fix_latent,
fix_inpaint_latent,
fix_kpts_np,
fix_original,
fix_crop_coord,
fix_n_generation,
fix_seed,
fix_cfg,
fix_quality,
],
[fix_result, fix_result_pose, fix_result_original],
)
fix_clear.click(
fix_clear_all,
[],
[
fix_crop,
fix_crop_coord,
fix_ref,
fix_checkbox,
fix_kp_right,
fix_kp_left,
fix_result,
fix_result_pose,
fix_result_original,
fix_inpaint_mask,
fix_original,
fix_img,
fix_vis_mask32,
fix_vis_mask256,
fix_kpts,
fix_kpts_np,
fix_ref_cond,
fix_target_cond,
fix_latent,
fix_inpaint_latent,
fix_n_generation,
fix_seed,
fix_cfg,
fix_quality,
],
)
fix_clear.click(
fix_set_unvisible,
[],
[
fix_kp_right,
fix_kp_left,
fix_kp_r_info,
fix_kp_l_info,
fix_undo_left,
fix_undo_right,
fix_reset_left,
fix_reset_right
]
)
with gr.Tab("Demo 2. Repose Hands", elem_id="repose-tab"):
# gr.Markdown("""<p style="background-color: #90EE90; text-align: center; font-size: 20px; font-weight: bold;">Demo 2. Repose Hands</p>""")
dump = gr.State(value=None)
# ref states
ref_img = gr.State(value=None)
ref_im_raw = gr.State(value=None)
ref_kp_raw = gr.State(value=0)
ref_is_user = gr.State(value=True)
ref_kp_got = gr.State(value=None)
ref_manual_cond = gr.State(value=None)
ref_auto_cond = gr.State(value=None)
ref_cond = gr.State(value=None)
# target states
target_img = gr.State(value=None)
target_im_raw = gr.State(value=None)
target_kp_raw = gr.State(value=0)
target_is_user = gr.State(value=True)
target_kp_got = gr.State(value=None)
target_manual_keypts = gr.State(value=None)
target_auto_keypts = gr.State(value=None)
target_keypts = gr.State(value=None)
target_manual_cond = gr.State(value=None)
target_auto_cond = gr.State(value=None)
target_cond = gr.State(value=None)
# main tab
with gr.Row():
# ref column
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 18px; font-weight: bold;">1. Upload a hand image to repose 📥</p>"""
)
gr.Markdown(
"""<p style="text-align: center;">Optionally crop the image</p>"""
)
ref = gr.ImageEditor(
type="numpy",
label="Reference",
show_label=True,
height=LENGTH,
width=LENGTH,
brush=False,
layers=False,
crop_size="1:1",
)
gr.Examples(example_ref_imgs, [ref], examples_per_page=20)
# gr.Markdown(
# """<p style="text-align: center;">&#9313; Hit the &quot;Finish Cropping&quot; button to get hand pose</p>"""
# )
# ref_finish_crop = gr.Button(value="Finish Cropping", interactive=False)
with gr.Accordion(label="See hand pose and more options", open=False):
with gr.Tab("Automatic hand keypoints"):
ref_pose = gr.Image(
type="numpy",
label="Reference Pose",
show_label=True,
height=LENGTH,
width=LENGTH,
interactive=False,
)
ref_use_auto = gr.Button(value="Click here to use automatic, not manual", interactive=False, visible=True)
with gr.Tab("Manual hand keypoints"):
ref_manual_checkbox_info = gr.Markdown(
"""<p style="text-align: center;"><b>Step 1.</b> Tell us if this is right, left, or both hands.</p>""",
visible=True,
)
ref_manual_checkbox = gr.CheckboxGroup(
["Right hand", "Left hand"],
show_label=False,
visible=True,
interactive=True,
)
ref_manual_kp_r_info = gr.Markdown(
"""<p style="text-align: center;"><b>Step 2.</b> Click on image to provide hand keypoints for <b>right</b> hand. See \"OpenPose Keypoint Convention\" for guidance.</p>""",
visible=False,
)
ref_manual_kp_right = gr.Image(
type="numpy",
label="Keypoint Selection (right hand)",
show_label=True,
height=LENGTH,
width=LENGTH,
interactive=False,
visible=False,
sources=[],
)
with gr.Row():
ref_manual_undo_right = gr.Button(
value="Undo", interactive=True, visible=False
)
ref_manual_reset_right = gr.Button(
value="Reset", interactive=True, visible=False
)
ref_manual_kp_l_info = gr.Markdown(
"""<p style="text-align: center;"><b>Step 2.</b> Click on image to provide hand keypoints for <b>left</b> hand. See \"OpenPose keypoint convention\" for guidance.</p>""",
visible=False
)
ref_manual_kp_left = gr.Image(
type="numpy",
label="Keypoint Selection (left hand)",
show_label=True,
height=LENGTH,
width=LENGTH,
interactive=False,
visible=False,
sources=[],
)
with gr.Row():
ref_manual_undo_left = gr.Button(
value="Undo", interactive=True, visible=False
)
ref_manual_reset_left = gr.Button(
value="Reset", interactive=True, visible=False
)
ref_manual_done_info = gr.Markdown(
"""<p style="text-align: center;"><b>Step 3.</b> Hit \"Done\" button to confirm.</p>""",
visible=False,
)
ref_manual_done = gr.Button(value="Done", interactive=True, visible=False)
ref_manual_pose = gr.Image(
type="numpy",
label="Reference Pose",
show_label=True,
height=LENGTH,
width=LENGTH,
interactive=False,
visible=False
)
ref_use_manual = gr.Button(value="Click here to use manual, not automatic", interactive=True, visible=False)
ref_manual_instruct = gr.Markdown(
value="""<p style="text-align: left; font-weight: bold; ">OpenPose Keypoints Convention</p>""",
visible=True
)
ref_manual_openpose = gr.Image(
value="openpose.png",
type="numpy",
show_label=False,
height=LENGTH // 2,
width=LENGTH // 2,
interactive=False,
visible=True
)
gr.Markdown(
"""<p style="text-align: center;">Optionally flip the hand</p>"""
)
ref_flip = gr.Checkbox(
value=False, label="Flip Handedness (Reference)", interactive=False
)
# target column
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 18px; font-weight: bold;">2. Upload a hand image for target hand pose 📥</p>"""
)
gr.Markdown(
"""<p style="text-align: center;">Optionally crop the image</p>"""
)
target = gr.ImageEditor(
type="numpy",
label="Target",
show_label=True,
height=LENGTH,
width=LENGTH,
brush=False,
layers=False,
crop_size="1:1",
)
gr.Examples(example_target_imgs, [target], examples_per_page=20)
# gr.Markdown(
# """<p style="text-align: center;">&#9313; Hit the &quot;Finish Cropping&quot; button to get hand pose</p>"""
# )
# target_finish_crop = gr.Button(
# value="Finish Cropping", interactive=False
# )
with gr.Accordion(label="See hand pose and more options", open=False):
with gr.Tab("Automatic hand keypoints"):
target_pose = gr.Image(
type="numpy",
label="Target Pose",
show_label=True,
height=LENGTH,
width=LENGTH,
interactive=False,
)
target_use_auto = gr.Button(value="Click here to use automatic, not manual", interactive=False, visible=True)
with gr.Tab("Manual hand keypoints"):
target_manual_checkbox_info = gr.Markdown(
"""<p style="text-align: center;"><b>Step 1.</b> Tell us if this is right, left, or both hands.</p>""",
visible=True,
)
target_manual_checkbox = gr.CheckboxGroup(
["Right hand", "Left hand"],
show_label=False,
visible=True,
interactive=True,
)
target_manual_kp_r_info = gr.Markdown(
"""<p style="text-align: center;"><b>Step 2.</b> Click on image to provide hand keypoints for <b>right</b> hand. See \"OpenPose Keypoint Convention\" for guidance.</p>""",
visible=False,
)
target_manual_kp_right = gr.Image(
type="numpy",
label="Keypoint Selection (right hand)",
show_label=True,
height=LENGTH,
width=LENGTH,
interactive=False,
visible=False,
sources=[],
)
with gr.Row():
target_manual_undo_right = gr.Button(
value="Undo", interactive=True, visible=False
)
target_manual_reset_right = gr.Button(
value="Reset", interactive=True, visible=False
)
target_manual_kp_l_info = gr.Markdown(
"""<p style="text-align: center;"><b>Step 2.</b> Click on image to provide hand keypoints for <b>left</b> hand. See \"OpenPose keypoint convention\" for guidance.</p>""",
visible=False
)
target_manual_kp_left = gr.Image(
type="numpy",
label="Keypoint Selection (left hand)",
show_label=True,
height=LENGTH,
width=LENGTH,
interactive=False,
visible=False,
sources=[],
)
with gr.Row():
target_manual_undo_left = gr.Button(
value="Undo", interactive=True, visible=False
)
target_manual_reset_left = gr.Button(
value="Reset", interactive=True, visible=False
)
target_manual_done_info = gr.Markdown(
"""<p style="text-align: center;"><b>Step 3.</b> Hit \"Done\" button to confirm.</p>""",
visible=False,
)
target_manual_done = gr.Button(value="Done", interactive=True, visible=False)
target_manual_pose = gr.Image(
type="numpy",
label="Target Pose",
show_label=True,
height=LENGTH,
width=LENGTH,
interactive=False,
visible=False
)
target_use_manual = gr.Button(value="Click here to use manual, not automatic", interactive=True, visible=False)
target_manual_instruct = gr.Markdown(
value="""<p style="text-align: left; font-weight: bold; ">OpenPose Keypoints Convention</p>""",
visible=True
)
target_manual_openpose = gr.Image(
value="openpose.png",
type="numpy",
show_label=False,
height=LENGTH // 2,
width=LENGTH // 2,
interactive=False,
visible=True
)
gr.Markdown(
"""<p style="text-align: center;">Optionally flip the hand</p>"""
)
target_flip = gr.Checkbox(
value=False, label="Flip Handedness (Target)", interactive=False
)
# result column
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 18px; font-weight: bold;">3. Press &quot;Run&quot; to get the reposed results 🎯</p>"""
)
run = gr.Button(value="Run", interactive=False)
gr.Markdown(
"""<p style="text-align: center;">⚠️ ~20s per generation with RTX3090. ~50s with A100. <br>(For example, if you set Number of generations as 2, it would take around 40s)</p>"""
)
results = gr.Gallery(
type="numpy",
label="Results",
show_label=True,
height=LENGTH,
min_width=LENGTH,
columns=MAX_N,
interactive=False,
preview=True,
)
with gr.Accordion(label="Results with pose", open=False):
results_pose = gr.Gallery(
type="numpy",
label="Results Pose",
show_label=True,
height=LENGTH,
min_width=LENGTH,
columns=MAX_N,
interactive=False,
preview=True,
)
gr.Markdown(
"""<p style="text-align: center;">✨ Hit &quot;Clear&quot; to restart from the beginning</p>"""
)
clear = gr.ClearButton()
with gr.Accordion(label="More options", open=False):
with gr.Row():
n_generation = gr.Slider(
label="Number of generations",
value=1,
minimum=1,
maximum=MAX_N,
step=1,
randomize=False,
interactive=True,
)
seed = gr.Slider(
label="Seed",
value=42,
minimum=0,
maximum=10000,
step=1,
randomize=False,
interactive=True,
)
cfg = gr.Slider(
label="Classifier free guidance scale",
value=2.5,
minimum=0.0,
maximum=10.0,
step=0.1,
randomize=False,
interactive=True,
)
# reference listeners
# ref.change(enable_component, [ref, ref], ref_finish_crop)
ref.change(prepare_anno, [ref, ref_is_user], [ref_im_raw, ref_kp_raw])
ref_kp_raw.change(lambda x: x, ref_im_raw, ref_manual_kp_right)
ref_kp_raw.change(lambda x: x, ref_im_raw, ref_manual_kp_left)
ref_kp_raw.change(get_ref_anno, [ref_im_raw, ref_kp_raw], [ref_img, ref_pose, ref_auto_cond, ref, ref_is_user])
ref_pose.change(enable_component, [ref_kp_raw, ref_pose], ref_use_auto)
ref_pose.change(enable_component, [ref_img, ref_pose], ref_flip)
ref_auto_cond.change(lambda x: x, ref_auto_cond, ref_cond)
ref_use_auto.click(lambda x: x, ref_auto_cond, ref_cond)
ref_use_auto.click(lambda x: gr.Info("Automatic hand keypoints will be used for 'Reference'", duration=3))
ref_manual_checkbox.select(
set_visible,
[ref_manual_checkbox, ref_kp_got, ref_im_raw, ref_manual_kp_right, ref_manual_kp_left, ref_manual_done],
[
ref_kp_got,
ref_manual_kp_right,
ref_manual_kp_left,
ref_manual_kp_right,
ref_manual_undo_right,
ref_manual_reset_right,
ref_manual_kp_left,
ref_manual_undo_left,
ref_manual_reset_left,
ref_manual_kp_r_info,
ref_manual_kp_l_info,
ref_manual_done,
ref_manual_done_info
]
)
ref_manual_kp_right.select(
get_kps, [ref_im_raw, ref_kp_got, gr.State("right")], [ref_manual_kp_right, ref_kp_got]
)
ref_manual_undo_right.click(
undo_kps, [ref_im_raw, ref_kp_got, gr.State("right")], [ref_manual_kp_right, ref_kp_got]
)
ref_manual_reset_right.click(
reset_kps, [ref_im_raw, ref_kp_got, gr.State("right")], [ref_manual_kp_right, ref_kp_got]
)
ref_manual_kp_left.select(
get_kps, [ref_im_raw, ref_kp_got, gr.State("left")], [ref_manual_kp_left, ref_kp_got]
)
ref_manual_undo_left.click(
undo_kps, [ref_im_raw, ref_kp_got, gr.State("left")], [ref_manual_kp_left, ref_kp_got]
)
ref_manual_reset_left.click(
reset_kps, [ref_im_raw, ref_kp_got, gr.State("left")], [ref_manual_kp_left, ref_kp_got]
)
ref_manual_done.click(visible_component, [gr.State(0), ref_manual_pose], ref_manual_pose)
ref_manual_done.click(visible_component, [gr.State(0), ref_use_manual], ref_use_manual)
ref_manual_done.click(get_ref_anno, [ref_im_raw, ref_kp_got], [ref_img, ref_manual_pose, ref_manual_cond])
ref_manual_pose.change(enable_component, [ref_manual_pose, ref_manual_pose], ref_manual_done)
ref_manual_pose.change(enable_component, [ref_img, ref_manual_pose], ref_flip)
ref_manual_cond.change(lambda x: x, ref_manual_cond, ref_cond)
ref_use_manual.click(lambda x: x, ref_manual_cond, ref_cond)
ref_use_manual.click(lambda x: gr.Info("Manual hand keypoints will be used for 'Reference'", duration=3))
ref_flip.select(
flip_hand,
[ref, ref_im_raw, ref_pose, ref_manual_pose, ref_manual_kp_right, ref_manual_kp_left, ref_cond, ref_auto_cond, ref_manual_cond],
[ref, ref_im_raw, ref_pose, ref_manual_pose, ref_manual_kp_right, ref_manual_kp_left, ref_cond, ref_auto_cond, ref_manual_cond]
)
# target listeners
# target.change(enable_component, [target, target], target_finish_crop)
target.change(prepare_anno, [target, target_is_user], [target_im_raw, target_kp_raw])
target_kp_raw.change(lambda x:x, target_im_raw, target_manual_kp_right)
target_kp_raw.change(lambda x:x, target_im_raw, target_manual_kp_left)
target_kp_raw.change(get_target_anno, [target_im_raw, target_kp_raw], [target_img, target_pose, target_auto_cond, target_auto_keypts, target, target_is_user])
target_pose.change(enable_component, [target_kp_raw, target_pose], target_use_auto)
target_pose.change(enable_component, [target_img, target_pose], target_flip)
target_auto_cond.change(lambda x: x, target_auto_cond, target_cond)
target_auto_keypts.change(lambda x: x, target_auto_keypts, target_keypts)
target_use_auto.click(lambda x: x, target_auto_cond, target_cond)
target_use_auto.click(lambda x: x, target_auto_keypts, target_keypts)
target_use_auto.click(lambda x: gr.Info("Automatic hand keypoints will be used for 'Target'", duration=3))
target_manual_checkbox.select(
set_visible,
[target_manual_checkbox, target_kp_got, target_im_raw, target_manual_kp_right, target_manual_kp_left, target_manual_done],
[
target_kp_got,
target_manual_kp_right,
target_manual_kp_left,
target_manual_kp_right,
target_manual_undo_right,
target_manual_reset_right,
target_manual_kp_left,
target_manual_undo_left,
target_manual_reset_left,
target_manual_kp_r_info,
target_manual_kp_l_info,
target_manual_done,
target_manual_done_info
]
)
target_manual_kp_right.select(
get_kps, [target_im_raw, target_kp_got, gr.State("right")], [target_manual_kp_right, target_kp_got]
)
target_manual_undo_right.click(
undo_kps, [target_im_raw, target_kp_got, gr.State("right")], [target_manual_kp_right, target_kp_got]
)
target_manual_reset_right.click(
reset_kps, [target_im_raw, target_kp_got, gr.State("right")], [target_manual_kp_right, target_kp_got]
)
target_manual_kp_left.select(
get_kps, [target_im_raw, target_kp_got, gr.State("left")], [target_manual_kp_left, target_kp_got]
)
target_manual_undo_left.click(
undo_kps, [target_im_raw, target_kp_got, gr.State("left")], [target_manual_kp_left, target_kp_got]
)
target_manual_reset_left.click(
reset_kps, [target_im_raw, target_kp_got, gr.State("left")], [target_manual_kp_left, target_kp_got]
)
target_manual_done.click(visible_component, [gr.State(0), target_manual_pose], target_manual_pose)
target_manual_done.click(visible_component, [gr.State(0), target_use_manual], target_use_manual)
target_manual_done.click(get_target_anno, [target_im_raw, target_kp_got], [target_img, target_manual_pose, target_manual_cond, target_manual_keypts])
target_manual_pose.change(enable_component, [target_manual_pose, target_manual_pose], target_manual_done)
target_manual_pose.change(enable_component, [target_img, target_manual_pose], target_flip)
target_manual_cond.change(lambda x: x, target_manual_cond, target_cond)
target_manual_keypts.change(lambda x: x, target_manual_keypts, target_keypts)
target_use_manual.click(lambda x: x, target_manual_cond, target_cond)
target_use_manual.click(lambda x: x, target_manual_keypts, target_keypts)
target_use_manual.click(lambda x: gr.Info("Manual hand keypoints will be used for 'Reference'", duration=3))
target_flip.select(
flip_hand,
[target, target_im_raw, target_pose, target_manual_pose, target_manual_kp_right, target_manual_kp_left, target_cond, target_auto_cond, target_manual_cond, target_keypts, target_auto_keypts, target_manual_keypts],
[target, target_im_raw, target_pose, target_manual_pose, target_manual_kp_right, target_manual_kp_left, target_cond, target_auto_cond, target_manual_cond, target_keypts, target_auto_keypts, target_manual_keypts],
)
# run listerners
ref_cond.change(enable_component, [ref_cond, target_cond], run)
target_cond.change(enable_component, [ref_cond, target_cond], run)
# ref_manual_pose.change(enable_component, [ref_manual_pose, target_manual_pose], run)
# target_manual_pose.change(enable_component, [ref_manual_pose, target_manual_pose], run)
run.click(
sample_diff,
[ref_cond, target_cond, target_keypts, n_generation, seed, cfg],
[results, results_pose],
)
clear.click(
clear_all,
[],
[
ref,
ref_manual_checkbox,
ref_manual_kp_right,
ref_manual_kp_left,
ref_img,
ref_pose,
ref_manual_pose,
ref_cond,
ref_flip,
target,
target_keypts,
target_manual_checkbox,
target_manual_kp_right,
target_manual_kp_left,
target_img,
target_pose,
target_manual_pose,
target_cond,
target_flip,
results,
results_pose,
n_generation,
seed,
cfg,
ref_kp_raw,
],
)
clear.click(
set_unvisible,
[],
[
ref_manual_kp_l_info,
ref_manual_kp_r_info,
ref_manual_kp_left,
ref_manual_kp_right,
ref_manual_undo_left,
ref_manual_undo_right,
ref_manual_reset_left,
ref_manual_reset_right,
ref_manual_done,
ref_manual_done_info,
ref_manual_pose,
ref_use_manual,
target_manual_kp_l_info,
target_manual_kp_r_info,
target_manual_kp_left,
target_manual_kp_right,
target_manual_undo_left,
target_manual_undo_right,
target_manual_reset_left,
target_manual_reset_right,
target_manual_done,
target_manual_done_info,
target_manual_pose,
target_use_manual,
]
)
gr.Markdown("<h1>Acknowledgement</h1>")
gr.Markdown(_ACK_)
gr.Markdown("<h1>Trouble Shooting</h1>")
gr.Markdown("If something doesn't work, <br>1. Try refreshing the page and do it again. <br>2. Leave a message at our HuggingFace Spaces's \"Community\" tab on the top right, at our Github repo's Issues, or email us.<br>3. The problem might be from compatibility issue to HuggingFace or GPU vram limitations. If that's possible, we highly recommend you to clone this repo and try with your own gpu.")
gr.Markdown("<h1>Citation</h1>")
gr.Markdown(
"""<p style="text-align: left;">If this was useful, please cite us! ❤️</p>"""
)
gr.Markdown(_CITE_)
# print("Ready to launch..")
# _, _, shared_url = demo.queue().launch(
# share=True, server_name="0.0.0.0", server_port=7739
# )
demo.launch(share=True)