Chaerin5 commited on
Commit
6aa317f
·
1 Parent(s): fe8403c

instruction renovation; allow manual keypoints at edit hands

Browse files
Files changed (2) hide show
  1. app_regular_gpu.py +2003 -0
  2. no_hands.png +3 -0
app_regular_gpu.py ADDED
@@ -0,0 +1,2003 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from dataclasses import dataclass
4
+ import gradio as gr
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import cv2
8
+ import mediapipe as mp
9
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize
10
+ import vqvae
11
+ import vit
12
+ from typing import Literal
13
+ from diffusion import create_diffusion
14
+ from utils import scale_keypoint, keypoint_heatmap, check_keypoints_validity
15
+ from segment_hoi import init_sam
16
+ from io import BytesIO
17
+ from PIL import Image
18
+ import random
19
+ from copy import deepcopy
20
+ from typing import Optional
21
+ import requests
22
+ from huggingface_hub import hf_hub_download
23
+ # import spaces
24
+
25
+ MAX_N = 6
26
+ FIX_MAX_N = 6
27
+
28
+ placeholder = cv2.cvtColor(cv2.imread("placeholder.png"), cv2.COLOR_BGR2RGB)
29
+ NEW_MODEL = True
30
+ MODEL_EPOCH = 6
31
+ REF_POSE_MASK = True
32
+
33
+ def set_seed(seed):
34
+ seed = int(seed)
35
+ torch.manual_seed(seed)
36
+ np.random.seed(seed)
37
+ torch.cuda.manual_seed_all(seed)
38
+ random.seed(seed)
39
+
40
+ # if torch.cuda.is_available():
41
+ device = "cuda"
42
+ # else:
43
+ # device = "cpu"
44
+
45
+ def remove_prefix(text, prefix):
46
+ if text.startswith(prefix):
47
+ return text[len(prefix) :]
48
+ return text
49
+
50
+
51
+ def unnormalize(x):
52
+ return (((x + 1) / 2) * 255).astype(np.uint8)
53
+
54
+
55
+ def visualize_hand(all_joints, img, side=["right", "left"], n_avail_joints=21):
56
+ # Define the connections between joints for drawing lines and their corresponding colors
57
+ connections = [
58
+ ((0, 1), "red"),
59
+ ((1, 2), "green"),
60
+ ((2, 3), "blue"),
61
+ ((3, 4), "purple"),
62
+ ((0, 5), "orange"),
63
+ ((5, 6), "pink"),
64
+ ((6, 7), "brown"),
65
+ ((7, 8), "cyan"),
66
+ ((0, 9), "yellow"),
67
+ ((9, 10), "magenta"),
68
+ ((10, 11), "lime"),
69
+ ((11, 12), "indigo"),
70
+ ((0, 13), "olive"),
71
+ ((13, 14), "teal"),
72
+ ((14, 15), "navy"),
73
+ ((15, 16), "gray"),
74
+ ((0, 17), "lavender"),
75
+ ((17, 18), "silver"),
76
+ ((18, 19), "maroon"),
77
+ ((19, 20), "fuchsia"),
78
+ ]
79
+ H, W, C = img.shape
80
+
81
+ # Create a figure and axis
82
+ plt.figure()
83
+ ax = plt.gca()
84
+ # Plot joints as points
85
+ ax.imshow(img)
86
+ start_is = []
87
+ if "right" in side:
88
+ start_is.append(0)
89
+ if "left" in side:
90
+ start_is.append(21)
91
+ for start_i in start_is:
92
+ joints = all_joints[start_i : start_i + n_avail_joints]
93
+ if len(joints) == 1:
94
+ ax.scatter(joints[0][0], joints[0][1], color="red", s=10)
95
+ else:
96
+ for connection, color in connections[: len(joints) - 1]:
97
+ joint1 = joints[connection[0]]
98
+ joint2 = joints[connection[1]]
99
+ ax.plot([joint1[0], joint2[0]], [joint1[1], joint2[1]], color=color)
100
+
101
+ ax.set_xlim([0, W])
102
+ ax.set_ylim([0, H])
103
+ ax.grid(False)
104
+ ax.set_axis_off()
105
+ ax.invert_yaxis()
106
+ # plt.subplots_adjust(wspace=0.01)
107
+ # plt.show()
108
+ buf = BytesIO()
109
+ plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
110
+ plt.close()
111
+
112
+ # Convert BytesIO object to numpy array
113
+ buf.seek(0)
114
+ img_pil = Image.open(buf)
115
+ img_pil = img_pil.resize((H, W))
116
+ numpy_img = np.array(img_pil)
117
+
118
+ return numpy_img
119
+
120
+
121
+ def mask_image(image, mask, color=[0, 0, 0], alpha=0.6, transparent=True):
122
+ """Overlay mask on image for visualization purpose.
123
+ Args:
124
+ image (H, W, 3) or (H, W): input image
125
+ mask (H, W): mask to be overlaid
126
+ color: the color of overlaid mask
127
+ alpha: the transparency of the mask
128
+ """
129
+ out = deepcopy(image)
130
+ img = deepcopy(image)
131
+ img[mask == 1] = color
132
+ if transparent:
133
+ out = cv2.addWeighted(img, alpha, out, 1 - alpha, 0, out)
134
+ else:
135
+ out = img
136
+ return out
137
+
138
+
139
+ def scale_keypoint(keypoint, original_size, target_size):
140
+ """Scale a keypoint based on the resizing of the image."""
141
+ keypoint_copy = keypoint.copy()
142
+ keypoint_copy[:, 0] *= target_size[0] / original_size[0]
143
+ keypoint_copy[:, 1] *= target_size[1] / original_size[1]
144
+ return keypoint_copy
145
+
146
+
147
+ print("Configure...")
148
+
149
+
150
+ @dataclass
151
+ class HandDiffOpts:
152
+ run_name: str = "ViT_256_handmask_heatmap_nvs_b25_lr1e-5"
153
+ sd_path: str = "/users/kchen157/scratch/weights/SD/sd-v1-4.ckpt"
154
+ log_dir: str = "/users/kchen157/scratch/log"
155
+ data_root: str = "/users/kchen157/data/users/kchen157/dataset/handdiff"
156
+ image_size: tuple = (256, 256)
157
+ latent_size: tuple = (32, 32)
158
+ latent_dim: int = 4
159
+ mask_bg: bool = False
160
+ kpts_form: str = "heatmap"
161
+ n_keypoints: int = 42
162
+ n_mask: int = 1
163
+ noise_steps: int = 1000
164
+ test_sampling_steps: int = 250
165
+ ddim_steps: int = 100
166
+ ddim_discretize: str = "uniform"
167
+ ddim_eta: float = 0.0
168
+ beta_start: float = 8.5e-4
169
+ beta_end: float = 0.012
170
+ latent_scaling_factor: float = 0.18215
171
+ cfg_pose: float = 5.0
172
+ cfg_appearance: float = 3.5
173
+ batch_size: int = 25
174
+ lr: float = 1e-5
175
+ max_epochs: int = 500
176
+ log_every_n_steps: int = 100
177
+ limit_val_batches: int = 1
178
+ n_gpu: int = 8
179
+ num_nodes: int = 1
180
+ precision: str = "16-mixed"
181
+ profiler: str = "simple"
182
+ swa_epoch_start: int = 10
183
+ swa_lrs: float = 1e-3
184
+ num_workers: int = 10
185
+ n_val_samples: int = 4
186
+
187
+ # load models
188
+ token = os.getenv("HF_TOKEN")
189
+ if NEW_MODEL:
190
+ opts = HandDiffOpts()
191
+ if MODEL_EPOCH == 7:
192
+ model_path = './DINO_EMA_11M_b50_lr1e-5_epoch7_step380k.ckpt'
193
+ elif MODEL_EPOCH == 6:
194
+ # model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt"
195
+ model_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt", token=token)
196
+ elif MODEL_EPOCH == 4:
197
+ model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch4_step210k.ckpt"
198
+ elif MODEL_EPOCH == 10:
199
+ model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch10_step550k.ckpt"
200
+ else:
201
+ raise ValueError(f"new model epoch should be either 6 or 7, got {MODEL_EPOCH}")
202
+ # vae_path = './vae-ft-mse-840000-ema-pruned.ckpt'
203
+ vae_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="vae-ft-mse-840000-ema-pruned.ckpt", token=token)
204
+ # sd_path = './sd-v1-4.ckpt'
205
+ print('Load diffusion model...')
206
+ diffusion = create_diffusion(str(opts.test_sampling_steps))
207
+ model = vit.DiT_XL_2(
208
+ input_size=opts.latent_size[0],
209
+ latent_dim=opts.latent_dim,
210
+ in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
211
+ learn_sigma=True,
212
+ ).to(device)
213
+ # ckpt_state_dict = torch.load(model_path)['model_state_dict']
214
+ ckpt_state_dict = torch.load(model_path, map_location='cpu')['ema_state_dict']
215
+ missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
216
+ model = model.to(device)
217
+ model.eval()
218
+ print(missing_keys, extra_keys)
219
+ assert len(missing_keys) == 0
220
+ vae_state_dict = torch.load(vae_path, map_location='cpu')['state_dict']
221
+ print(f"vae_state_dict encoder dtype: {vae_state_dict['encoder.conv_in.weight'].dtype}")
222
+ autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False)
223
+ print(f"autoencoder encoder dtype: {next(autoencoder.encoder.parameters()).dtype}")
224
+ print(f"encoder before load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
225
+ print(f"encoder before load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
226
+ missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
227
+ print(f"encoder after load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
228
+ print(f"encoder after load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
229
+ autoencoder = autoencoder.to(device)
230
+ autoencoder.eval()
231
+ print(f"encoder after eval() min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
232
+ print(f"encoder after eval() max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
233
+ print(f"autoencoder encoder after eval() dtype: {next(autoencoder.encoder.parameters()).dtype}")
234
+ assert len(missing_keys) == 0
235
+ # else:
236
+ # opts = HandDiffOpts()
237
+ # model_path = './finetune_epoch=5-step=130000.ckpt'
238
+ # sd_path = './sd-v1-4.ckpt'
239
+ # print('Load diffusion model...')
240
+ # diffusion = create_diffusion(str(opts.test_sampling_steps))
241
+ # model = vit.DiT_XL_2(
242
+ # input_size=opts.latent_size[0],
243
+ # latent_dim=opts.latent_dim,
244
+ # in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
245
+ # learn_sigma=True,
246
+ # ).to(device)
247
+ # ckpt_state_dict = torch.load(model_path)['state_dict']
248
+ # dit_state_dict = {remove_prefix(k, 'diffusion_backbone.'): v for k, v in ckpt_state_dict.items() if k.startswith('diffusion_backbone')}
249
+ # vae_state_dict = {remove_prefix(k, 'autoencoder.'): v for k, v in ckpt_state_dict.items() if k.startswith('autoencoder')}
250
+ # missing_keys, extra_keys = model.load_state_dict(dit_state_dict, strict=False)
251
+ # model.eval()
252
+ # assert len(missing_keys) == 0 and len(extra_keys) == 0
253
+ # autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).to(device)
254
+ # missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
255
+ # autoencoder.eval()
256
+ # assert len(missing_keys) == 0 and len(extra_keys) == 0
257
+ sam_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="sam_vit_h_4b8939.pth", token=token)
258
+ sam_predictor = init_sam(ckpt_path=sam_path, device='cuda')
259
+
260
+
261
+ print("Mediapipe hand detector and SAM ready...")
262
+ mp_hands = mp.solutions.hands
263
+ hands = mp_hands.Hands(
264
+ static_image_mode=True, # Use False if image is part of a video stream
265
+ max_num_hands=2, # Maximum number of hands to detect
266
+ min_detection_confidence=0.1,
267
+ )
268
+
269
+ def prepare_ref_anno(ref):
270
+ if ref is None:
271
+ return (
272
+ None,
273
+ None,
274
+ None,
275
+ None,
276
+ None,
277
+ )
278
+ missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
279
+
280
+ img = ref["composite"][..., :3]
281
+ img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
282
+ keypts = np.zeros((42, 2))
283
+ # if REF_POSE_MASK:
284
+ mp_pose = hands.process(img)
285
+ # detected = np.array([0, 0])
286
+ # start_idx = 0
287
+ if mp_pose.multi_hand_landmarks:
288
+ # handedness is flipped assuming the input image is mirrored in MediaPipe
289
+ for hand_landmarks, handedness in zip(
290
+ mp_pose.multi_hand_landmarks, mp_pose.multi_handedness
291
+ ):
292
+ # actually right hand
293
+ if handedness.classification[0].label == "Left":
294
+ start_idx = 0
295
+ # detected[0] = 1
296
+ # actually left hand
297
+ elif handedness.classification[0].label == "Right":
298
+ start_idx = 21
299
+ # detected[1] = 1
300
+ for i, landmark in enumerate(hand_landmarks.landmark):
301
+ keypts[start_idx + i] = [
302
+ landmark.x * opts.image_size[1],
303
+ landmark.y * opts.image_size[0],
304
+ ]
305
+
306
+ # sam_predictor.set_image(img)
307
+ # l = keypts[:21].shape[0]
308
+ # if keypts[0].sum() != 0 and keypts[21].sum() != 0:
309
+ # input_point = np.array([keypts[0], keypts[21]])
310
+ # input_label = np.array([1, 1])
311
+ # elif keypts[0].sum() != 0:
312
+ # input_point = np.array(keypts[:1])
313
+ # input_label = np.array([1])
314
+ # elif keypts[21].sum() != 0:
315
+ # input_point = np.array(keypts[21:22])
316
+ # input_label = np.array([1])
317
+ # masks, _, _ = sam_predictor.predict(
318
+ # point_coords=input_point,
319
+ # point_labels=input_label,
320
+ # multimask_output=False,
321
+ # )
322
+ # hand_mask = masks[0]
323
+ # masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None])
324
+ # ref_pose = visualize_hand(keypts, masked_img)
325
+ print(f"keypts.max(): {keypts.max()}, keypts.min(): {keypts.min()}")
326
+ return img, keypts
327
+ else:
328
+ return img, None
329
+ # raise gr.Error("No hands detected in the reference image.")
330
+ # else:
331
+ # hand_mask = np.zeros_like(img[:,:, 0])
332
+ # ref_pose = np.zeros_like(img)
333
+
334
+ def get_ref_anno(img, keypts):
335
+ if keypts is None:
336
+ no_hands = cv2.resize(np.array(Image.open("no_hands.png"))[..., :3], (LENGTH, LENGTH))
337
+ return None, no_hands, None
338
+ if isinstance(keypts, list):
339
+ if len(keypts[0]) == 0:
340
+ keypts[0] = np.zeros((21, 2))
341
+ elif len(keypts[0]) == 21:
342
+ keypts[0] = np.array(keypts[0], dtype=np.float32)
343
+ else:
344
+ gr.Info("Number of right hand keypoints should be either 0 or 21.")
345
+ return None, None
346
+
347
+ if len(keypts[1]) == 0:
348
+ keypts[1] = np.zeros((21, 2))
349
+ elif len(keypts[1]) == 21:
350
+ keypts[1] = np.array(keypts[1], dtype=np.float32)
351
+ else:
352
+ gr.Info("Number of left hand keypoints should be either 0 or 21.")
353
+ return None, None
354
+
355
+ keypts = np.concatenate(keypts, axis=0)
356
+ # keypts = scale_keypoint(keypts, (LENGTH, LENGTH), opts.image_size)
357
+ if REF_POSE_MASK:
358
+ sam_predictor.set_image(img)
359
+ # l = keypts[:21].shape[0]
360
+ if keypts[0].sum() != 0 and keypts[21].sum() != 0:
361
+ input_point = np.array([keypts[0], keypts[21]])
362
+ input_label = np.array([1, 1])
363
+ elif keypts[0].sum() != 0:
364
+ input_point = np.array(keypts[:1])
365
+ input_label = np.array([1])
366
+ elif keypts[21].sum() != 0:
367
+ input_point = np.array(keypts[21:22])
368
+ input_label = np.array([1])
369
+ masks, _, _ = sam_predictor.predict(
370
+ point_coords=input_point,
371
+ point_labels=input_label,
372
+ multimask_output=False,
373
+ )
374
+ hand_mask = masks[0]
375
+ masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None])
376
+ ref_pose = visualize_hand(keypts, masked_img)
377
+ else:
378
+ hand_mask = np.zeros_like(img[:,:, 0])
379
+ ref_pose = np.zeros_like(img)
380
+ def make_ref_cond(
381
+ img,
382
+ keypts,
383
+ hand_mask,
384
+ device="cuda",
385
+ target_size=(256, 256),
386
+ latent_size=(32, 32),
387
+ ):
388
+ image_transform = Compose(
389
+ [
390
+ ToTensor(),
391
+ Resize(target_size),
392
+ Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
393
+ ]
394
+ )
395
+ image = image_transform(img).to(device)
396
+ kpts_valid = check_keypoints_validity(keypts, target_size)
397
+ heatmaps = torch.tensor(
398
+ keypoint_heatmap(
399
+ scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0
400
+ )
401
+ * kpts_valid[:, None, None],
402
+ dtype=torch.float,
403
+ device=device
404
+ )[None, ...]
405
+ mask = torch.tensor(
406
+ cv2.resize(
407
+ hand_mask.astype(int),
408
+ dsize=latent_size,
409
+ interpolation=cv2.INTER_NEAREST,
410
+ ),
411
+ dtype=torch.float,
412
+ device=device,
413
+ ).unsqueeze(0)[None, ...]
414
+ return image[None, ...], heatmaps, mask
415
+
416
+ print(f"img.max(): {img.max()}, img.min(): {img.min()}")
417
+ image, heatmaps, mask = make_ref_cond(
418
+ img,
419
+ keypts,
420
+ hand_mask,
421
+ device="cuda",
422
+ target_size=opts.image_size,
423
+ latent_size=opts.latent_size,
424
+ )
425
+ print(f"image.max(): {image.max()}, image.min(): {image.min()}")
426
+ print(f"opts.latent_scaling_factor: {opts.latent_scaling_factor}")
427
+ print(f"autoencoder encoder before operating max: {min([p.min() for p in autoencoder.encoder.parameters()])}")
428
+ print(f"autoencoder encoder before operating min: {max([p.max() for p in autoencoder.encoder.parameters()])}")
429
+ print(f"autoencoder encoder before operating dtype: {next(autoencoder.encoder.parameters()).dtype}")
430
+ latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
431
+ print(f"latent.max(): {latent.max()}, latent.min(): {latent.min()}")
432
+ if not REF_POSE_MASK:
433
+ heatmaps = torch.zeros_like(heatmaps)
434
+ mask = torch.zeros_like(mask)
435
+ print(f"heatmaps.max(): {heatmaps.max()}, heatmaps.min(): {heatmaps.min()}")
436
+ print(f"mask.max(): {mask.max()}, mask.min(): {mask.min()}")
437
+ ref_cond = torch.cat([latent, heatmaps, mask], 1)
438
+ print(f"ref_cond.max(): {ref_cond.max()}, ref_cond.min(): {ref_cond.min()}")
439
+
440
+ return img, ref_pose, ref_cond
441
+
442
+ def get_target_anno(target):
443
+ if target is None:
444
+ return (
445
+ gr.State.update(value=None),
446
+ gr.Image.update(value=None),
447
+ gr.State.update(value=None),
448
+ gr.State.update(value=None),
449
+ )
450
+ pose_img = target["composite"][..., :3]
451
+ pose_img = cv2.resize(pose_img, opts.image_size, interpolation=cv2.INTER_AREA)
452
+ # detect keypoints
453
+ mp_pose = hands.process(pose_img)
454
+ target_keypts = np.zeros((42, 2))
455
+ detected = np.array([0, 0])
456
+ start_idx = 0
457
+ if mp_pose.multi_hand_landmarks:
458
+ # handedness is flipped assuming the input image is mirrored in MediaPipe
459
+ for hand_landmarks, handedness in zip(
460
+ mp_pose.multi_hand_landmarks, mp_pose.multi_handedness
461
+ ):
462
+ # actually right hand
463
+ if handedness.classification[0].label == "Left":
464
+ start_idx = 0
465
+ detected[0] = 1
466
+ # actually left hand
467
+ elif handedness.classification[0].label == "Right":
468
+ start_idx = 21
469
+ detected[1] = 1
470
+ for i, landmark in enumerate(hand_landmarks.landmark):
471
+ target_keypts[start_idx + i] = [
472
+ landmark.x * opts.image_size[1],
473
+ landmark.y * opts.image_size[0],
474
+ ]
475
+
476
+ target_pose = visualize_hand(target_keypts, pose_img)
477
+ kpts_valid = check_keypoints_validity(target_keypts, opts.image_size)
478
+ target_heatmaps = torch.tensor(
479
+ keypoint_heatmap(
480
+ scale_keypoint(target_keypts, opts.image_size, opts.latent_size),
481
+ opts.latent_size,
482
+ var=1.0,
483
+ )
484
+ * kpts_valid[:, None, None],
485
+ dtype=torch.float,
486
+ # device=device,
487
+ )[None, ...]
488
+ target_cond = torch.cat(
489
+ [target_heatmaps, torch.zeros_like(target_heatmaps)[:, :1]], 1
490
+ )
491
+ else:
492
+ raise gr.Error("No hands detected in the target image.")
493
+
494
+ return pose_img, target_pose, target_cond, target_keypts
495
+
496
+
497
+ def get_mask_inpaint(ref):
498
+ inpaint_mask = np.array(ref["layers"][0])[..., -1]
499
+ inpaint_mask = cv2.resize(
500
+ inpaint_mask, opts.image_size, interpolation=cv2.INTER_AREA
501
+ )
502
+ inpaint_mask = (inpaint_mask >= 128).astype(np.uint8)
503
+ return inpaint_mask
504
+
505
+
506
+ def visualize_ref(crop, brush):
507
+ if crop is None or brush is None:
508
+ return None
509
+ inpainted = brush["layers"][0][..., -1]
510
+ img = crop["background"][..., :3]
511
+ img = cv2.resize(img, inpainted.shape[::-1], interpolation=cv2.INTER_AREA)
512
+ mask = inpainted < 128
513
+ # img = img.astype(np.int32)
514
+ # img[mask, :] = img[mask, :] - 50
515
+ # img[np.any(img<0, axis=-1)]=0
516
+ # img = img.astype(np.uint8)
517
+ img = mask_image(img, mask)
518
+ return img
519
+
520
+
521
+ def get_kps(img, keypoints, side: Literal["right", "left"], evt: gr.SelectData):
522
+ if keypoints is None:
523
+ keypoints = [[], []]
524
+ kps = np.zeros((42, 2))
525
+ if side == "right":
526
+ if len(keypoints[0]) == 21:
527
+ gr.Info("21 keypoints for right hand already selected. Try reset if something looks wrong.")
528
+ else:
529
+ keypoints[0].append(list(evt.index))
530
+ len_kps = len(keypoints[0])
531
+ kps[:len_kps] = np.array(keypoints[0])
532
+ elif side == "left":
533
+ if len(keypoints[1]) == 21:
534
+ gr.Info("21 keypoints for left hand already selected. Try reset if something looks wrong.")
535
+ else:
536
+ keypoints[1].append(list(evt.index))
537
+ len_kps = len(keypoints[1])
538
+ kps[21 : 21 + len_kps] = np.array(keypoints[1])
539
+ vis_hand = visualize_hand(kps, img, side, len_kps)
540
+ return vis_hand, keypoints
541
+
542
+
543
+ def undo_kps(img, keypoints, side: Literal["right", "left"]):
544
+ if keypoints is None:
545
+ return img, None
546
+ kps = np.zeros((42, 2))
547
+ if side == "right":
548
+ if len(keypoints[0]) == 0:
549
+ return img, keypoints
550
+ keypoints[0].pop()
551
+ len_kps = len(keypoints[0])
552
+ kps[:len_kps] = np.array(keypoints[0])
553
+ elif side == "left":
554
+ if len(keypoints[1]) == 0:
555
+ return img, keypoints
556
+ keypoints[1].pop()
557
+ len_kps = len(keypoints[1])
558
+ kps[21 : 21 + len_kps] = np.array(keypoints[1])
559
+ vis_hand = visualize_hand(kps, img, side, len_kps)
560
+ return vis_hand, keypoints
561
+
562
+
563
+ def reset_kps(img, keypoints, side: Literal["right", "left"]):
564
+ if keypoints is None:
565
+ return img, None
566
+ if side == "right":
567
+ keypoints[0] = []
568
+ elif side == "left":
569
+ keypoints[1] = []
570
+ return img, keypoints
571
+
572
+ # @spaces.GPU(duration=60)
573
+ def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg):
574
+ set_seed(seed)
575
+ z = torch.randn(
576
+ (num_gen, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]),
577
+ device=device,
578
+ )
579
+ print(f"z.device: {z.device}")
580
+ target_cond = target_cond.repeat(num_gen, 1, 1, 1).to(z.device)
581
+ ref_cond = ref_cond.repeat(num_gen, 1, 1, 1).to(z.device)
582
+ print(f"target_cond.max(): {target_cond.max()}, target_cond.min(): {target_cond.min()}")
583
+ print(f"ref_cond.max(): {ref_cond.max()}, ref_cond.min(): {ref_cond.min()}")
584
+ # novel view synthesis mode = off
585
+ nvs = torch.zeros(num_gen, dtype=torch.int, device=device)
586
+ z = torch.cat([z, z], 0)
587
+ model_kwargs = dict(
588
+ target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]),
589
+ ref_cond=torch.cat([ref_cond, torch.zeros_like(ref_cond)]),
590
+ nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]),
591
+ cfg_scale=cfg,
592
+ )
593
+
594
+ samples, _ = diffusion.p_sample_loop(
595
+ model.forward_with_cfg,
596
+ z.shape,
597
+ z,
598
+ clip_denoised=False,
599
+ model_kwargs=model_kwargs,
600
+ progress=True,
601
+ device=device,
602
+ ).chunk(2)
603
+ sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
604
+ sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
605
+ sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy())
606
+
607
+ results = []
608
+ results_pose = []
609
+ for i in range(MAX_N):
610
+ if i < num_gen:
611
+ results.append(sampled_images[i])
612
+ results_pose.append(visualize_hand(target_keypts, sampled_images[i]))
613
+ else:
614
+ results.append(placeholder)
615
+ results_pose.append(placeholder)
616
+ print(f"results[0].max(): {results[0].max()}")
617
+ return results, results_pose
618
+
619
+ # @spaces.GPU(duration=120)
620
+ def ready_sample(img_ori, inpaint_mask, keypts):
621
+ img = cv2.resize(img_ori[..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
622
+ sam_predictor.set_image(img)
623
+ if len(keypts[0]) == 0:
624
+ keypts[0] = np.zeros((21, 2))
625
+ elif len(keypts[0]) == 21:
626
+ keypts[0] = np.array(keypts[0], dtype=np.float32)
627
+ else:
628
+ gr.Info("Number of right hand keypoints should be either 0 or 21.")
629
+ return None, None
630
+
631
+ if len(keypts[1]) == 0:
632
+ keypts[1] = np.zeros((21, 2))
633
+ elif len(keypts[1]) == 21:
634
+ keypts[1] = np.array(keypts[1], dtype=np.float32)
635
+ else:
636
+ gr.Info("Number of left hand keypoints should be either 0 or 21.")
637
+ return None, None
638
+
639
+ keypts = np.concatenate(keypts, axis=0)
640
+ keypts = scale_keypoint(keypts, (LENGTH, LENGTH), opts.image_size)
641
+ # if keypts[0].sum() != 0 and keypts[21].sum() != 0:
642
+ # input_point = np.array([keypts[0], keypts[21]])
643
+ # # input_point = keypts
644
+ # input_label = np.array([1, 1])
645
+ # # input_label = np.ones_like(input_point[:, 0])
646
+ # elif keypts[0].sum() != 0:
647
+ # input_point = np.array(keypts[:1])
648
+ # # input_point = keypts[:21]
649
+ # input_label = np.array([1])
650
+ # # input_label = np.ones_like(input_point[:21, 0])
651
+ # elif keypts[21].sum() != 0:
652
+ # input_point = np.array(keypts[21:22])
653
+ # # input_point = keypts[21:]
654
+ # input_label = np.array([1])
655
+ # # input_label = np.ones_like(input_point[21:, 0])
656
+
657
+ box_shift_ratio = 0.5
658
+ box_size_factor = 1.2
659
+
660
+ if keypts[0].sum() != 0 and keypts[21].sum() != 0:
661
+ input_point = np.array(keypts)
662
+ input_box = np.stack([keypts.min(axis=0), keypts.max(axis=0)])
663
+ elif keypts[0].sum() != 0:
664
+ input_point = np.array(keypts[:21])
665
+ input_box = np.stack([keypts[:21].min(axis=0), keypts[:21].max(axis=0)])
666
+ elif keypts[21].sum() != 0:
667
+ input_point = np.array(keypts[21:])
668
+ input_box = np.stack([keypts[21:].min(axis=0), keypts[21:].max(axis=0)])
669
+ else:
670
+ raise ValueError(
671
+ "Something wrong. If no hand detected, it should not reach here."
672
+ )
673
+
674
+ input_label = np.ones_like(input_point[:, 0]).astype(np.int32)
675
+ box_trans = input_box[0] * box_shift_ratio + input_box[1] * (1 - box_shift_ratio)
676
+ input_box = ((input_box - box_trans) * box_size_factor + box_trans).reshape(-1)
677
+
678
+ masks, _, _ = sam_predictor.predict(
679
+ point_coords=input_point,
680
+ point_labels=input_label,
681
+ box=input_box[None, :],
682
+ multimask_output=False,
683
+ )
684
+ hand_mask = masks[0]
685
+
686
+ inpaint_latent_mask = torch.tensor(
687
+ cv2.resize(
688
+ inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
689
+ ),
690
+ dtype=torch.float,
691
+ # device=device,
692
+ ).unsqueeze(0)[None, ...]
693
+
694
+ def make_ref_cond(
695
+ img,
696
+ keypts,
697
+ hand_mask,
698
+ device=device,
699
+ target_size=(256, 256),
700
+ latent_size=(32, 32),
701
+ ):
702
+ image_transform = Compose(
703
+ [
704
+ ToTensor(),
705
+ Resize(target_size),
706
+ Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
707
+ ]
708
+ )
709
+ image = image_transform(img)
710
+ kpts_valid = check_keypoints_validity(keypts, target_size)
711
+ heatmaps = torch.tensor(
712
+ keypoint_heatmap(
713
+ scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0
714
+ )
715
+ * kpts_valid[:, None, None],
716
+ dtype=torch.float,
717
+ # device=device,
718
+ )[None, ...]
719
+ mask = torch.tensor(
720
+ cv2.resize(
721
+ hand_mask.astype(int),
722
+ dsize=latent_size,
723
+ interpolation=cv2.INTER_NEAREST,
724
+ ),
725
+ dtype=torch.float,
726
+ # device=device,
727
+ ).unsqueeze(0)[None, ...]
728
+ return image[None, ...], heatmaps, mask
729
+
730
+ image, heatmaps, mask = make_ref_cond(
731
+ img,
732
+ keypts,
733
+ hand_mask * (1 - inpaint_mask),
734
+ device=device,
735
+ target_size=opts.image_size,
736
+ latent_size=opts.latent_size,
737
+ )
738
+ latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
739
+ target_cond = torch.cat([heatmaps, torch.zeros_like(mask)], 1)
740
+ ref_cond = torch.cat([latent, heatmaps, mask], 1)
741
+ ref_cond = torch.zeros_like(ref_cond)
742
+
743
+ img32 = cv2.resize(img, opts.latent_size, interpolation=cv2.INTER_NEAREST)
744
+ assert mask.max() == 1
745
+ vis_mask32 = mask_image(
746
+ img32, inpaint_latent_mask[0,0].cpu().numpy(), (255,255,255), transparent=False
747
+ ).astype(np.uint8) # 1.0 - mask[0, 0].cpu().numpy()
748
+
749
+ assert np.unique(inpaint_mask).shape[0] <= 2
750
+ assert hand_mask.dtype == bool
751
+ mask256 = inpaint_mask # hand_mask * (1 - inpaint_mask)
752
+ vis_mask256 = mask_image(img, mask256, (255,255,255), transparent=False).astype(
753
+ np.uint8
754
+ ) # 1 - mask256
755
+
756
+ return (
757
+ ref_cond,
758
+ target_cond,
759
+ latent,
760
+ inpaint_latent_mask,
761
+ keypts,
762
+ vis_mask32,
763
+ vis_mask256,
764
+ )
765
+
766
+
767
+ def switch_mask_size(radio):
768
+ if radio == "256x256":
769
+ out = (gr.update(visible=False), gr.update(visible=True))
770
+ elif radio == "latent size (32x32)":
771
+ out = (gr.update(visible=True), gr.update(visible=False))
772
+ return out
773
+
774
+ # @spaces.GPU(duration=300)
775
+ def sample_inpaint(
776
+ ref_cond,
777
+ target_cond,
778
+ latent,
779
+ inpaint_latent_mask,
780
+ keypts,
781
+ num_gen,
782
+ seed,
783
+ cfg,
784
+ quality,
785
+ ):
786
+ set_seed(seed)
787
+ N = num_gen
788
+ jump_length = 10
789
+ jump_n_sample = quality
790
+ cfg_scale = cfg
791
+ z = torch.randn(
792
+ (N, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device=device
793
+ )
794
+ target_cond_N = target_cond.repeat(N, 1, 1, 1).to(z.device)
795
+ ref_cond_N = ref_cond.repeat(N, 1, 1, 1).to(z.device)
796
+ # novel view synthesis mode = off
797
+ nvs = torch.zeros(N, dtype=torch.int, device=device)
798
+ z = torch.cat([z, z], 0)
799
+ model_kwargs = dict(
800
+ target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]),
801
+ ref_cond=torch.cat([ref_cond_N, torch.zeros_like(ref_cond_N)]),
802
+ nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]),
803
+ cfg_scale=cfg_scale,
804
+ )
805
+
806
+ samples, _ = diffusion.inpaint_p_sample_loop(
807
+ model.forward_with_cfg,
808
+ z.shape,
809
+ latent.to(z.device),
810
+ inpaint_latent_mask.to(z.device),
811
+ z,
812
+ clip_denoised=False,
813
+ model_kwargs=model_kwargs,
814
+ progress=True,
815
+ device=z.device,
816
+ jump_length=jump_length,
817
+ jump_n_sample=jump_n_sample,
818
+ ).chunk(2)
819
+ sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
820
+ sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
821
+ sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy())
822
+
823
+ # visualize
824
+ results = []
825
+ results_pose = []
826
+ for i in range(FIX_MAX_N):
827
+ if i < num_gen:
828
+ results.append(sampled_images[i])
829
+ results_pose.append(visualize_hand(keypts, sampled_images[i]))
830
+ else:
831
+ results.append(placeholder)
832
+ results_pose.append(placeholder)
833
+ return results, results_pose
834
+
835
+
836
+ def flip_hand(
837
+ img, pose_img, cond: Optional[torch.Tensor], keypts: Optional[torch.Tensor] = None, pose_manual_img = None,
838
+ manual_kp_right=None, manual_kp_left=None
839
+ ):
840
+ if cond is None: # clear clicked
841
+ return None, None, None, None
842
+ img["composite"] = img["composite"][:, ::-1, :]
843
+ img["background"] = img["background"][:, ::-1, :]
844
+ img["layers"] = [layer[:, ::-1, :] for layer in img["layers"]]
845
+ pose_img = pose_img[:, ::-1, :]
846
+ cond = cond.flip(-1)
847
+ if keypts is not None: # cond is target_cond
848
+ if keypts[:21, :].sum() != 0:
849
+ keypts[:21, 0] = opts.image_size[1] - keypts[:21, 0]
850
+ # keypts[:21, 1] = opts.image_size[0] - keypts[:21, 1]
851
+ if keypts[21:, :].sum() != 0:
852
+ keypts[21:, 0] = opts.image_size[1] - keypts[21:, 0]
853
+ # keypts[21:, 1] = opts.image_size[0] - keypts[21:, 1]
854
+ if pose_manual_img is not None:
855
+ pose_manual_img = pose_manual_img[:, ::-1, :]
856
+ manual_kp_right = manual_kp_right[:, ::-1, :]
857
+ manual_kp_left = manual_kp_left[:, ::-1, :]
858
+ return img, pose_img, cond, keypts, pose_manual_img, manual_kp_right, manual_kp_left
859
+
860
+
861
+ def resize_to_full(img):
862
+ img["background"] = cv2.resize(img["background"], (LENGTH, LENGTH))
863
+ img["composite"] = cv2.resize(img["composite"], (LENGTH, LENGTH))
864
+ img["layers"] = [cv2.resize(layer, (LENGTH, LENGTH)) for layer in img["layers"]]
865
+ return img
866
+
867
+
868
+ def clear_all():
869
+ return (
870
+ None,
871
+ None,
872
+ None,
873
+ None,
874
+ None,
875
+ False,
876
+ None,
877
+ None,
878
+ False,
879
+ None,
880
+ None,
881
+ None,
882
+ None,
883
+ None,
884
+ None,
885
+ None,
886
+ 1,
887
+ 42,
888
+ 3.0,
889
+ gr.update(interactive=False),
890
+ []
891
+ )
892
+
893
+
894
+ def fix_clear_all():
895
+ return (
896
+ None,
897
+ None,
898
+ None,
899
+ None,
900
+ None,
901
+ None,
902
+ None,
903
+ None,
904
+ None,
905
+ None,
906
+ None,
907
+ None,
908
+ None,
909
+ None,
910
+ None,
911
+ None,
912
+ None,
913
+ 1,
914
+ # (0,0),
915
+ 42,
916
+ 3.0,
917
+ 10,
918
+ )
919
+
920
+
921
+ def enable_component(image1, image2):
922
+ if image1 is None or image2 is None:
923
+ return gr.update(interactive=False)
924
+ if "background" in image1 and "layers" in image1 and "composite" in image1:
925
+ if (
926
+ image1["background"].sum() == 0
927
+ and (sum([im.sum() for im in image1["layers"]]) == 0)
928
+ and image1["composite"].sum() == 0
929
+ ):
930
+ return gr.update(interactive=False)
931
+ if "background" in image2 and "layers" in image2 and "composite" in image2:
932
+ if (
933
+ image2["background"].sum() == 0
934
+ and (sum([im.sum() for im in image2["layers"]]) == 0)
935
+ and image2["composite"].sum() == 0
936
+ ):
937
+ return gr.update(interactive=False)
938
+ return gr.update(interactive=True)
939
+
940
+
941
+ def set_visible(checkbox, kpts, img_clean, img_pose_right, img_pose_left, done=None, done_info=None):
942
+ if kpts is None:
943
+ kpts = [[], []]
944
+ if "Right hand" not in checkbox:
945
+ kpts[0] = []
946
+ vis_right = img_clean
947
+ update_right = gr.update(visible=False)
948
+ update_r_info = gr.update(visible=False)
949
+ else:
950
+ vis_right = img_pose_right
951
+ update_right = gr.update(visible=True)
952
+ update_r_info = gr.update(visible=True)
953
+
954
+ if "Left hand" not in checkbox:
955
+ kpts[1] = []
956
+ vis_left = img_clean
957
+ update_left = gr.update(visible=False)
958
+ update_l_info = gr.update(visible=False)
959
+ else:
960
+ vis_left = img_pose_left
961
+ update_left = gr.update(visible=True)
962
+ update_l_info = gr.update(visible=True)
963
+
964
+ ret = [
965
+ kpts,
966
+ vis_right,
967
+ vis_left,
968
+ update_right,
969
+ update_right,
970
+ update_right,
971
+ update_left,
972
+ update_left,
973
+ update_left,
974
+ update_r_info,
975
+ update_l_info,
976
+ ]
977
+ if done is not None:
978
+ if not checkbox:
979
+ ret.append(gr.update(visible=False))
980
+ ret.append(gr.update(visible=False))
981
+ else:
982
+ ret.append(gr.update(visible=True))
983
+ ret.append(gr.update(visible=True))
984
+ return tuple(ret)
985
+
986
+ def set_unvisible():
987
+ return (
988
+ gr.update(visible=False),
989
+ gr.update(visible=False),
990
+ gr.update(visible=False),
991
+ gr.update(visible=False),
992
+ gr.update(visible=False),
993
+ gr.update(visible=False),
994
+ gr.update(visible=False),
995
+ gr.update(visible=False),
996
+ gr.update(visible=False),
997
+ gr.update(visible=False),
998
+ gr.update(visible=False),
999
+ gr.update(visible=False)
1000
+ )
1001
+
1002
+ def set_no_hands(decider, component):
1003
+ if decider is None:
1004
+ no_hands = cv2.resize(np.array(Image.open("no_hands.png"))[..., :3], (LENGTH, LENGTH))
1005
+ return no_hands
1006
+ else:
1007
+ return component
1008
+
1009
+ # def visible_component(decider, component):
1010
+ # if decider is not None:
1011
+ # update_component = gr.update(visible=True)
1012
+ # else:
1013
+ # update_component = gr.update(visible=False)
1014
+ # return update_component
1015
+
1016
+ def unvisible_component(decider, component):
1017
+ if decider is not None:
1018
+ update_component = gr.update(visible=False)
1019
+ else:
1020
+ update_component = gr.update(visible=True)
1021
+ return update_component
1022
+
1023
+ def make_change(decider, state):
1024
+ '''
1025
+ if decider is not None, change the state's value. True/False does not matter.
1026
+ '''
1027
+ if decider is not None:
1028
+ if state:
1029
+ state = False
1030
+ else:
1031
+ state = True
1032
+ return state
1033
+ else:
1034
+ return state
1035
+
1036
+ LENGTH = 480
1037
+
1038
+ example_ref_imgs = [
1039
+ [
1040
+ "sample_images/sample1.jpg",
1041
+ ],
1042
+ [
1043
+ "sample_images/sample2.jpg",
1044
+ ],
1045
+ [
1046
+ "sample_images/sample3.jpg",
1047
+ ],
1048
+ [
1049
+ "sample_images/sample4.jpg",
1050
+ ],
1051
+ # [
1052
+ # "sample_images/sample5.jpg",
1053
+ # ],
1054
+ [
1055
+ "sample_images/sample6.jpg",
1056
+ ],
1057
+ # [
1058
+ # "sample_images/sample7.jpg",
1059
+ # ],
1060
+ # [
1061
+ # "sample_images/sample8.jpg",
1062
+ # ],
1063
+ # [
1064
+ # "sample_images/sample9.jpg",
1065
+ # ],
1066
+ # [
1067
+ # "sample_images/sample10.jpg",
1068
+ # ],
1069
+ # [
1070
+ # "sample_images/sample11.jpg",
1071
+ # ],
1072
+ # ["pose_images/pose1.jpg"],
1073
+ # ["pose_images/pose2.jpg"],
1074
+ # ["pose_images/pose3.jpg"],
1075
+ # ["pose_images/pose4.jpg"],
1076
+ # ["pose_images/pose5.jpg"],
1077
+ # ["pose_images/pose6.jpg"],
1078
+ # ["pose_images/pose7.jpg"],
1079
+ # ["pose_images/pose8.jpg"],
1080
+ ]
1081
+ example_target_imgs = [
1082
+ # [
1083
+ # "sample_images/sample1.jpg",
1084
+ # ],
1085
+ # [
1086
+ # "sample_images/sample2.jpg",
1087
+ # ],
1088
+ # [
1089
+ # "sample_images/sample3.jpg",
1090
+ # ],
1091
+ # [
1092
+ # "sample_images/sample4.jpg",
1093
+ # ],
1094
+ [
1095
+ "sample_images/sample5.jpg",
1096
+ ],
1097
+ # [
1098
+ # "sample_images/sample6.jpg",
1099
+ # ],
1100
+ # [
1101
+ # "sample_images/sample7.jpg",
1102
+ # ],
1103
+ # [
1104
+ # "sample_images/sample8.jpg",
1105
+ # ],
1106
+ [
1107
+ "sample_images/sample9.jpg",
1108
+ ],
1109
+ [
1110
+ "sample_images/sample10.jpg",
1111
+ ],
1112
+ [
1113
+ "sample_images/sample11.jpg",
1114
+ ],
1115
+ ["pose_images/pose1.jpg"],
1116
+ # ["pose_images/pose2.jpg"],
1117
+ # ["pose_images/pose3.jpg"],
1118
+ # ["pose_images/pose4.jpg"],
1119
+ # ["pose_images/pose5.jpg"],
1120
+ # ["pose_images/pose6.jpg"],
1121
+ # ["pose_images/pose7.jpg"],
1122
+ # ["pose_images/pose8.jpg"],
1123
+ ]
1124
+ fix_example_imgs = [
1125
+ ["bad_hands/1.jpg"], # "bad_hands/1_mask.jpg"],
1126
+ # ["bad_hands/2.jpg"], # "bad_hands/2_mask.jpg"],
1127
+ ["bad_hands/3.jpg"], # "bad_hands/3_mask.jpg"],
1128
+ # ["bad_hands/4.jpg"], # "bad_hands/4_mask.jpg"],
1129
+ ["bad_hands/5.jpg"], # "bad_hands/5_mask.jpg"],
1130
+ ["bad_hands/6.jpg"], # "bad_hands/6_mask.jpg"],
1131
+ ["bad_hands/7.jpg"], # "bad_hands/7_mask.jpg"],
1132
+ # ["bad_hands/8.jpg"], # "bad_hands/8_mask.jpg"],
1133
+ # ["bad_hands/9.jpg"], # "bad_hands/9_mask.jpg"],
1134
+ # ["bad_hands/10.jpg"], # "bad_hands/10_mask.jpg"],
1135
+ # ["bad_hands/11.jpg"], # "bad_hands/11_mask.jpg"],
1136
+ # ["bad_hands/12.jpg"], # "bad_hands/12_mask.jpg"],
1137
+ # ["bad_hands/13.jpg"], # "bad_hands/13_mask.jpg"],
1138
+ ["bad_hands/14.jpg"],
1139
+ ["bad_hands/15.jpg"],
1140
+ ]
1141
+ custom_css = """
1142
+ .gradio-container .examples img {
1143
+ width: 240px !important;
1144
+ height: 240px !important;
1145
+ }
1146
+ """
1147
+
1148
+ _HEADER_ = '''
1149
+ <div style="text-align: center;">
1150
+ <h1><b>FoundHand: Large-Scale Domain-Specific Learning for Controllable Hand Image Generation</b></h1>
1151
+ <h2 style="color: #777777;">CVPR 2025</h2>
1152
+ <style>
1153
+ .link-spacing {
1154
+ margin-right: 20px;
1155
+ }
1156
+ </style>
1157
+ <p style="font-size: 15px;">
1158
+ <span style="display: inline-block; margin-right: 30px;">Brown University</span>
1159
+ <span style="display: inline-block;">Meta Reality Labs</span>
1160
+ </p>
1161
+ <h3>
1162
+ <a href='https://arxiv.org/abs/2412.02690' target='_blank' class="link-spacing">Paper</a>
1163
+ <a href='https://ivl.cs.brown.edu/research/foundhand.html' target='_blank' class="link-spacing">Project Page</a>
1164
+ <a href='' target='_blank' class="link-spacing">Code</a>
1165
+ <a href='' target='_blank'>Model Weights</a>
1166
+ </h3>
1167
+ <p>Below are two important abilities of our model. First, we can <b>edit hand poses</b> given two hand images - one is the image to edit, and the other one provides target hand pose. Second, we can automatically <b>fix malformed hand images</b>, following the user-provided target hand pose and area to fix.</p>
1168
+ </div>
1169
+ '''
1170
+
1171
+ _CITE_ = r"""
1172
+ ```
1173
+ @article{chen2024foundhand,
1174
+ title={FoundHand: Large-Scale Domain-Specific Learning for Controllable Hand Image Generation},
1175
+ author={Chen, Kefan and Min, Chaerin and Zhang, Linguang and Hampali, Shreyas and Keskin, Cem and Sridhar, Srinath},
1176
+ journal={arXiv preprint arXiv:2412.02690},
1177
+ year={2024}
1178
+ }
1179
+ ```
1180
+ """
1181
+
1182
+ with gr.Blocks(css=custom_css, theme="soft") as demo:
1183
+ gr.Markdown(_HEADER_)
1184
+ with gr.Tab("Edit Hand Poses"):
1185
+ ref_img = gr.State(value=None)
1186
+ ref_im_raw = gr.State(value=None)
1187
+ ref_kp_raw = gr.State(value=0)
1188
+ ref_kp_got = gr.State(value=None)
1189
+ dump = gr.State(value=None)
1190
+ ref_cond = gr.State(value=None)
1191
+ ref_manual_cond = gr.State(value=None)
1192
+ ref_auto_cond = gr.State(value=None)
1193
+ keypts = gr.State(value=None)
1194
+ target_img = gr.State(value=None)
1195
+ target_cond = gr.State(value=None)
1196
+ target_keypts = gr.State(value=None)
1197
+ dump = gr.State(value=None)
1198
+ with gr.Row():
1199
+ with gr.Column():
1200
+ gr.Markdown(
1201
+ """<p style="text-align: center; font-size: 20px; font-weight: bold;">1. Upload a hand image to edit 📥</p>"""
1202
+ )
1203
+ gr.Markdown(
1204
+ """<p style="text-align: center;">&#9312; Optionally crop the image</p>"""
1205
+ )
1206
+ # gr.Markdown("""<p style="text-align: center;"><br></p>""")
1207
+ ref = gr.ImageEditor(
1208
+ type="numpy",
1209
+ label="Reference",
1210
+ show_label=True,
1211
+ height=LENGTH,
1212
+ width=LENGTH,
1213
+ brush=False,
1214
+ layers=False,
1215
+ crop_size="1:1",
1216
+ )
1217
+ gr.Examples(example_ref_imgs, [ref], examples_per_page=20)
1218
+ gr.Markdown(
1219
+ """<p style="text-align: center;">&#9313; Hit the &quot;Finish Cropping&quot; button to get hand pose</p>"""
1220
+ )
1221
+ ref_finish_crop = gr.Button(value="Finish Cropping", interactive=False)
1222
+ with gr.Tab("Automatic hand keypoints"):
1223
+ ref_pose = gr.Image(
1224
+ type="numpy",
1225
+ label="Reference Pose",
1226
+ show_label=True,
1227
+ height=LENGTH,
1228
+ width=LENGTH,
1229
+ interactive=False,
1230
+ )
1231
+ ref_use_auto = gr.Button(value="Click here to use automatic, not manual", interactive=False, visible=True)
1232
+ with gr.Tab("Manual hand keypoints"):
1233
+ ref_manual_checkbox_info = gr.Markdown(
1234
+ """<p style="text-align: center;"><b>Step 1.</b> Tell us if this is right, left, or both hands.</p>""",
1235
+ visible=True,
1236
+ )
1237
+ ref_manual_checkbox = gr.CheckboxGroup(
1238
+ ["Right hand", "Left hand"],
1239
+ # label="Hand side",
1240
+ # info="Hand pose failed to automatically detected. Now let's enable user-provided hand pose. First of all, please tell us if this is right, left, or both hands",
1241
+ show_label=False,
1242
+ visible=True,
1243
+ interactive=True,
1244
+ )
1245
+ ref_manual_kp_r_info = gr.Markdown(
1246
+ """<p style="text-align: center;"><b>Step 2.</b> Click on image to provide hand keypoints for <b>right</b> hand. See \"OpenPose Keypoint Convention\" for guidance.</p>""",
1247
+ visible=False,
1248
+ )
1249
+ ref_manual_kp_right = gr.Image(
1250
+ type="numpy",
1251
+ label="Keypoint Selection (right hand)",
1252
+ show_label=True,
1253
+ height=LENGTH,
1254
+ width=LENGTH,
1255
+ interactive=False,
1256
+ visible=False,
1257
+ sources=[],
1258
+ )
1259
+ with gr.Row():
1260
+ ref_manual_undo_right = gr.Button(
1261
+ value="Undo", interactive=True, visible=False
1262
+ )
1263
+ ref_manual_reset_right = gr.Button(
1264
+ value="Reset", interactive=True, visible=False
1265
+ )
1266
+ ref_manual_kp_l_info = gr.Markdown(
1267
+ """<p style="text-align: center;"><b>Step 2.</b> Click on image to provide hand keypoints for <b>left</b> hand. See \"OpenPose keypoint convention\" for guidance.</p>""",
1268
+ visible=False
1269
+ )
1270
+ ref_manual_kp_left = gr.Image(
1271
+ type="numpy",
1272
+ label="Keypoint Selection (left hand)",
1273
+ show_label=True,
1274
+ height=LENGTH,
1275
+ width=LENGTH,
1276
+ interactive=False,
1277
+ visible=False,
1278
+ sources=[],
1279
+ )
1280
+ with gr.Row():
1281
+ ref_manual_undo_left = gr.Button(
1282
+ value="Undo", interactive=True, visible=False
1283
+ )
1284
+ ref_manual_reset_left = gr.Button(
1285
+ value="Reset", interactive=True, visible=False
1286
+ )
1287
+ ref_manual_done_info = gr.Markdown(
1288
+ """<p style="text-align: center;"><b>Step 3.</b> Hit \"Done\" button to confirm.</p>""",
1289
+ visible=False,
1290
+ )
1291
+ ref_manual_done = gr.Button(value="Done", interactive=True, visible=False)
1292
+ ref_manual_pose = gr.Image(
1293
+ type="numpy",
1294
+ label="Reference Pose",
1295
+ show_label=True,
1296
+ height=LENGTH,
1297
+ width=LENGTH,
1298
+ interactive=False,
1299
+ visible=False
1300
+ )
1301
+ ref_use_manual = gr.Button(value="Click here to use manual, not automatic", interactive=True, visible=False)
1302
+ ref_manual_instruct = gr.Markdown(
1303
+ value="""<p style="text-align: left; font-weight: bold; ">OpenPose Keypoints Convention</p>""",
1304
+ visible=True
1305
+ )
1306
+ ref_manual_openpose = gr.Image(
1307
+ value="openpose.png",
1308
+ type="numpy",
1309
+ # label="OpenPose keypoints convention",
1310
+ show_label=False,
1311
+ height=LENGTH // 2,
1312
+ width=LENGTH // 2,
1313
+ interactive=False,
1314
+ visible=True
1315
+ )
1316
+ gr.Markdown(
1317
+ """<p style="text-align: center;">&#9314; Optionally flip the hand</p>"""
1318
+ )
1319
+ ref_flip = gr.Checkbox(
1320
+ value=False, label="Flip Handedness (Reference)", interactive=False
1321
+ )
1322
+ with gr.Column():
1323
+ gr.Markdown(
1324
+ """<p style="text-align: center; font-size: 20px; font-weight: bold;">2. Upload a hand image for target hand pose 📥</p>"""
1325
+ )
1326
+ gr.Markdown(
1327
+ """<p style="text-align: center;">&#9312; Optionally crop the image</p>"""
1328
+ )
1329
+ target = gr.ImageEditor(
1330
+ type="numpy",
1331
+ label="Target",
1332
+ show_label=True,
1333
+ height=LENGTH,
1334
+ width=LENGTH,
1335
+ brush=False,
1336
+ layers=False,
1337
+ crop_size="1:1",
1338
+ )
1339
+ gr.Examples(example_target_imgs, [target], examples_per_page=20)
1340
+ gr.Markdown(
1341
+ """<p style="text-align: center;">&#9313; Hit the &quot;Finish Cropping&quot; button to get hand pose</p>"""
1342
+ )
1343
+ target_finish_crop = gr.Button(
1344
+ value="Finish Cropping", interactive=False
1345
+ )
1346
+ target_pose = gr.Image(
1347
+ type="numpy",
1348
+ label="Target Pose",
1349
+ show_label=True,
1350
+ height=LENGTH,
1351
+ width=LENGTH,
1352
+ interactive=False,
1353
+ )
1354
+ gr.Markdown(
1355
+ """<p style="text-align: center;">&#9314; Optionally flip the hand</p>"""
1356
+ )
1357
+ target_flip = gr.Checkbox(
1358
+ value=False, label="Flip Handedness (Target)", interactive=False
1359
+ )
1360
+ with gr.Column():
1361
+ gr.Markdown(
1362
+ """<p style="text-align: center; font-size: 20px; font-weight: bold;">3. Press &quot;Run&quot; to get the edited results 🎯</p>"""
1363
+ )
1364
+ # gr.Markdown(
1365
+ # """<p style="text-align: center;">[NOTE] Run will be enabled after the previous steps have been completed</p>"""
1366
+ # )
1367
+ run = gr.Button(value="Run", interactive=False)
1368
+ gr.Markdown(
1369
+ """<p style="text-align: center;">⚠️ ~20s per generation with RTX3090. ~50s with A100. <br>(For example, if you set Number of generations as 2, it would take around 40s)</p>"""
1370
+ )
1371
+ results = gr.Gallery(
1372
+ type="numpy",
1373
+ label="Results",
1374
+ show_label=True,
1375
+ height=LENGTH,
1376
+ min_width=LENGTH,
1377
+ columns=MAX_N,
1378
+ interactive=False,
1379
+ preview=True,
1380
+ )
1381
+ results_pose = gr.Gallery(
1382
+ type="numpy",
1383
+ label="Results Pose",
1384
+ show_label=True,
1385
+ height=LENGTH,
1386
+ min_width=LENGTH,
1387
+ columns=MAX_N,
1388
+ interactive=False,
1389
+ preview=True,
1390
+ )
1391
+ gr.Markdown(
1392
+ """<p style="text-align: center;">✨ Hit &quot;Clear&quot; to restart from the beginning</p>"""
1393
+ )
1394
+ clear = gr.ClearButton()
1395
+
1396
+ # gr.Markdown(
1397
+ # """<p style="text-align: left; font-size: 25px;"><b>More options</b></p>"""
1398
+ # )
1399
+ with gr.Tab("More options"):
1400
+ with gr.Row():
1401
+ n_generation = gr.Slider(
1402
+ label="Number of generations",
1403
+ value=1,
1404
+ minimum=1,
1405
+ maximum=MAX_N,
1406
+ step=1,
1407
+ randomize=False,
1408
+ interactive=True,
1409
+ )
1410
+ seed = gr.Slider(
1411
+ label="Seed",
1412
+ value=42,
1413
+ minimum=0,
1414
+ maximum=10000,
1415
+ step=1,
1416
+ randomize=False,
1417
+ interactive=True,
1418
+ )
1419
+ cfg = gr.Slider(
1420
+ label="Classifier free guidance scale",
1421
+ value=2.5,
1422
+ minimum=0.0,
1423
+ maximum=10.0,
1424
+ step=0.1,
1425
+ randomize=False,
1426
+ interactive=True,
1427
+ )
1428
+
1429
+ ref.change(enable_component, [ref, ref], ref_finish_crop)
1430
+ # ref_finish_crop.click(get_ref_anno, [ref], [ref_img, ref_pose, ref_cond])
1431
+ ref_finish_crop.click(prepare_ref_anno, [ref], [ref_im_raw, ref_kp_raw])
1432
+ # ref_kp_raw.change(make_change, [ref_kp_raw, ref_kp_watcher], ref_kp_watcher)
1433
+ # ref_kp_raw.change(set_no_hands, [ref_kp_raw, ref_pose], ref_pose)
1434
+ ref_kp_raw.change(lambda x: x, ref_im_raw, ref_manual_kp_right)
1435
+ ref_kp_raw.change(lambda x: x, ref_im_raw, ref_manual_kp_left)
1436
+ # ref_kp_raw.change(unvisible_component, [ref_kp_raw, ref_manual_checkbox], ref_manual_checkbox)
1437
+ # ref_kp_raw.change(unvisible_component, [ref_kp_raw, ref_manual_checkbox_info], ref_manual_checkbox_info)
1438
+ # ref_kp_raw.change(unvisible_component, [ref_kp_raw, ref_manual_openpose], ref_manual_openpose)
1439
+ # ref_kp_raw.change(unvisible_component, [ref_kp_raw, ref_manual_instruct], ref_manual_instruct)
1440
+ # ref_kp_raw.change(lambda x: x, ref_kp_raw, ref_kp_got)
1441
+ ref_manual_checkbox.select(
1442
+ set_visible,
1443
+ [ref_manual_checkbox, ref_kp_got, ref_im_raw, ref_manual_kp_right, ref_manual_kp_left, ref_manual_done],
1444
+ [
1445
+ ref_kp_got,
1446
+ ref_manual_kp_right,
1447
+ ref_manual_kp_left,
1448
+ ref_manual_kp_right,
1449
+ ref_manual_undo_right,
1450
+ ref_manual_reset_right,
1451
+ ref_manual_kp_left,
1452
+ ref_manual_undo_left,
1453
+ ref_manual_reset_left,
1454
+ ref_manual_kp_r_info,
1455
+ ref_manual_kp_l_info,
1456
+ ref_manual_done,
1457
+ ref_manual_done_info
1458
+ ]
1459
+ )
1460
+ ref_manual_kp_right.select(
1461
+ get_kps, [ref_im_raw, ref_kp_got, gr.State("right")], [ref_manual_kp_right, ref_kp_got]
1462
+ )
1463
+ ref_manual_undo_right.click(
1464
+ undo_kps, [ref_im_raw, ref_kp_got, gr.State("right")], [ref_manual_kp_right, ref_kp_got]
1465
+ )
1466
+ ref_manual_reset_right.click(
1467
+ reset_kps, [ref_im_raw, ref_kp_got, gr.State("right")], [ref_manual_kp_right, ref_kp_got]
1468
+ )
1469
+ ref_manual_kp_left.select(
1470
+ get_kps, [ref_im_raw, ref_kp_got, gr.State("left")], [ref_manual_kp_left, ref_kp_got]
1471
+ )
1472
+ ref_manual_undo_left.click(
1473
+ undo_kps, [ref_im_raw, ref_kp_got, gr.State("left")], [ref_manual_kp_left, ref_kp_got]
1474
+ )
1475
+ ref_manual_reset_left.click(
1476
+ reset_kps, [ref_im_raw, ref_kp_got, gr.State("left")], [ref_manual_kp_left, ref_kp_got]
1477
+ )
1478
+ # ref_manual_done.click(lambda x: ~x, ref_kp_watcher, ref_kp_watcher)
1479
+ ref_manual_done.click(get_ref_anno, [ref_im_raw, ref_kp_got], [ref_img, ref_manual_pose, ref_manual_cond])
1480
+ ref_manual_cond.change(lambda x: x, ref_manual_cond, ref_cond)
1481
+ ref_use_manual.click(lambda x: x, ref_manual_cond, ref_cond)
1482
+ ref_use_manual.click(lambda x: gr.Info("Manual hand keypoints will be used for 'Reference'", duration=3))
1483
+ ref_manual_done.click(lambda x: gr.update(visible=True), ref_manual_pose, ref_manual_pose)
1484
+ ref_manual_done.click(lambda x: gr.update(visible=True), ref_use_manual, ref_use_manual)
1485
+ ref_manual_pose.change(enable_component, [ref_manual_pose, ref_manual_pose], ref_manual_done)
1486
+ # ref_pose.change(enable_component, [ref_pose, gr.State(value=True)], ref_ok)
1487
+ ref_kp_raw.change(get_ref_anno, [ref_im_raw, ref_kp_raw], [ref_img, ref_pose, ref_auto_cond])
1488
+ ref_auto_cond.change(lambda x: x, ref_auto_cond, ref_cond)
1489
+ ref_use_auto.click(lambda x: x, ref_auto_cond, ref_cond)
1490
+ ref_use_auto.click(lambda x: gr.Info("Automatic hand keypoints will be used for 'Reference'", duration=3))
1491
+ ref_pose.change(enable_component, [ref_kp_raw, ref_pose], ref_use_auto)
1492
+ ref_pose.change(enable_component, [ref_img, ref_pose], ref_flip)
1493
+ ref_manual_pose.change(enable_component, [ref_img, ref_manual_pose], ref_flip)
1494
+ ref_flip.select(
1495
+ flip_hand, [ref, ref_pose, ref_cond, gr.State(value=None), ref_manual_pose, ref_manual_kp_right, ref_manual_kp_left], [ref, ref_pose, ref_cond, dump, ref_manual_pose, ref_manual_kp_right, ref_manual_kp_left]
1496
+ )
1497
+ target.change(enable_component, [target, target], target_finish_crop)
1498
+ target_finish_crop.click(
1499
+ get_target_anno,
1500
+ [target],
1501
+ [target_img, target_pose, target_cond, target_keypts],
1502
+ )
1503
+ target_pose.change(enable_component, [target_img, target_pose], target_flip)
1504
+ target_flip.select(
1505
+ flip_hand,
1506
+ [target, target_pose, target_cond, target_keypts],
1507
+ [target, target_pose, target_cond, target_keypts],
1508
+ )
1509
+ ref_pose.change(enable_component, [ref_pose, target_pose], run)
1510
+ ref_manual_pose.change(enable_component, [ref_manual_pose, target_pose], run)
1511
+ target_pose.change(enable_component, [ref_pose, target_pose], run)
1512
+ run.click(
1513
+ sample_diff,
1514
+ [ref_cond, target_cond, target_keypts, n_generation, seed, cfg],
1515
+ [results, results_pose],
1516
+ )
1517
+ clear.click(
1518
+ clear_all,
1519
+ [],
1520
+ [
1521
+ ref,
1522
+ ref_manual_kp_right,
1523
+ ref_manual_kp_left,
1524
+ ref_pose,
1525
+ ref_manual_pose,
1526
+ ref_flip,
1527
+ target,
1528
+ target_pose,
1529
+ target_flip,
1530
+ results,
1531
+ results_pose,
1532
+ ref_img,
1533
+ ref_cond,
1534
+ # mask,
1535
+ target_img,
1536
+ target_cond,
1537
+ target_keypts,
1538
+ n_generation,
1539
+ seed,
1540
+ cfg,
1541
+ ref_kp_raw,
1542
+ ref_manual_checkbox
1543
+ ],
1544
+ )
1545
+ clear.click(
1546
+ set_unvisible,
1547
+ [],
1548
+ [
1549
+ # ref_manual_checkbox,
1550
+ # ref_manual_instruct,
1551
+ # ref_manual_openpose,
1552
+ ref_manual_kp_r_info,
1553
+ ref_manual_kp_l_info,
1554
+ ref_manual_undo_left,
1555
+ ref_manual_undo_right,
1556
+ ref_manual_reset_left,
1557
+ ref_manual_reset_right,
1558
+ ref_manual_done,
1559
+ ref_manual_done_info,
1560
+ ref_manual_pose,
1561
+ ref_use_manual,
1562
+ ref_manual_kp_right,
1563
+ ref_manual_kp_left
1564
+ ]
1565
+ )
1566
+
1567
+ # gr.Markdown("""<p style="font-size: 25px; font-weight: bold;">Examples</p>""")
1568
+ # with gr.Tab("Reference"):
1569
+ # with gr.Row():
1570
+ # gr.Examples(example_imgs, [ref], examples_per_page=20)
1571
+ # with gr.Tab("Target"):
1572
+ # with gr.Row():
1573
+ # gr.Examples(example_imgs, [target], examples_per_page=20)
1574
+ with gr.Tab("Fix Hands"):
1575
+ fix_inpaint_mask = gr.State(value=None)
1576
+ fix_original = gr.State(value=None)
1577
+ fix_img = gr.State(value=None)
1578
+ fix_kpts = gr.State(value=None)
1579
+ fix_kpts_np = gr.State(value=None)
1580
+ fix_ref_cond = gr.State(value=None)
1581
+ fix_target_cond = gr.State(value=None)
1582
+ fix_latent = gr.State(value=None)
1583
+ fix_inpaint_latent = gr.State(value=None)
1584
+ # fix_size_memory = gr.State(value=(0, 0))
1585
+ # gr.Markdown("""<p style="text-align: center; font-size: 25px; font-weight: bold; ">⚠️ Note</p>""")
1586
+ # gr.Markdown("""<p>"Fix Hands" with A100 needs around 6 mins, which is beyond the ZeroGPU quota (5 mins). Please either purchase additional gpus from Hugging Face or wait for us to open-source our code soon so that you can use your own gpus🙏 </p>""")
1587
+ with gr.Row():
1588
+ with gr.Column():
1589
+ # gr.Markdown(
1590
+ # """<p style="text-align: center; font-size: 25px; font-weight: bold; ">1. Image Cropping & Brushing</p>"""
1591
+ # )
1592
+ gr.Markdown(
1593
+ """<p style="text-align: center; font-size: 20px; font-weight: bold;">1. Upload a malformed hand image to fix 📥</p>"""
1594
+ )
1595
+ gr.Markdown(
1596
+ """<p style="text-align: center;">&#9312; Optionally crop the image around the hand</p>"""
1597
+ )
1598
+ # gr.Markdown(
1599
+ # """<p style="text-align: center; font-size: 20px; font-weight: bold; ">A. Crop</p>"""
1600
+ # )
1601
+ fix_crop = gr.ImageEditor(
1602
+ type="numpy",
1603
+ sources=["upload", "webcam", "clipboard"],
1604
+ label="Image crop",
1605
+ show_label=True,
1606
+ height=LENGTH,
1607
+ width=LENGTH,
1608
+ layers=False,
1609
+ crop_size="1:1",
1610
+ brush=False,
1611
+ image_mode="RGBA",
1612
+ container=False,
1613
+ )
1614
+ fix_example = gr.Examples(
1615
+ fix_example_imgs,
1616
+ inputs=[fix_crop],
1617
+ examples_per_page=20,
1618
+ )
1619
+ gr.Markdown(
1620
+ """<p style="text-align: center;">&#9313; Brush area (e.g., wrong finger) that needs to be fixed. This will serve as an inpaint mask</p>"""
1621
+ )
1622
+ # gr.Markdown(
1623
+ # """<p style="text-align: center; font-size: 20px; font-weight: bold; ">B. Brush</p>"""
1624
+ # )
1625
+ fix_ref = gr.ImageEditor(
1626
+ type="numpy",
1627
+ label="Image brush",
1628
+ sources=(),
1629
+ show_label=True,
1630
+ height=LENGTH,
1631
+ width=LENGTH,
1632
+ layers=False,
1633
+ transforms=("brush"),
1634
+ brush=gr.Brush(
1635
+ colors=["rgb(255, 255, 255)"], default_size=20
1636
+ ), # 204, 50, 50
1637
+ image_mode="RGBA",
1638
+ container=False,
1639
+ interactive=False,
1640
+ )
1641
+ fix_finish_crop = gr.Button(
1642
+ value="Finish Croping & Brushing", interactive=False
1643
+ )
1644
+ with gr.Column():
1645
+ # gr.Markdown(
1646
+ # """<p style="text-align: center; font-size: 25px; font-weight: bold; ">2. Keypoint Selection</p>"""
1647
+ # )
1648
+ gr.Markdown(
1649
+ """<p style="text-align: center; font-size: 20px; font-weight: bold;">2. Click on hand to get target hand pose</p>"""
1650
+ )
1651
+ # gr.Markdown(
1652
+ # """<p style="text-align: center;">On the hand, select 21 keypoints that you hope the output to be. <br>Please see the \"OpenPose keypoints convention\"</p>"""
1653
+ # )
1654
+ gr.Markdown(
1655
+ """<p style="text-align: center;">&#9312; Tell us if this is right, left, or both hands</p>"""
1656
+ )
1657
+ fix_checkbox = gr.CheckboxGroup(
1658
+ ["Right hand", "Left hand"],
1659
+ # value=["Right hand", "Left hand"],
1660
+ # label="Hand side",
1661
+ # info="Which side this hand is? Could be both.",
1662
+ show_label=False,
1663
+ interactive=False,
1664
+ )
1665
+ gr.Markdown(
1666
+ """<p style="text-align: center;">&#9313; On the image, click 21 hand keypoints. This will serve as target hand poses. See the \"OpenPose keypoints convention\" for guidance.</p>"""
1667
+ )
1668
+ fix_kp_r_info = gr.Markdown(
1669
+ """<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select right only</p>""",
1670
+ visible=False,
1671
+ )
1672
+ fix_kp_right = gr.Image(
1673
+ type="numpy",
1674
+ label="Keypoint Selection (right hand)",
1675
+ show_label=True,
1676
+ height=LENGTH,
1677
+ width=LENGTH,
1678
+ interactive=False,
1679
+ visible=False,
1680
+ sources=[],
1681
+ )
1682
+ with gr.Row():
1683
+ fix_undo_right = gr.Button(
1684
+ value="Undo", interactive=False, visible=False
1685
+ )
1686
+ fix_reset_right = gr.Button(
1687
+ value="Reset", interactive=False, visible=False
1688
+ )
1689
+ fix_kp_l_info = gr.Markdown(
1690
+ """<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select left only</p>""",
1691
+ visible=False
1692
+ )
1693
+ fix_kp_left = gr.Image(
1694
+ type="numpy",
1695
+ label="Keypoint Selection (left hand)",
1696
+ show_label=True,
1697
+ height=LENGTH,
1698
+ width=LENGTH,
1699
+ interactive=False,
1700
+ visible=False,
1701
+ sources=[],
1702
+ )
1703
+ with gr.Row():
1704
+ fix_undo_left = gr.Button(
1705
+ value="Undo", interactive=False, visible=False
1706
+ )
1707
+ fix_reset_left = gr.Button(
1708
+ value="Reset", interactive=False, visible=False
1709
+ )
1710
+ gr.Markdown(
1711
+ """<p style="text-align: left; font-weight: bold; ">OpenPose keypoints convention</p>"""
1712
+ )
1713
+ fix_openpose = gr.Image(
1714
+ value="openpose.png",
1715
+ type="numpy",
1716
+ # label="OpenPose keypoints convention",
1717
+ show_label=False,
1718
+ height=LENGTH // 2,
1719
+ width=LENGTH // 2,
1720
+ interactive=False,
1721
+ )
1722
+ with gr.Column():
1723
+ # gr.Markdown(
1724
+ # """<p style="text-align: center; font-size: 25px; font-weight: bold; ">3. Prepare Mask</p>"""
1725
+ # )
1726
+ gr.Markdown(
1727
+ """<p style="text-align: center; font-size: 20px; font-weight: bold;">3. Press &quot;Ready&quot; to start pre-processing</p>"""
1728
+ )
1729
+ fix_ready = gr.Button(value="Ready", interactive=False)
1730
+ # fix_mask_size = gr.Radio(
1731
+ # ["256x256", "latent size (32x32)"],
1732
+ # label="Visualized inpaint mask size",
1733
+ # interactive=False,
1734
+ # value="256x256",
1735
+ # )
1736
+ gr.Markdown(
1737
+ """<p style="text-align: center; font-weight: bold; ">Visualized (256, 256) Inpaint Mask</p>"""
1738
+ )
1739
+ fix_vis_mask32 = gr.Image(
1740
+ type="numpy",
1741
+ label=f"Visualized {opts.latent_size} Inpaint Mask",
1742
+ show_label=True,
1743
+ height=opts.latent_size,
1744
+ width=opts.latent_size,
1745
+ interactive=False,
1746
+ visible=False,
1747
+ )
1748
+ fix_vis_mask256 = gr.Image(
1749
+ type="numpy",
1750
+ # label=f"Visualized {opts.image_size} Inpaint Mask",
1751
+ visible=True,
1752
+ show_label=False,
1753
+ height=opts.image_size,
1754
+ width=opts.image_size,
1755
+ interactive=False,
1756
+ )
1757
+ gr.Markdown(
1758
+ """<p style="text-align: center;">[NOTE] Above should be inpaint mask that you brushed, NOT the segmentation mask of the entire hand. </p>"""
1759
+ )
1760
+ with gr.Column():
1761
+ # gr.Markdown(
1762
+ # """<p style="text-align: center; font-size: 25px; font-weight: bold; ">4. Results</p>"""
1763
+ # )
1764
+ gr.Markdown(
1765
+ """<p style="text-align: center; font-size: 20px; font-weight: bold;">4. Press &quot;Run&quot; to get the fixed hand image 🎯</p>"""
1766
+ )
1767
+ fix_run = gr.Button(value="Run", interactive=False)
1768
+ gr.Markdown(
1769
+ """<p style="text-align: center;">⚠️ >3min and ~24GB per generation</p>"""
1770
+ )
1771
+ fix_result = gr.Gallery(
1772
+ type="numpy",
1773
+ label="Results",
1774
+ show_label=True,
1775
+ height=LENGTH,
1776
+ min_width=LENGTH,
1777
+ columns=FIX_MAX_N,
1778
+ interactive=False,
1779
+ preview=True,
1780
+ )
1781
+ fix_result_pose = gr.Gallery(
1782
+ type="numpy",
1783
+ label="Results Pose",
1784
+ show_label=True,
1785
+ height=LENGTH,
1786
+ min_width=LENGTH,
1787
+ columns=FIX_MAX_N,
1788
+ interactive=False,
1789
+ preview=True,
1790
+ )
1791
+ gr.Markdown(
1792
+ """<p style="text-align: center;">✨ Hit &quot;Clear&quot; to restart from the beginning</p>"""
1793
+ )
1794
+ fix_clear = gr.ClearButton()
1795
+
1796
+ gr.Markdown(
1797
+ """<p style="text-align: left; font-size: 25px;"><b>More options</b></p>"""
1798
+ )
1799
+ gr.Markdown(
1800
+ "⚠️ Currently, Number of generation > 1 could lead to out-of-memory"
1801
+ )
1802
+ with gr.Row():
1803
+ fix_n_generation = gr.Slider(
1804
+ label="Number of generations",
1805
+ value=1,
1806
+ minimum=1,
1807
+ maximum=FIX_MAX_N,
1808
+ step=1,
1809
+ randomize=False,
1810
+ interactive=True,
1811
+ )
1812
+ fix_seed = gr.Slider(
1813
+ label="Seed",
1814
+ value=42,
1815
+ minimum=0,
1816
+ maximum=10000,
1817
+ step=1,
1818
+ randomize=False,
1819
+ interactive=True,
1820
+ )
1821
+ fix_cfg = gr.Slider(
1822
+ label="Classifier free guidance scale",
1823
+ value=3.0,
1824
+ minimum=0.0,
1825
+ maximum=10.0,
1826
+ step=0.1,
1827
+ randomize=False,
1828
+ interactive=True,
1829
+ )
1830
+ fix_quality = gr.Slider(
1831
+ label="Quality",
1832
+ value=10,
1833
+ minimum=1,
1834
+ maximum=10,
1835
+ step=1,
1836
+ randomize=False,
1837
+ interactive=True,
1838
+ )
1839
+ fix_crop.change(enable_component, [fix_crop, fix_crop], fix_ref)
1840
+ fix_crop.change(resize_to_full, fix_crop, fix_ref)
1841
+ fix_ref.change(enable_component, [fix_ref, fix_ref], fix_finish_crop)
1842
+ fix_finish_crop.click(get_mask_inpaint, [fix_ref], [fix_inpaint_mask])
1843
+ # fix_finish_crop.click(lambda x: x["background"], [fix_ref], [fix_kp_right])
1844
+ # fix_finish_crop.click(lambda x: x["background"], [fix_ref], [fix_kp_left])
1845
+ fix_finish_crop.click(lambda x: x["background"], [fix_crop], [fix_original])
1846
+ fix_finish_crop.click(visualize_ref, [fix_crop, fix_ref], [fix_img])
1847
+ fix_img.change(lambda x: x, [fix_img], [fix_kp_right])
1848
+ fix_img.change(lambda x: x, [fix_img], [fix_kp_left])
1849
+ fix_inpaint_mask.change(
1850
+ enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_checkbox
1851
+ )
1852
+ fix_inpaint_mask.change(
1853
+ enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_right
1854
+ )
1855
+ fix_inpaint_mask.change(
1856
+ enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_right
1857
+ )
1858
+ fix_inpaint_mask.change(
1859
+ enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_right
1860
+ )
1861
+ fix_inpaint_mask.change(
1862
+ enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_left
1863
+ )
1864
+ fix_inpaint_mask.change(
1865
+ enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_left
1866
+ )
1867
+ fix_inpaint_mask.change(
1868
+ enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_left
1869
+ )
1870
+ fix_inpaint_mask.change(
1871
+ enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_ready
1872
+ )
1873
+ # fix_inpaint_mask.change(
1874
+ # enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_run
1875
+ # )
1876
+ fix_checkbox.select(
1877
+ set_visible,
1878
+ [fix_checkbox, fix_kpts, fix_img, fix_kp_right, fix_kp_left],
1879
+ [
1880
+ fix_kpts,
1881
+ fix_kp_right,
1882
+ fix_kp_left,
1883
+ fix_kp_right,
1884
+ fix_undo_right,
1885
+ fix_reset_right,
1886
+ fix_kp_left,
1887
+ fix_undo_left,
1888
+ fix_reset_left,
1889
+ fix_kp_r_info,
1890
+ fix_kp_l_info,
1891
+ ],
1892
+ )
1893
+ fix_kp_right.select(
1894
+ get_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
1895
+ )
1896
+ fix_undo_right.click(
1897
+ undo_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
1898
+ )
1899
+ fix_reset_right.click(
1900
+ reset_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
1901
+ )
1902
+ fix_kp_left.select(
1903
+ get_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
1904
+ )
1905
+ fix_undo_left.click(
1906
+ undo_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
1907
+ )
1908
+ fix_reset_left.click(
1909
+ reset_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
1910
+ )
1911
+ # fix_kpts.change(check_keypoints, [fix_kpts], [fix_kp_right, fix_kp_left, fix_run])
1912
+ # fix_run.click(lambda x:gr.update(value=None), [], [fix_result, fix_result_pose])
1913
+ fix_vis_mask32.change(
1914
+ enable_component, [fix_vis_mask32, fix_vis_mask256], fix_run
1915
+ )
1916
+ # fix_vis_mask32.change(
1917
+ # enable_component, [fix_vis_mask32, fix_vis_mask256], fix_mask_size
1918
+ # )
1919
+ fix_ready.click(
1920
+ ready_sample,
1921
+ [fix_original, fix_inpaint_mask, fix_kpts],
1922
+ [
1923
+ fix_ref_cond,
1924
+ fix_target_cond,
1925
+ fix_latent,
1926
+ fix_inpaint_latent,
1927
+ fix_kpts_np,
1928
+ fix_vis_mask32,
1929
+ fix_vis_mask256,
1930
+ ],
1931
+ )
1932
+ # fix_mask_size.select(
1933
+ # switch_mask_size, [fix_mask_size], [fix_vis_mask32, fix_vis_mask256]
1934
+ # )
1935
+ fix_run.click(
1936
+ sample_inpaint,
1937
+ [
1938
+ fix_ref_cond,
1939
+ fix_target_cond,
1940
+ fix_latent,
1941
+ fix_inpaint_latent,
1942
+ fix_kpts_np,
1943
+ fix_n_generation,
1944
+ fix_seed,
1945
+ fix_cfg,
1946
+ fix_quality,
1947
+ ],
1948
+ [fix_result, fix_result_pose],
1949
+ )
1950
+ fix_clear.click(
1951
+ fix_clear_all,
1952
+ [],
1953
+ [
1954
+ fix_crop,
1955
+ fix_ref,
1956
+ fix_kp_right,
1957
+ fix_kp_left,
1958
+ fix_result,
1959
+ fix_result_pose,
1960
+ fix_inpaint_mask,
1961
+ fix_original,
1962
+ fix_img,
1963
+ fix_vis_mask32,
1964
+ fix_vis_mask256,
1965
+ fix_kpts,
1966
+ fix_kpts_np,
1967
+ fix_ref_cond,
1968
+ fix_target_cond,
1969
+ fix_latent,
1970
+ fix_inpaint_latent,
1971
+ fix_n_generation,
1972
+ # fix_size_memory,
1973
+ fix_seed,
1974
+ fix_cfg,
1975
+ fix_quality,
1976
+ ],
1977
+ )
1978
+
1979
+ # gr.Markdown("""<p style="font-size: 25px; font-weight: bold;">Examples</p>""")
1980
+ # fix_dump_ex = gr.Image(value=None, label="Original Image", visible=False)
1981
+ # fix_dump_ex_masked = gr.Image(value=None, label="After Brushing", visible=False)
1982
+ # with gr.Column():
1983
+ # fix_example = gr.Examples(
1984
+ # fix_example_imgs,
1985
+ # # run_on_click=True,
1986
+ # # fn=parse_fix_example,
1987
+ # # inputs=[fix_dump_ex, fix_dump_ex_masked],
1988
+ # # outputs=[fix_original, fix_ref, fix_img, fix_inpaint_mask],
1989
+ # inputs=[fix_crop],
1990
+ # examples_per_page=20,
1991
+ # )
1992
+
1993
+ gr.Markdown("<h1>Citation</h1>")
1994
+ gr.Markdown(
1995
+ """<p style="text-align: left;">If this was useful, please cite us! ❤️</p>"""
1996
+ )
1997
+ gr.Markdown(_CITE_)
1998
+
1999
+ print("Ready to launch..")
2000
+ _, _, shared_url = demo.queue().launch(
2001
+ share=True, server_name="0.0.0.0", server_port=7739
2002
+ )
2003
+ # demo.launch(share=True)
no_hands.png ADDED

Git LFS Details

  • SHA256: 80a0c25741912005bc5619e9965d89cfa28832a898292320a8929bdb64a95229
  • Pointer size: 129 Bytes
  • Size of remote file: 9.31 kB