Chaerin5 commited on
Commit
6a1d13f
·
1 Parent(s): 133d942

allow manual keypoints at edit hands; put fixed hand to original image

Browse files
Files changed (7) hide show
  1. .gitignore +6 -0
  2. README.md +1 -1
  3. app.py +563 -254
  4. brown_logo.png +3 -0
  5. meta_logo.png +3 -0
  6. sbatch/sbatch_demo.sh +38 -0
  7. vqvae.py +4 -1
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ settings.json
2
+ sbatch/err/
3
+ sbatch/out/
4
+ __pycache__/
5
+ diffusion/__pycache__/
6
+ *.pyc
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: ✋
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: false
10
  short_description: FoundHand
 
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.40.1
8
  app_file: app.py
9
  pinned: false
10
  short_description: FoundHand
app.py CHANGED
@@ -20,7 +20,10 @@ from copy import deepcopy
20
  from typing import Optional
21
  import requests
22
  from huggingface_hub import hf_hub_download
23
- import spaces
 
 
 
24
 
25
  MAX_N = 6
26
  FIX_MAX_N = 6
@@ -29,6 +32,12 @@ placeholder = cv2.cvtColor(cv2.imread("placeholder.png"), cv2.COLOR_BGR2RGB)
29
  NEW_MODEL = True
30
  MODEL_EPOCH = 6
31
  REF_POSE_MASK = True
 
 
 
 
 
 
32
 
33
  def set_seed(seed):
34
  seed = int(seed)
@@ -112,7 +121,7 @@ def visualize_hand(all_joints, img, side=["right", "left"], n_avail_joints=21):
112
  # Convert BytesIO object to numpy array
113
  buf.seek(0)
114
  img_pil = Image.open(buf)
115
- img_pil = img_pil.resize((H, W))
116
  numpy_img = np.array(img_pil)
117
 
118
  return numpy_img
@@ -232,31 +241,9 @@ if NEW_MODEL:
232
  print(f"encoder after eval() max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
233
  print(f"autoencoder encoder after eval() dtype: {next(autoencoder.encoder.parameters()).dtype}")
234
  assert len(missing_keys) == 0
235
- # else:
236
- # opts = HandDiffOpts()
237
- # model_path = './finetune_epoch=5-step=130000.ckpt'
238
- # sd_path = './sd-v1-4.ckpt'
239
- # print('Load diffusion model...')
240
- # diffusion = create_diffusion(str(opts.test_sampling_steps))
241
- # model = vit.DiT_XL_2(
242
- # input_size=opts.latent_size[0],
243
- # latent_dim=opts.latent_dim,
244
- # in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
245
- # learn_sigma=True,
246
- # ).to(device)
247
- # ckpt_state_dict = torch.load(model_path)['state_dict']
248
- # dit_state_dict = {remove_prefix(k, 'diffusion_backbone.'): v for k, v in ckpt_state_dict.items() if k.startswith('diffusion_backbone')}
249
- # vae_state_dict = {remove_prefix(k, 'autoencoder.'): v for k, v in ckpt_state_dict.items() if k.startswith('autoencoder')}
250
- # missing_keys, extra_keys = model.load_state_dict(dit_state_dict, strict=False)
251
- # model.eval()
252
- # assert len(missing_keys) == 0 and len(extra_keys) == 0
253
- # autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).to(device)
254
- # missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
255
- # autoencoder.eval()
256
- # assert len(missing_keys) == 0 and len(extra_keys) == 0
257
- sam_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="sam_vit_h_4b8939.pth", token=token)
258
- sam_predictor = init_sam(ckpt_path=sam_path, device='cpu')
259
 
 
 
260
 
261
  print("Mediapipe hand detector and SAM ready...")
262
  mp_hands = mp.solutions.hands
@@ -266,17 +253,12 @@ hands = mp_hands.Hands(
266
  min_detection_confidence=0.1,
267
  )
268
 
269
- def prepare_ref_anno(ref):
270
  if ref is None:
271
  return (
272
- None,
273
- None,
274
- None,
275
- None,
276
- None,
277
  )
278
- missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
279
-
280
  img = ref["composite"][..., :3]
281
  img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
282
  keypts = np.zeros((42, 2))
@@ -307,6 +289,7 @@ def get_ref_anno(img, keypts):
307
  if keypts is None:
308
  no_hands = cv2.resize(np.array(Image.open("no_hands.png"))[..., :3], (LENGTH, LENGTH))
309
  return None, no_hands, None
 
310
  if isinstance(keypts, list):
311
  if len(keypts[0]) == 0:
312
  keypts[0] = np.zeros((21, 2))
@@ -315,7 +298,6 @@ def get_ref_anno(img, keypts):
315
  else:
316
  gr.Info("Number of right hand keypoints should be either 0 or 21.")
317
  return None, None, None
318
-
319
  if len(keypts[1]) == 0:
320
  keypts[1] = np.zeros((21, 2))
321
  elif len(keypts[1]) == 21:
@@ -323,7 +305,6 @@ def get_ref_anno(img, keypts):
323
  else:
324
  gr.Info("Number of left hand keypoints should be either 0 or 21.")
325
  return None, None, None
326
-
327
  keypts = np.concatenate(keypts, axis=0)
328
  if REF_POSE_MASK:
329
  sam_predictor.set_image(img)
@@ -362,7 +343,7 @@ def get_ref_anno(img, keypts):
362
  Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
363
  ]
364
  )
365
- image = image_transform(img) # .to(device)
366
  kpts_valid = check_keypoints_validity(keypts, target_size)
367
  heatmaps = torch.tensor(
368
  keypoint_heatmap(
@@ -370,7 +351,7 @@ def get_ref_anno(img, keypts):
370
  )
371
  * kpts_valid[:, None, None],
372
  dtype=torch.float,
373
- # device=device
374
  )[None, ...]
375
  mask = torch.tensor(
376
  cv2.resize(
@@ -379,7 +360,7 @@ def get_ref_anno(img, keypts):
379
  interpolation=cv2.INTER_NEAREST,
380
  ),
381
  dtype=torch.float,
382
- # device=device,
383
  ).unsqueeze(0)[None, ...]
384
  return image[None, ...], heatmaps, mask
385
 
@@ -388,7 +369,7 @@ def get_ref_anno(img, keypts):
388
  img,
389
  keypts,
390
  hand_mask,
391
- device="cuda",
392
  target_size=opts.image_size,
393
  latent_size=opts.latent_size,
394
  )
@@ -409,62 +390,49 @@ def get_ref_anno(img, keypts):
409
 
410
  return img, ref_pose, ref_cond
411
 
412
- def get_target_anno(target):
413
- if target is None:
414
- return (
415
- gr.State.update(value=None),
416
- gr.Image.update(value=None),
417
- gr.State.update(value=None),
418
- gr.State.update(value=None),
419
- )
420
- pose_img = target["composite"][..., :3]
421
- pose_img = cv2.resize(pose_img, opts.image_size, interpolation=cv2.INTER_AREA)
422
- # detect keypoints
423
- mp_pose = hands.process(pose_img)
424
- target_keypts = np.zeros((42, 2))
425
- detected = np.array([0, 0])
426
- start_idx = 0
427
- if mp_pose.multi_hand_landmarks:
428
- # handedness is flipped assuming the input image is mirrored in MediaPipe
429
- for hand_landmarks, handedness in zip(
430
- mp_pose.multi_hand_landmarks, mp_pose.multi_handedness
431
- ):
432
- # actually right hand
433
- if handedness.classification[0].label == "Left":
434
- start_idx = 0
435
- detected[0] = 1
436
- # actually left hand
437
- elif handedness.classification[0].label == "Right":
438
- start_idx = 21
439
- detected[1] = 1
440
- for i, landmark in enumerate(hand_landmarks.landmark):
441
- target_keypts[start_idx + i] = [
442
- landmark.x * opts.image_size[1],
443
- landmark.y * opts.image_size[0],
444
- ]
445
-
446
- target_pose = visualize_hand(target_keypts, pose_img)
447
- kpts_valid = check_keypoints_validity(target_keypts, opts.image_size)
448
- target_heatmaps = torch.tensor(
449
- keypoint_heatmap(
450
- scale_keypoint(target_keypts, opts.image_size, opts.latent_size),
451
- opts.latent_size,
452
- var=1.0,
453
- )
454
- * kpts_valid[:, None, None],
455
- dtype=torch.float,
456
- # device=device,
457
- )[None, ...]
458
- target_cond = torch.cat(
459
- [target_heatmaps, torch.zeros_like(target_heatmaps)[:, :1]], 1
460
  )
461
- else:
462
- raise gr.Error("No hands detected in the target image.")
 
 
 
 
 
463
 
464
- return pose_img, target_pose, target_cond, target_keypts
465
 
466
 
467
  def get_mask_inpaint(ref):
 
 
 
468
  inpaint_mask = np.array(ref["layers"][0])[..., -1]
469
  inpaint_mask = cv2.resize(
470
  inpaint_mask, opts.image_size, interpolation=cv2.INTER_AREA
@@ -473,12 +441,12 @@ def get_mask_inpaint(ref):
473
  return inpaint_mask
474
 
475
 
476
- def visualize_ref(crop, brush):
477
- if crop is None or brush is None:
478
  return None
479
  inpainted = brush["layers"][0][..., -1]
480
- img = crop["background"][..., :3]
481
- img = cv2.resize(img, inpainted.shape[::-1], interpolation=cv2.INTER_AREA)
482
  mask = inpainted < 128
483
  # img = img.astype(np.int32)
484
  # img[mask, :] = img[mask, :] - 50
@@ -539,7 +507,39 @@ def reset_kps(img, keypoints, side: Literal["right", "left"]):
539
  keypoints[1] = []
540
  return img, keypoints
541
 
542
- @spaces.GPU(duration=60)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg):
544
  set_seed(seed)
545
  z = torch.randn(
@@ -586,14 +586,17 @@ def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg):
586
  print(f"results[0].max(): {results[0].max()}")
587
  return results, results_pose
588
 
589
- @spaces.GPU(duration=120)
590
- def ready_sample(img_ori, inpaint_mask, keypts):
591
- img = cv2.resize(img_ori[..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
 
592
  sam_predictor.set_image(img)
593
  if len(keypts[0]) == 0:
594
  keypts[0] = np.zeros((21, 2))
595
  elif len(keypts[0]) == 21:
596
  keypts[0] = np.array(keypts[0], dtype=np.float32)
 
 
597
  else:
598
  gr.Info("Number of right hand keypoints should be either 0 or 21.")
599
  return None, None
@@ -602,12 +605,14 @@ def ready_sample(img_ori, inpaint_mask, keypts):
602
  keypts[1] = np.zeros((21, 2))
603
  elif len(keypts[1]) == 21:
604
  keypts[1] = np.array(keypts[1], dtype=np.float32)
 
 
605
  else:
606
  gr.Info("Number of left hand keypoints should be either 0 or 21.")
607
  return None, None
608
 
609
  keypts = np.concatenate(keypts, axis=0)
610
- keypts = scale_keypoint(keypts, (LENGTH, LENGTH), opts.image_size)
611
 
612
  box_shift_ratio = 0.5
613
  box_size_factor = 1.2
@@ -643,7 +648,7 @@ def ready_sample(img_ori, inpaint_mask, keypts):
643
  inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
644
  ),
645
  dtype=torch.float,
646
- # device=device,
647
  ).unsqueeze(0)[None, ...]
648
 
649
  def make_ref_cond(
@@ -661,7 +666,7 @@ def ready_sample(img_ori, inpaint_mask, keypts):
661
  Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
662
  ]
663
  )
664
- image = image_transform(img)
665
  kpts_valid = check_keypoints_validity(keypts, target_size)
666
  heatmaps = torch.tensor(
667
  keypoint_heatmap(
@@ -669,7 +674,7 @@ def ready_sample(img_ori, inpaint_mask, keypts):
669
  )
670
  * kpts_valid[:, None, None],
671
  dtype=torch.float,
672
- # device=device,
673
  )[None, ...]
674
  mask = torch.tensor(
675
  cv2.resize(
@@ -678,7 +683,7 @@ def ready_sample(img_ori, inpaint_mask, keypts):
678
  interpolation=cv2.INTER_NEAREST,
679
  ),
680
  dtype=torch.float,
681
- # device=device,
682
  ).unsqueeze(0)[None, ...]
683
  return image[None, ...], heatmaps, mask
684
 
@@ -686,7 +691,7 @@ def ready_sample(img_ori, inpaint_mask, keypts):
686
  img,
687
  keypts,
688
  hand_mask * (1 - inpaint_mask),
689
- device=device,
690
  target_size=opts.image_size,
691
  latent_size=opts.latent_size,
692
  )
@@ -726,13 +731,15 @@ def switch_mask_size(radio):
726
  out = (gr.update(visible=True), gr.update(visible=False))
727
  return out
728
 
729
- @spaces.GPU(duration=300)
730
  def sample_inpaint(
731
  ref_cond,
732
  target_cond,
733
  latent,
734
  inpaint_latent_mask,
735
  keypts,
 
 
736
  num_gen,
737
  seed,
738
  cfg,
@@ -778,39 +785,76 @@ def sample_inpaint(
778
  # visualize
779
  results = []
780
  results_pose = []
 
781
  for i in range(FIX_MAX_N):
782
  if i < num_gen:
783
- results.append(sampled_images[i])
784
- results_pose.append(visualize_hand(keypts, sampled_images[i]))
 
 
 
 
 
785
  else:
786
  results.append(placeholder)
787
  results_pose.append(placeholder)
788
- return results, results_pose
 
789
 
790
 
791
  def flip_hand(
792
- img, pose_img, cond: Optional[torch.Tensor], keypts: Optional[torch.Tensor] = None, pose_manual_img = None,
793
- manual_kp_right=None, manual_kp_left=None
 
 
794
  ):
795
  if cond is None: # clear clicked
796
- return None, None, None, None
797
  img["composite"] = img["composite"][:, ::-1, :]
798
  img["background"] = img["background"][:, ::-1, :]
799
  img["layers"] = [layer[:, ::-1, :] for layer in img["layers"]]
 
 
 
 
 
 
 
 
800
  pose_img = pose_img[:, ::-1, :]
 
 
 
 
 
 
801
  cond = cond.flip(-1)
802
- if keypts is not None: # cond is target_cond
 
 
 
 
 
 
 
 
 
 
803
  if keypts[:21, :].sum() != 0:
804
  keypts[:21, 0] = opts.image_size[1] - keypts[:21, 0]
805
- # keypts[:21, 1] = opts.image_size[0] - keypts[:21, 1]
806
  if keypts[21:, :].sum() != 0:
807
  keypts[21:, 0] = opts.image_size[1] - keypts[21:, 0]
808
- # keypts[21:, 1] = opts.image_size[0] - keypts[21:, 1]
809
- if pose_manual_img is not None:
810
- pose_manual_img = pose_manual_img[:, ::-1, :]
811
- manual_kp_right = manual_kp_right[:, ::-1, :]
812
- manual_kp_left = manual_kp_left[:, ::-1, :]
813
- return img, pose_img, cond, keypts, pose_manual_img, manual_kp_right, manual_kp_left
 
 
 
 
 
814
 
815
 
816
  def resize_to_full(img):
@@ -823,26 +867,30 @@ def resize_to_full(img):
823
  def clear_all():
824
  return (
825
  None,
 
826
  None,
827
  None,
828
  None,
829
  None,
830
- False,
831
  None,
832
  None,
833
  False,
834
  None,
835
  None,
 
 
 
 
836
  None,
837
  None,
838
  None,
 
839
  None,
840
  None,
841
  1,
842
  42,
843
  3.0,
844
  gr.update(interactive=False),
845
- []
846
  )
847
 
848
 
@@ -851,6 +899,9 @@ def fix_clear_all():
851
  None,
852
  None,
853
  None,
 
 
 
854
  None,
855
  None,
856
  None,
@@ -876,14 +927,14 @@ def fix_clear_all():
876
  def enable_component(image1, image2):
877
  if image1 is None or image2 is None:
878
  return gr.update(interactive=False)
879
- if "background" in image1 and "layers" in image1 and "composite" in image1:
880
  if (
881
  image1["background"].sum() == 0
882
  and (sum([im.sum() for im in image1["layers"]]) == 0)
883
  and image1["composite"].sum() == 0
884
  ):
885
  return gr.update(interactive=False)
886
- if "background" in image2 and "layers" in image2 and "composite" in image2:
887
  if (
888
  image2["background"].sum() == 0
889
  and (sum([im.sum() for im in image2["layers"]]) == 0)
@@ -940,6 +991,18 @@ def set_visible(checkbox, kpts, img_clean, img_pose_right, img_pose_left, done=N
940
 
941
  def set_unvisible():
942
  return (
 
 
 
 
 
 
 
 
 
 
 
 
943
  gr.update(visible=False),
944
  gr.update(visible=False),
945
  gr.update(visible=False),
@@ -954,6 +1017,18 @@ def set_unvisible():
954
  gr.update(visible=False)
955
  )
956
 
 
 
 
 
 
 
 
 
 
 
 
 
957
  def set_no_hands(decider, component):
958
  if decider is None:
959
  no_hands = cv2.resize(np.array(Image.open("no_hands.png"))[..., :3], (LENGTH, LENGTH))
@@ -975,19 +1050,6 @@ def unvisible_component(decider, component):
975
  update_component = gr.update(visible=True)
976
  return update_component
977
 
978
- # def make_change(decider, state):
979
- # '''
980
- # if decider is not None, change the state's value. True/False does not matter.
981
- # '''
982
- # if decider is not None:
983
- # if state:
984
- # state = False
985
- # else:
986
- # state = True
987
- # return state
988
- # else:
989
- # return state
990
-
991
  LENGTH = 480
992
 
993
  example_ref_imgs = [
@@ -1083,7 +1145,7 @@ fix_example_imgs = [
1083
  # ["bad_hands/4.jpg"], # "bad_hands/4_mask.jpg"],
1084
  ["bad_hands/5.jpg"], # "bad_hands/5_mask.jpg"],
1085
  ["bad_hands/6.jpg"], # "bad_hands/6_mask.jpg"],
1086
- ["bad_hands/7.jpg"], # "bad_hands/7_mask.jpg"],
1087
  # ["bad_hands/8.jpg"], # "bad_hands/8_mask.jpg"],
1088
  # ["bad_hands/9.jpg"], # "bad_hands/9_mask.jpg"],
1089
  # ["bad_hands/10.jpg"], # "bad_hands/10_mask.jpg"],
@@ -1137,20 +1199,32 @@ _CITE_ = r"""
1137
  with gr.Blocks(css=custom_css, theme="soft") as demo:
1138
  gr.Markdown(_HEADER_)
1139
  with gr.Tab("Edit Hand Poses"):
 
 
 
1140
  ref_img = gr.State(value=None)
1141
  ref_im_raw = gr.State(value=None)
1142
  ref_kp_raw = gr.State(value=0)
1143
  ref_kp_got = gr.State(value=None)
1144
- dump = gr.State(value=None)
1145
- ref_cond = gr.State(value=None)
1146
  ref_manual_cond = gr.State(value=None)
1147
  ref_auto_cond = gr.State(value=None)
1148
- keypts = gr.State(value=None)
 
 
1149
  target_img = gr.State(value=None)
1150
- target_cond = gr.State(value=None)
 
 
 
 
1151
  target_keypts = gr.State(value=None)
1152
- dump = gr.State(value=None)
 
 
 
 
1153
  with gr.Row():
 
1154
  with gr.Column():
1155
  gr.Markdown(
1156
  """<p style="text-align: center; font-size: 20px; font-weight: bold;">1. Upload a hand image to edit 📥</p>"""
@@ -1270,6 +1344,8 @@ with gr.Blocks(css=custom_css, theme="soft") as demo:
1270
  ref_flip = gr.Checkbox(
1271
  value=False, label="Flip Handedness (Reference)", interactive=False
1272
  )
 
 
1273
  with gr.Column():
1274
  gr.Markdown(
1275
  """<p style="text-align: center; font-size: 20px; font-weight: bold;">2. Upload a hand image for target hand pose 📥</p>"""
@@ -1294,20 +1370,105 @@ with gr.Blocks(css=custom_css, theme="soft") as demo:
1294
  target_finish_crop = gr.Button(
1295
  value="Finish Cropping", interactive=False
1296
  )
1297
- target_pose = gr.Image(
1298
- type="numpy",
1299
- label="Target Pose",
1300
- show_label=True,
1301
- height=LENGTH,
1302
- width=LENGTH,
1303
- interactive=False,
1304
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1305
  gr.Markdown(
1306
  """<p style="text-align: center;">&#9314; Optionally flip the hand</p>"""
1307
  )
1308
  target_flip = gr.Checkbox(
1309
  value=False, label="Flip Handedness (Target)", interactive=False
1310
  )
 
 
1311
  with gr.Column():
1312
  gr.Markdown(
1313
  """<p style="text-align: center; font-size: 20px; font-weight: bold;">3. Press &quot;Run&quot; to get the edited results 🎯</p>"""
@@ -1371,10 +1532,18 @@ with gr.Blocks(css=custom_css, theme="soft") as demo:
1371
  interactive=True,
1372
  )
1373
 
 
1374
  ref.change(enable_component, [ref, ref], ref_finish_crop)
1375
- ref_finish_crop.click(prepare_ref_anno, [ref], [ref_im_raw, ref_kp_raw])
1376
  ref_kp_raw.change(lambda x: x, ref_im_raw, ref_manual_kp_right)
1377
  ref_kp_raw.change(lambda x: x, ref_im_raw, ref_manual_kp_left)
 
 
 
 
 
 
 
1378
  ref_manual_checkbox.select(
1379
  set_visible,
1380
  [ref_manual_checkbox, ref_kp_got, ref_im_raw, ref_manual_kp_right, ref_manual_kp_left, ref_manual_done],
@@ -1412,38 +1581,94 @@ with gr.Blocks(css=custom_css, theme="soft") as demo:
1412
  ref_manual_reset_left.click(
1413
  reset_kps, [ref_im_raw, ref_kp_got, gr.State("left")], [ref_manual_kp_left, ref_kp_got]
1414
  )
 
 
1415
  ref_manual_done.click(get_ref_anno, [ref_im_raw, ref_kp_got], [ref_img, ref_manual_pose, ref_manual_cond])
1416
- ref_manual_cond.change(lambda x: x, ref_manual_cond, ref_cond)
1417
- ref_use_manual.click(lambda x: x, ref_manual_cond, ref_cond)
1418
- # ref_use_manual.click(lambda x: gr.Info("Manual hand keypoints will be used for 'Reference'", duration=3))
1419
- ref_manual_done.click(visible_component, [ref_manual_pose, ref_manual_pose], ref_manual_pose)
1420
- ref_manual_done.click(visible_component, [ref_use_manual, ref_use_manual], ref_use_manual)
1421
  ref_manual_pose.change(enable_component, [ref_manual_pose, ref_manual_pose], ref_manual_done)
1422
- ref_kp_raw.change(get_ref_anno, [ref_im_raw, ref_kp_raw], [ref_img, ref_pose, ref_auto_cond])
1423
- ref_auto_cond.change(lambda x: x, ref_auto_cond, ref_cond)
1424
- ref_use_auto.click(lambda x: x, ref_auto_cond, ref_cond)
1425
- # ref_use_auto.click(lambda x: gr.Info("Automatic hand keypoints will be used for 'Reference'", duration=3))
1426
- ref_pose.change(enable_component, [ref_kp_raw, ref_pose], ref_use_auto)
1427
- ref_pose.change(enable_component, [ref_img, ref_pose], ref_flip)
1428
  ref_manual_pose.change(enable_component, [ref_img, ref_manual_pose], ref_flip)
 
 
 
 
1429
  ref_flip.select(
1430
- flip_hand, [ref, ref_pose, ref_cond, gr.State(value=None), ref_manual_pose, ref_manual_kp_right, ref_manual_kp_left], [ref, ref_pose, ref_cond, dump, ref_manual_pose, ref_manual_kp_right, ref_manual_kp_left]
 
 
1431
  )
 
 
1432
  target.change(enable_component, [target, target], target_finish_crop)
1433
- target_finish_crop.click(
1434
- get_target_anno,
1435
- [target],
1436
- [target_img, target_pose, target_cond, target_keypts],
1437
- )
1438
  target_pose.change(enable_component, [target_img, target_pose], target_flip)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1439
  target_flip.select(
1440
  flip_hand,
1441
- [target, target_pose, target_cond, target_keypts],
1442
- [target, target_pose, target_cond, target_keypts],
1443
  )
1444
- ref_pose.change(enable_component, [ref_pose, target_pose], run)
1445
- ref_manual_pose.change(enable_component, [ref_manual_pose, target_pose], run)
1446
- target_pose.change(enable_component, [ref_pose, target_pose], run)
 
 
 
1447
  run.click(
1448
  sample_diff,
1449
  [ref_cond, target_cond, target_keypts, n_generation, seed, cfg],
@@ -1454,34 +1679,40 @@ with gr.Blocks(css=custom_css, theme="soft") as demo:
1454
  [],
1455
  [
1456
  ref,
 
1457
  ref_manual_kp_right,
1458
  ref_manual_kp_left,
 
1459
  ref_pose,
1460
  ref_manual_pose,
 
1461
  ref_flip,
1462
  target,
 
 
 
 
 
1463
  target_pose,
 
 
1464
  target_flip,
1465
  results,
1466
  results_pose,
1467
- ref_img,
1468
- ref_cond,
1469
- target_img,
1470
- target_cond,
1471
- target_keypts,
1472
  n_generation,
1473
  seed,
1474
  cfg,
1475
  ref_kp_raw,
1476
- ref_manual_checkbox
1477
  ],
1478
  )
1479
  clear.click(
1480
  set_unvisible,
1481
  [],
1482
  [
1483
- ref_manual_kp_r_info,
1484
  ref_manual_kp_l_info,
 
 
 
1485
  ref_manual_undo_left,
1486
  ref_manual_undo_right,
1487
  ref_manual_reset_left,
@@ -1490,14 +1721,25 @@ with gr.Blocks(css=custom_css, theme="soft") as demo:
1490
  ref_manual_done_info,
1491
  ref_manual_pose,
1492
  ref_use_manual,
1493
- ref_manual_kp_right,
1494
- ref_manual_kp_left
 
 
 
 
 
 
 
 
 
 
1495
  ]
1496
  )
1497
 
1498
  with gr.Tab("Fix Hands"):
1499
  fix_inpaint_mask = gr.State(value=None)
1500
  fix_original = gr.State(value=None)
 
1501
  fix_img = gr.State(value=None)
1502
  fix_kpts = gr.State(value=None)
1503
  fix_kpts_np = gr.State(value=None)
@@ -1506,37 +1748,62 @@ with gr.Blocks(css=custom_css, theme="soft") as demo:
1506
  fix_latent = gr.State(value=None)
1507
  fix_inpaint_latent = gr.State(value=None)
1508
  with gr.Row():
 
1509
  with gr.Column():
1510
  gr.Markdown(
1511
  """<p style="text-align: center; font-size: 20px; font-weight: bold;">1. Upload a malformed hand image to fix 📥</p>"""
1512
  )
1513
  gr.Markdown(
1514
- """<p style="text-align: center;">&#9312; Optionally crop the image around the hand</p>"""
1515
  )
1516
- fix_crop = gr.ImageEditor(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1517
  type="numpy",
1518
  sources=["upload", "webcam", "clipboard"],
1519
- label="Image crop",
1520
  show_label=True,
1521
  height=LENGTH,
1522
  width=LENGTH,
1523
- layers=False,
1524
- crop_size="1:1",
1525
- brush=False,
1526
- image_mode="RGBA",
1527
- container=False,
1528
  )
 
 
 
 
 
 
 
 
 
 
1529
  fix_example = gr.Examples(
1530
  fix_example_imgs,
1531
  inputs=[fix_crop],
1532
  examples_per_page=20,
1533
  )
1534
  gr.Markdown(
1535
- """<p style="text-align: center;">&#9313; Brush area (e.g., wrong finger) that needs to be fixed. This will serve as an inpaint mask</p>"""
1536
  )
1537
  fix_ref = gr.ImageEditor(
1538
  type="numpy",
1539
- label="Image brush",
1540
  sources=(),
1541
  show_label=True,
1542
  height=LENGTH,
@@ -1550,9 +1817,14 @@ with gr.Blocks(css=custom_css, theme="soft") as demo:
1550
  container=False,
1551
  interactive=False,
1552
  )
 
 
 
1553
  fix_finish_crop = gr.Button(
1554
  value="Finish Croping & Brushing", interactive=False
1555
  )
 
 
1556
  with gr.Column():
1557
  gr.Markdown(
1558
  """<p style="text-align: center; font-size: 20px; font-weight: bold;">2. Click on hand to get target hand pose</p>"""
@@ -1565,13 +1837,14 @@ with gr.Blocks(css=custom_css, theme="soft") as demo:
1565
  show_label=False,
1566
  interactive=False,
1567
  )
1568
- gr.Markdown(
1569
- """<p style="text-align: center;">&#9313; On the image, click 21 hand keypoints. This will serve as target hand poses. See the \"OpenPose keypoints convention\" for guidance.</p>"""
1570
- )
1571
  fix_kp_r_info = gr.Markdown(
1572
- """<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select right only</p>""",
1573
- visible=False,
1574
  )
 
 
 
 
1575
  fix_kp_right = gr.Image(
1576
  type="numpy",
1577
  label="Keypoint Selection (right hand)",
@@ -1590,7 +1863,7 @@ with gr.Blocks(css=custom_css, theme="soft") as demo:
1590
  value="Reset", interactive=False, visible=False
1591
  )
1592
  fix_kp_l_info = gr.Markdown(
1593
- """<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select left only</p>""",
1594
  visible=False
1595
  )
1596
  fix_kp_left = gr.Image(
@@ -1621,13 +1894,15 @@ with gr.Blocks(css=custom_css, theme="soft") as demo:
1621
  width=LENGTH // 2,
1622
  interactive=False,
1623
  )
 
 
1624
  with gr.Column():
1625
  gr.Markdown(
1626
  """<p style="text-align: center; font-size: 20px; font-weight: bold;">3. Press &quot;Ready&quot; to start pre-processing</p>"""
1627
  )
1628
  fix_ready = gr.Button(value="Ready", interactive=False)
1629
  gr.Markdown(
1630
- """<p style="text-align: center; font-weight: bold; ">Visualized (256, 256) Inpaint Mask</p>"""
1631
  )
1632
  fix_vis_mask32 = gr.Image(
1633
  type="numpy",
@@ -1646,9 +1921,11 @@ with gr.Blocks(css=custom_css, theme="soft") as demo:
1646
  width=opts.image_size,
1647
  interactive=False,
1648
  )
1649
- gr.Markdown(
1650
- """<p style="text-align: center;">[NOTE] Above should be inpaint mask that you brushed, NOT the segmentation mask of the entire hand. </p>"""
1651
- )
 
 
1652
  with gr.Column():
1653
  gr.Markdown(
1654
  """<p style="text-align: center; font-size: 20px; font-weight: bold;">4. Press &quot;Run&quot; to get the fixed hand image 🎯</p>"""
@@ -1657,6 +1934,16 @@ with gr.Blocks(css=custom_css, theme="soft") as demo:
1657
  gr.Markdown(
1658
  """<p style="text-align: center;">⚠️ >3min and ~24GB per generation</p>"""
1659
  )
 
 
 
 
 
 
 
 
 
 
1660
  fix_result = gr.Gallery(
1661
  type="numpy",
1662
  label="Results",
@@ -1682,55 +1969,58 @@ with gr.Blocks(css=custom_css, theme="soft") as demo:
1682
  )
1683
  fix_clear = gr.ClearButton()
1684
 
1685
- gr.Markdown(
1686
- """<p style="text-align: left; font-size: 25px;"><b>More options</b></p>"""
1687
- )
1688
- gr.Markdown(
1689
- "⚠️ Currently, Number of generation > 1 could lead to out-of-memory"
1690
- )
1691
- with gr.Row():
1692
- fix_n_generation = gr.Slider(
1693
- label="Number of generations",
1694
- value=1,
1695
- minimum=1,
1696
- maximum=FIX_MAX_N,
1697
- step=1,
1698
- randomize=False,
1699
- interactive=True,
1700
- )
1701
- fix_seed = gr.Slider(
1702
- label="Seed",
1703
- value=42,
1704
- minimum=0,
1705
- maximum=10000,
1706
- step=1,
1707
- randomize=False,
1708
- interactive=True,
1709
- )
1710
- fix_cfg = gr.Slider(
1711
- label="Classifier free guidance scale",
1712
- value=3.0,
1713
- minimum=0.0,
1714
- maximum=10.0,
1715
- step=0.1,
1716
- randomize=False,
1717
- interactive=True,
1718
- )
1719
- fix_quality = gr.Slider(
1720
- label="Quality",
1721
- value=10,
1722
- minimum=1,
1723
- maximum=10,
1724
- step=1,
1725
- randomize=False,
1726
- interactive=True,
1727
  )
1728
- fix_crop.change(enable_component, [fix_crop, fix_crop], fix_ref)
1729
- fix_crop.change(resize_to_full, fix_crop, fix_ref)
1730
- fix_ref.change(enable_component, [fix_ref, fix_ref], fix_finish_crop)
1731
- fix_finish_crop.click(get_mask_inpaint, [fix_ref], [fix_inpaint_mask])
1732
- fix_finish_crop.click(lambda x: x["background"], [fix_crop], [fix_original])
1733
- fix_finish_crop.click(visualize_ref, [fix_crop, fix_ref], [fix_img])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1734
  fix_img.change(lambda x: x, [fix_img], [fix_kp_right])
1735
  fix_img.change(lambda x: x, [fix_img], [fix_kp_left])
1736
  fix_inpaint_mask.change(
@@ -1775,7 +2065,7 @@ with gr.Blocks(css=custom_css, theme="soft") as demo:
1775
  ],
1776
  )
1777
  fix_kp_right.select(
1778
- get_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
1779
  )
1780
  fix_undo_right.click(
1781
  undo_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
@@ -1797,7 +2087,7 @@ with gr.Blocks(css=custom_css, theme="soft") as demo:
1797
  )
1798
  fix_ready.click(
1799
  ready_sample,
1800
- [fix_original, fix_inpaint_mask, fix_kpts],
1801
  [
1802
  fix_ref_cond,
1803
  fix_target_cond,
@@ -1816,23 +2106,28 @@ with gr.Blocks(css=custom_css, theme="soft") as demo:
1816
  fix_latent,
1817
  fix_inpaint_latent,
1818
  fix_kpts_np,
 
 
1819
  fix_n_generation,
1820
  fix_seed,
1821
  fix_cfg,
1822
  fix_quality,
1823
  ],
1824
- [fix_result, fix_result_pose],
1825
  )
1826
  fix_clear.click(
1827
  fix_clear_all,
1828
  [],
1829
  [
1830
  fix_crop,
 
1831
  fix_ref,
 
1832
  fix_kp_right,
1833
  fix_kp_left,
1834
  fix_result,
1835
  fix_result_pose,
 
1836
  fix_inpaint_mask,
1837
  fix_original,
1838
  fix_img,
@@ -1850,6 +2145,20 @@ with gr.Blocks(css=custom_css, theme="soft") as demo:
1850
  fix_quality,
1851
  ],
1852
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1853
 
1854
  gr.Markdown("<h1>Citation</h1>")
1855
  gr.Markdown(
 
20
  from typing import Optional
21
  import requests
22
  from huggingface_hub import hf_hub_download
23
+ try:
24
+ import spaces
25
+ except:
26
+ pass
27
 
28
  MAX_N = 6
29
  FIX_MAX_N = 6
 
32
  NEW_MODEL = True
33
  MODEL_EPOCH = 6
34
  REF_POSE_MASK = True
35
+ HF = False
36
+ pre_device = "cpu" if HF else "cuda"
37
+ spaces_60_fn = spaces.GPU(duration=60) if HF else (lambda f: f)
38
+ spaces_120_fn = spaces.GPU(duration=60) if HF else (lambda f: f)
39
+ spaces_300_fn = spaces.GPU(duration=60) if HF else (lambda f: f)
40
+
41
 
42
  def set_seed(seed):
43
  seed = int(seed)
 
121
  # Convert BytesIO object to numpy array
122
  buf.seek(0)
123
  img_pil = Image.open(buf)
124
+ img_pil = img_pil.resize((W, H))
125
  numpy_img = np.array(img_pil)
126
 
127
  return numpy_img
 
241
  print(f"encoder after eval() max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
242
  print(f"autoencoder encoder after eval() dtype: {next(autoencoder.encoder.parameters()).dtype}")
243
  assert len(missing_keys) == 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
+ sam_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="sam_vit_h_4b8939.pth", token=token)
246
+ sam_predictor = init_sam(ckpt_path=sam_path, device=pre_device)
247
 
248
  print("Mediapipe hand detector and SAM ready...")
249
  mp_hands = mp.solutions.hands
 
253
  min_detection_confidence=0.1,
254
  )
255
 
256
+ def prepare_anno(ref):
257
  if ref is None:
258
  return (
259
+ gr.Image.update(value=None),
260
+ gr.State.update(value=None),
 
 
 
261
  )
 
 
262
  img = ref["composite"][..., :3]
263
  img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
264
  keypts = np.zeros((42, 2))
 
289
  if keypts is None:
290
  no_hands = cv2.resize(np.array(Image.open("no_hands.png"))[..., :3], (LENGTH, LENGTH))
291
  return None, no_hands, None
292
+ missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
293
  if isinstance(keypts, list):
294
  if len(keypts[0]) == 0:
295
  keypts[0] = np.zeros((21, 2))
 
298
  else:
299
  gr.Info("Number of right hand keypoints should be either 0 or 21.")
300
  return None, None, None
 
301
  if len(keypts[1]) == 0:
302
  keypts[1] = np.zeros((21, 2))
303
  elif len(keypts[1]) == 21:
 
305
  else:
306
  gr.Info("Number of left hand keypoints should be either 0 or 21.")
307
  return None, None, None
 
308
  keypts = np.concatenate(keypts, axis=0)
309
  if REF_POSE_MASK:
310
  sam_predictor.set_image(img)
 
343
  Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
344
  ]
345
  )
346
+ image = image_transform(img).to(device)
347
  kpts_valid = check_keypoints_validity(keypts, target_size)
348
  heatmaps = torch.tensor(
349
  keypoint_heatmap(
 
351
  )
352
  * kpts_valid[:, None, None],
353
  dtype=torch.float,
354
+ device=device
355
  )[None, ...]
356
  mask = torch.tensor(
357
  cv2.resize(
 
360
  interpolation=cv2.INTER_NEAREST,
361
  ),
362
  dtype=torch.float,
363
+ device=device,
364
  ).unsqueeze(0)[None, ...]
365
  return image[None, ...], heatmaps, mask
366
 
 
369
  img,
370
  keypts,
371
  hand_mask,
372
+ device=pre_device,
373
  target_size=opts.image_size,
374
  latent_size=opts.latent_size,
375
  )
 
390
 
391
  return img, ref_pose, ref_cond
392
 
393
+ def get_target_anno(img, keypts):
394
+ if keypts is None:
395
+ no_hands = cv2.resize(np.array(Image.open("no_hands.png"))[..., :3], (LENGTH, LENGTH))
396
+ return None, no_hands, None, None
397
+ if isinstance(keypts, list):
398
+ if len(keypts[0]) == 0:
399
+ keypts[0] = np.zeros((21, 2))
400
+ elif len(keypts[0]) == 21:
401
+ keypts[0] = np.array(keypts[0], dtype=np.float32)
402
+ else:
403
+ gr.Info("Number of right hand keypoints should be either 0 or 21.")
404
+ return None, None, None
405
+ if len(keypts[1]) == 0:
406
+ keypts[1] = np.zeros((21, 2))
407
+ elif len(keypts[1]) == 21:
408
+ keypts[1] = np.array(keypts[1], dtype=np.float32)
409
+ else:
410
+ gr.Info("Number of left hand keypoints should be either 0 or 21.")
411
+ return None, None, None
412
+ keypts = np.concatenate(keypts, axis=0)
413
+ target_pose = visualize_hand(keypts, img)
414
+ kpts_valid = check_keypoints_validity(keypts, opts.image_size)
415
+ target_heatmaps = torch.tensor(
416
+ keypoint_heatmap(
417
+ scale_keypoint(keypts, opts.image_size, opts.latent_size),
418
+ opts.latent_size,
419
+ var=1.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  )
421
+ * kpts_valid[:, None, None],
422
+ dtype=torch.float,
423
+ device=pre_device,
424
+ )[None, ...]
425
+ target_cond = torch.cat(
426
+ [target_heatmaps, torch.zeros_like(target_heatmaps)[:, :1]], 1
427
+ )
428
 
429
+ return img, target_pose, target_cond, keypts
430
 
431
 
432
  def get_mask_inpaint(ref):
433
+ # inpaint_mask = np.zeros_like(img_original[:, :, 0])
434
+ # cropped_mask = np.array(ref["layers"][0])[..., -1]
435
+ # inpaint_mask[crop_coord[0][1]:crop_coord[1][1], crop_coord[0][0]:crop_coord[1][0]] = cropped_mask
436
  inpaint_mask = np.array(ref["layers"][0])[..., -1]
437
  inpaint_mask = cv2.resize(
438
  inpaint_mask, opts.image_size, interpolation=cv2.INTER_AREA
 
441
  return inpaint_mask
442
 
443
 
444
+ def visualize_ref(brush): # crop,
445
+ if brush is None: # crop is None or
446
  return None
447
  inpainted = brush["layers"][0][..., -1]
448
+ img = brush["background"][..., :3]
449
+ # img = cv2.resize(img, inpainted.shape[::-1], interpolation=cv2.INTER_AREA)
450
  mask = inpainted < 128
451
  # img = img.astype(np.int32)
452
  # img[mask, :] = img[mask, :] - 50
 
507
  keypoints[1] = []
508
  return img, keypoints
509
 
510
+ def stay_crop(img, crop_coord):
511
+ if img is not None:
512
+ crop_coord = [[0, 0], [img.shape[1], img.shape[0]]]
513
+ cropped = img.copy()
514
+ return crop_coord, cropped
515
+ else:
516
+ return None, None
517
+
518
+ def process_crop(img, crop_coord, evt:gr.SelectData):
519
+ if len(crop_coord) == 2:
520
+ crop_coord = [list(evt.index)]
521
+ cropped = img.copy()
522
+ elif len(crop_coord) == 1:
523
+ new_coord =list(evt.index)
524
+ if new_coord[0] <= crop_coord[0][0] or new_coord[1] <= crop_coord[0][1]:
525
+ gr.Warning("Second click should be more under and more right thand the first click. Try second click again.", duration=3)
526
+ cropped = img.copy()
527
+ else:
528
+ crop_coord.append(new_coord)
529
+ x1, y1 = crop_coord[0]
530
+ x2, y2 = crop_coord[1]
531
+ cropped = img.copy()[y1:y2, x1:x2]
532
+ else:
533
+ gr.Error("Something is wrong", duration=3)
534
+ return crop_coord, cropped
535
+
536
+ def disable_crop(crop_coord):
537
+ if len(crop_coord) == 2:
538
+ return gr.update(interactive=False)
539
+ else:
540
+ return gr.update(interactive=True)
541
+
542
+ @spaces_60_fn
543
  def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg):
544
  set_seed(seed)
545
  z = torch.randn(
 
586
  print(f"results[0].max(): {results[0].max()}")
587
  return results, results_pose
588
 
589
+ @spaces_120_fn
590
+ def ready_sample(img_cropped, inpaint_mask, keypts):
591
+ # img = cv2.resize(img_ori[..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
592
+ img = cv2.resize(img_cropped["background"][..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
593
  sam_predictor.set_image(img)
594
  if len(keypts[0]) == 0:
595
  keypts[0] = np.zeros((21, 2))
596
  elif len(keypts[0]) == 21:
597
  keypts[0] = np.array(keypts[0], dtype=np.float32)
598
+ # keypts[0][:, 0] = keypts[0][:, 0] + crop_coord[0][0]
599
+ # keypts[0][:, 1] = keypts[0][:, 1] + crop_coord[0][1]
600
  else:
601
  gr.Info("Number of right hand keypoints should be either 0 or 21.")
602
  return None, None
 
605
  keypts[1] = np.zeros((21, 2))
606
  elif len(keypts[1]) == 21:
607
  keypts[1] = np.array(keypts[1], dtype=np.float32)
608
+ # keypts[1][:, 0] = keypts[1][:, 0] + crop_coord[0][0]
609
+ # keypts[1][:, 1] = keypts[1][:, 1] + crop_coord[0][1]
610
  else:
611
  gr.Info("Number of left hand keypoints should be either 0 or 21.")
612
  return None, None
613
 
614
  keypts = np.concatenate(keypts, axis=0)
615
+ keypts = scale_keypoint(keypts, (img_cropped["background"].shape[1], img_cropped["background"].shape[0]), opts.image_size)
616
 
617
  box_shift_ratio = 0.5
618
  box_size_factor = 1.2
 
648
  inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
649
  ),
650
  dtype=torch.float,
651
+ device=pre_device,
652
  ).unsqueeze(0)[None, ...]
653
 
654
  def make_ref_cond(
 
666
  Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
667
  ]
668
  )
669
+ image = image_transform(img).to(device)
670
  kpts_valid = check_keypoints_validity(keypts, target_size)
671
  heatmaps = torch.tensor(
672
  keypoint_heatmap(
 
674
  )
675
  * kpts_valid[:, None, None],
676
  dtype=torch.float,
677
+ device=device,
678
  )[None, ...]
679
  mask = torch.tensor(
680
  cv2.resize(
 
683
  interpolation=cv2.INTER_NEAREST,
684
  ),
685
  dtype=torch.float,
686
+ device=device,
687
  ).unsqueeze(0)[None, ...]
688
  return image[None, ...], heatmaps, mask
689
 
 
691
  img,
692
  keypts,
693
  hand_mask * (1 - inpaint_mask),
694
+ device=pre_device,
695
  target_size=opts.image_size,
696
  latent_size=opts.latent_size,
697
  )
 
731
  out = (gr.update(visible=True), gr.update(visible=False))
732
  return out
733
 
734
+ @spaces_300_fn
735
  def sample_inpaint(
736
  ref_cond,
737
  target_cond,
738
  latent,
739
  inpaint_latent_mask,
740
  keypts,
741
+ img_original,
742
+ crop_coord,
743
  num_gen,
744
  seed,
745
  cfg,
 
785
  # visualize
786
  results = []
787
  results_pose = []
788
+ results_original = []
789
  for i in range(FIX_MAX_N):
790
  if i < num_gen:
791
+ res =sampled_images[i]
792
+ results.append(res)
793
+ results_pose.append(visualize_hand(keypts, res))
794
+ res = cv2.resize(res, (crop_coord[1][0]-crop_coord[0][0], crop_coord[1][1]-crop_coord[0][1]))
795
+ res_original = img_original.copy()
796
+ res_original[crop_coord[0][1]:crop_coord[1][1], crop_coord[0][0]:crop_coord[1][0], :] = res
797
+ results_original.append(res_original)
798
  else:
799
  results.append(placeholder)
800
  results_pose.append(placeholder)
801
+ results_original.append(placeholder)
802
+ return results, results_pose, results_original
803
 
804
 
805
  def flip_hand(
806
+ img, img_raw, pose_img, pose_manual_img,
807
+ manual_kp_right, manual_kp_left,
808
+ cond, auto_cond, manual_cond,
809
+ keypts=None, auto_keypts=None, manual_keypts=None
810
  ):
811
  if cond is None: # clear clicked
812
+ return
813
  img["composite"] = img["composite"][:, ::-1, :]
814
  img["background"] = img["background"][:, ::-1, :]
815
  img["layers"] = [layer[:, ::-1, :] for layer in img["layers"]]
816
+ # for comp in [pose_img, pose_manual_img, manual_kp_right, manual_kp_left, cond, auto_cond, manual_cond]:
817
+ # if comp is not None:
818
+ # if isinstance(comp, torch.Tensor):
819
+ # comp = comp.flip(-1)
820
+ # else:
821
+ # comp = comp[:, ::-1, :]
822
+ if img_raw is not None:
823
+ img_raw = img_raw[:, ::-1, :]
824
  pose_img = pose_img[:, ::-1, :]
825
+ if pose_manual_img is not None:
826
+ pose_manual_img = pose_manual_img[:, ::-1, :]
827
+ if manual_kp_right is not None:
828
+ manual_kp_right = manual_kp_right[:, ::-1, :]
829
+ if manual_kp_left is not None:
830
+ manual_kp_left = manual_kp_left[:, ::-1, :]
831
  cond = cond.flip(-1)
832
+ if auto_cond is not None:
833
+ auto_cond = auto_cond.flip(-1)
834
+ if manual_cond is not None:
835
+ manual_cond = manual_cond.flip(-1)
836
+ # for comp in [keypts, auto_keypts, manual_keypts]:
837
+ # if comp is not None:
838
+ # if comp[:21, :].sum() != 0:
839
+ # comp[:21, 0] = opts.image_size[1] - comp[:21, 0]
840
+ # if comp[21:, :].sum() != 0:
841
+ # comp[21:, 0] = opts.image_size[1] - comp[21:, 0]
842
+ if keypts is not None:
843
  if keypts[:21, :].sum() != 0:
844
  keypts[:21, 0] = opts.image_size[1] - keypts[:21, 0]
 
845
  if keypts[21:, :].sum() != 0:
846
  keypts[21:, 0] = opts.image_size[1] - keypts[21:, 0]
847
+ if auto_keypts is not None:
848
+ if auto_keypts[:21, :].sum() != 0:
849
+ auto_keypts[:21, 0] = opts.image_size[1] - auto_keypts[:21, 0]
850
+ if auto_keypts[21:, :].sum() != 0:
851
+ auto_keypts[21:, 0] = opts.image_size[1] - auto_keypts[21:, 0]
852
+ if manual_keypts is not None:
853
+ if manual_keypts[:21, :].sum() != 0:
854
+ manual_keypts[:21, 0] = opts.image_size[1] - manual_keypts[:21, 0]
855
+ if manual_keypts[21:, :].sum() != 0:
856
+ manual_keypts[21:, 0] = opts.image_size[1] - manual_keypts[21:, 0]
857
+ return img, img_raw, pose_img, pose_manual_img, manual_kp_right, manual_kp_left, cond, auto_cond, manual_cond, keypts, auto_keypts, manual_keypts
858
 
859
 
860
  def resize_to_full(img):
 
867
  def clear_all():
868
  return (
869
  None,
870
+ [],
871
  None,
872
  None,
873
  None,
874
  None,
 
875
  None,
876
  None,
877
  False,
878
  None,
879
  None,
880
+ [],
881
+ None,
882
+ None,
883
+ None,
884
  None,
885
  None,
886
  None,
887
+ False,
888
  None,
889
  None,
890
  1,
891
  42,
892
  3.0,
893
  gr.update(interactive=False),
 
894
  )
895
 
896
 
 
899
  None,
900
  None,
901
  None,
902
+ [],
903
+ None,
904
+ None,
905
  None,
906
  None,
907
  None,
 
927
  def enable_component(image1, image2):
928
  if image1 is None or image2 is None:
929
  return gr.update(interactive=False)
930
+ if isinstance(image1, dict) and "background" in image1 and "layers" in image1 and "composite" in image1:
931
  if (
932
  image1["background"].sum() == 0
933
  and (sum([im.sum() for im in image1["layers"]]) == 0)
934
  and image1["composite"].sum() == 0
935
  ):
936
  return gr.update(interactive=False)
937
+ if isinstance(image1, dict) and "background" in image2 and "layers" in image2 and "composite" in image2:
938
  if (
939
  image2["background"].sum() == 0
940
  and (sum([im.sum() for im in image2["layers"]]) == 0)
 
991
 
992
  def set_unvisible():
993
  return (
994
+ gr.update(visible=False),
995
+ gr.update(visible=False),
996
+ gr.update(visible=False),
997
+ gr.update(visible=False),
998
+ gr.update(visible=False),
999
+ gr.update(visible=False),
1000
+ gr.update(visible=False),
1001
+ gr.update(visible=False),
1002
+ gr.update(visible=False),
1003
+ gr.update(visible=False),
1004
+ gr.update(visible=False),
1005
+ gr.update(visible=False),
1006
  gr.update(visible=False),
1007
  gr.update(visible=False),
1008
  gr.update(visible=False),
 
1017
  gr.update(visible=False)
1018
  )
1019
 
1020
+ def fix_set_unvisible():
1021
+ return (
1022
+ gr.update(visible=False),
1023
+ gr.update(visible=False),
1024
+ gr.update(visible=False),
1025
+ gr.update(visible=False),
1026
+ gr.update(visible=False),
1027
+ gr.update(visible=False),
1028
+ gr.update(visible=False),
1029
+ gr.update(visible=False)
1030
+ )
1031
+
1032
  def set_no_hands(decider, component):
1033
  if decider is None:
1034
  no_hands = cv2.resize(np.array(Image.open("no_hands.png"))[..., :3], (LENGTH, LENGTH))
 
1050
  update_component = gr.update(visible=True)
1051
  return update_component
1052
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1053
  LENGTH = 480
1054
 
1055
  example_ref_imgs = [
 
1145
  # ["bad_hands/4.jpg"], # "bad_hands/4_mask.jpg"],
1146
  ["bad_hands/5.jpg"], # "bad_hands/5_mask.jpg"],
1147
  ["bad_hands/6.jpg"], # "bad_hands/6_mask.jpg"],
1148
+ # ["bad_hands/7.jpg"], # "bad_hands/7_mask.jpg"],
1149
  # ["bad_hands/8.jpg"], # "bad_hands/8_mask.jpg"],
1150
  # ["bad_hands/9.jpg"], # "bad_hands/9_mask.jpg"],
1151
  # ["bad_hands/10.jpg"], # "bad_hands/10_mask.jpg"],
 
1199
  with gr.Blocks(css=custom_css, theme="soft") as demo:
1200
  gr.Markdown(_HEADER_)
1201
  with gr.Tab("Edit Hand Poses"):
1202
+ dump = gr.State(value=None)
1203
+
1204
+ # ref states
1205
  ref_img = gr.State(value=None)
1206
  ref_im_raw = gr.State(value=None)
1207
  ref_kp_raw = gr.State(value=0)
1208
  ref_kp_got = gr.State(value=None)
 
 
1209
  ref_manual_cond = gr.State(value=None)
1210
  ref_auto_cond = gr.State(value=None)
1211
+ ref_cond = gr.State(value=None)
1212
+
1213
+ # target states
1214
  target_img = gr.State(value=None)
1215
+ target_im_raw = gr.State(value=None)
1216
+ target_kp_raw = gr.State(value=0)
1217
+ target_kp_got = gr.State(value=None)
1218
+ target_manual_keypts = gr.State(value=None)
1219
+ target_auto_keypts = gr.State(value=None)
1220
  target_keypts = gr.State(value=None)
1221
+ target_manual_cond = gr.State(value=None)
1222
+ target_auto_cond = gr.State(value=None)
1223
+ target_cond = gr.State(value=None)
1224
+
1225
+ # main tab
1226
  with gr.Row():
1227
+ # ref column
1228
  with gr.Column():
1229
  gr.Markdown(
1230
  """<p style="text-align: center; font-size: 20px; font-weight: bold;">1. Upload a hand image to edit 📥</p>"""
 
1344
  ref_flip = gr.Checkbox(
1345
  value=False, label="Flip Handedness (Reference)", interactive=False
1346
  )
1347
+
1348
+ # target column
1349
  with gr.Column():
1350
  gr.Markdown(
1351
  """<p style="text-align: center; font-size: 20px; font-weight: bold;">2. Upload a hand image for target hand pose 📥</p>"""
 
1370
  target_finish_crop = gr.Button(
1371
  value="Finish Cropping", interactive=False
1372
  )
1373
+ with gr.Tab("Automatic hand keypoints"):
1374
+ target_pose = gr.Image(
1375
+ type="numpy",
1376
+ label="Target Pose",
1377
+ show_label=True,
1378
+ height=LENGTH,
1379
+ width=LENGTH,
1380
+ interactive=False,
1381
+ )
1382
+ target_use_auto = gr.Button(value="Click here to use automatic, not manual", interactive=False, visible=True)
1383
+ with gr.Tab("Manual hand keypoints"):
1384
+ target_manual_checkbox_info = gr.Markdown(
1385
+ """<p style="text-align: center;"><b>Step 1.</b> Tell us if this is right, left, or both hands.</p>""",
1386
+ visible=True,
1387
+ )
1388
+ target_manual_checkbox = gr.CheckboxGroup(
1389
+ ["Right hand", "Left hand"],
1390
+ show_label=False,
1391
+ visible=True,
1392
+ interactive=True,
1393
+ )
1394
+ target_manual_kp_r_info = gr.Markdown(
1395
+ """<p style="text-align: center;"><b>Step 2.</b> Click on image to provide hand keypoints for <b>right</b> hand. See \"OpenPose Keypoint Convention\" for guidance.</p>""",
1396
+ visible=False,
1397
+ )
1398
+ target_manual_kp_right = gr.Image(
1399
+ type="numpy",
1400
+ label="Keypoint Selection (right hand)",
1401
+ show_label=True,
1402
+ height=LENGTH,
1403
+ width=LENGTH,
1404
+ interactive=False,
1405
+ visible=False,
1406
+ sources=[],
1407
+ )
1408
+ with gr.Row():
1409
+ target_manual_undo_right = gr.Button(
1410
+ value="Undo", interactive=True, visible=False
1411
+ )
1412
+ target_manual_reset_right = gr.Button(
1413
+ value="Reset", interactive=True, visible=False
1414
+ )
1415
+ target_manual_kp_l_info = gr.Markdown(
1416
+ """<p style="text-align: center;"><b>Step 2.</b> Click on image to provide hand keypoints for <b>left</b> hand. See \"OpenPose keypoint convention\" for guidance.</p>""",
1417
+ visible=False
1418
+ )
1419
+ target_manual_kp_left = gr.Image(
1420
+ type="numpy",
1421
+ label="Keypoint Selection (left hand)",
1422
+ show_label=True,
1423
+ height=LENGTH,
1424
+ width=LENGTH,
1425
+ interactive=False,
1426
+ visible=False,
1427
+ sources=[],
1428
+ )
1429
+ with gr.Row():
1430
+ target_manual_undo_left = gr.Button(
1431
+ value="Undo", interactive=True, visible=False
1432
+ )
1433
+ target_manual_reset_left = gr.Button(
1434
+ value="Reset", interactive=True, visible=False
1435
+ )
1436
+ target_manual_done_info = gr.Markdown(
1437
+ """<p style="text-align: center;"><b>Step 3.</b> Hit \"Done\" button to confirm.</p>""",
1438
+ visible=False,
1439
+ )
1440
+ target_manual_done = gr.Button(value="Done", interactive=True, visible=False)
1441
+ target_manual_pose = gr.Image(
1442
+ type="numpy",
1443
+ label="Target Pose",
1444
+ show_label=True,
1445
+ height=LENGTH,
1446
+ width=LENGTH,
1447
+ interactive=False,
1448
+ visible=False
1449
+ )
1450
+ target_use_manual = gr.Button(value="Click here to use manual, not automatic", interactive=True, visible=False)
1451
+ target_manual_instruct = gr.Markdown(
1452
+ value="""<p style="text-align: left; font-weight: bold; ">OpenPose Keypoints Convention</p>""",
1453
+ visible=True
1454
+ )
1455
+ target_manual_openpose = gr.Image(
1456
+ value="openpose.png",
1457
+ type="numpy",
1458
+ show_label=False,
1459
+ height=LENGTH // 2,
1460
+ width=LENGTH // 2,
1461
+ interactive=False,
1462
+ visible=True
1463
+ )
1464
  gr.Markdown(
1465
  """<p style="text-align: center;">&#9314; Optionally flip the hand</p>"""
1466
  )
1467
  target_flip = gr.Checkbox(
1468
  value=False, label="Flip Handedness (Target)", interactive=False
1469
  )
1470
+
1471
+ # result column
1472
  with gr.Column():
1473
  gr.Markdown(
1474
  """<p style="text-align: center; font-size: 20px; font-weight: bold;">3. Press &quot;Run&quot; to get the edited results 🎯</p>"""
 
1532
  interactive=True,
1533
  )
1534
 
1535
+ # reference listeners
1536
  ref.change(enable_component, [ref, ref], ref_finish_crop)
1537
+ ref_finish_crop.click(prepare_anno, [ref], [ref_im_raw, ref_kp_raw])
1538
  ref_kp_raw.change(lambda x: x, ref_im_raw, ref_manual_kp_right)
1539
  ref_kp_raw.change(lambda x: x, ref_im_raw, ref_manual_kp_left)
1540
+ ref_kp_raw.change(get_ref_anno, [ref_im_raw, ref_kp_raw], [ref_img, ref_pose, ref_auto_cond])
1541
+ ref_pose.change(enable_component, [ref_kp_raw, ref_pose], ref_use_auto)
1542
+ ref_pose.change(enable_component, [ref_img, ref_pose], ref_flip)
1543
+ ref_auto_cond.change(lambda x: x, ref_auto_cond, ref_cond)
1544
+ ref_use_auto.click(lambda x: x, ref_auto_cond, ref_cond)
1545
+ ref_use_auto.click(lambda x: gr.Info("Automatic hand keypoints will be used for 'Reference'", duration=3))
1546
+
1547
  ref_manual_checkbox.select(
1548
  set_visible,
1549
  [ref_manual_checkbox, ref_kp_got, ref_im_raw, ref_manual_kp_right, ref_manual_kp_left, ref_manual_done],
 
1581
  ref_manual_reset_left.click(
1582
  reset_kps, [ref_im_raw, ref_kp_got, gr.State("left")], [ref_manual_kp_left, ref_kp_got]
1583
  )
1584
+ ref_manual_done.click(visible_component, [gr.State(0), ref_manual_pose], ref_manual_pose)
1585
+ ref_manual_done.click(visible_component, [gr.State(0), ref_use_manual], ref_use_manual)
1586
  ref_manual_done.click(get_ref_anno, [ref_im_raw, ref_kp_got], [ref_img, ref_manual_pose, ref_manual_cond])
 
 
 
 
 
1587
  ref_manual_pose.change(enable_component, [ref_manual_pose, ref_manual_pose], ref_manual_done)
 
 
 
 
 
 
1588
  ref_manual_pose.change(enable_component, [ref_img, ref_manual_pose], ref_flip)
1589
+ ref_manual_cond.change(lambda x: x, ref_manual_cond, ref_cond)
1590
+ ref_use_manual.click(lambda x: x, ref_manual_cond, ref_cond)
1591
+ ref_use_manual.click(lambda x: gr.Info("Manual hand keypoints will be used for 'Reference'", duration=3))
1592
+
1593
  ref_flip.select(
1594
+ flip_hand,
1595
+ [ref, ref_im_raw, ref_pose, ref_manual_pose, ref_manual_kp_right, ref_manual_kp_left, ref_cond, ref_auto_cond, ref_manual_cond],
1596
+ [ref, ref_im_raw, ref_pose, ref_manual_pose, ref_manual_kp_right, ref_manual_kp_left, ref_cond, ref_auto_cond, ref_manual_cond]
1597
  )
1598
+
1599
+ # target listeners
1600
  target.change(enable_component, [target, target], target_finish_crop)
1601
+ target_finish_crop.click(prepare_anno, [target], [target_im_raw, target_kp_raw])
1602
+ target_kp_raw.change(lambda x:x, target_im_raw, target_manual_kp_right)
1603
+ target_kp_raw.change(lambda x:x, target_im_raw, target_manual_kp_left)
1604
+ target_kp_raw.change(get_target_anno, [target_im_raw, target_kp_raw], [target_img, target_pose, target_auto_cond, target_auto_keypts])
1605
+ target_pose.change(enable_component, [target_kp_raw, target_pose], target_use_auto)
1606
  target_pose.change(enable_component, [target_img, target_pose], target_flip)
1607
+ target_auto_cond.change(lambda x: x, target_auto_cond, target_cond)
1608
+ target_auto_keypts.change(lambda x: x, target_auto_keypts, target_keypts)
1609
+ target_use_auto.click(lambda x: x, target_auto_cond, target_cond)
1610
+ target_use_auto.click(lambda x: x, target_auto_keypts, target_keypts)
1611
+ target_use_auto.click(lambda x: gr.Info("Automatic hand keypoints will be used for 'Target'", duration=3))
1612
+
1613
+ target_manual_checkbox.select(
1614
+ set_visible,
1615
+ [target_manual_checkbox, target_kp_got, target_im_raw, target_manual_kp_right, target_manual_kp_left, target_manual_done],
1616
+ [
1617
+ target_kp_got,
1618
+ target_manual_kp_right,
1619
+ target_manual_kp_left,
1620
+ target_manual_kp_right,
1621
+ target_manual_undo_right,
1622
+ target_manual_reset_right,
1623
+ target_manual_kp_left,
1624
+ target_manual_undo_left,
1625
+ target_manual_reset_left,
1626
+ target_manual_kp_r_info,
1627
+ target_manual_kp_l_info,
1628
+ target_manual_done,
1629
+ target_manual_done_info
1630
+ ]
1631
+ )
1632
+ target_manual_kp_right.select(
1633
+ get_kps, [target_im_raw, target_kp_got, gr.State("right")], [target_manual_kp_right, target_kp_got]
1634
+ )
1635
+ target_manual_undo_right.click(
1636
+ undo_kps, [target_im_raw, target_kp_got, gr.State("right")], [target_manual_kp_right, target_kp_got]
1637
+ )
1638
+ target_manual_reset_right.click(
1639
+ reset_kps, [target_im_raw, target_kp_got, gr.State("right")], [target_manual_kp_right, target_kp_got]
1640
+ )
1641
+ target_manual_kp_left.select(
1642
+ get_kps, [target_im_raw, target_kp_got, gr.State("left")], [target_manual_kp_left, target_kp_got]
1643
+ )
1644
+ target_manual_undo_left.click(
1645
+ undo_kps, [target_im_raw, target_kp_got, gr.State("left")], [target_manual_kp_left, target_kp_got]
1646
+ )
1647
+ target_manual_reset_left.click(
1648
+ reset_kps, [target_im_raw, target_kp_got, gr.State("left")], [target_manual_kp_left, target_kp_got]
1649
+ )
1650
+ target_manual_done.click(visible_component, [gr.State(0), target_manual_pose], target_manual_pose)
1651
+ target_manual_done.click(visible_component, [gr.State(0), target_use_manual], target_use_manual)
1652
+ target_manual_done.click(get_target_anno, [target_im_raw, target_kp_got], [target_img, target_manual_pose, target_manual_cond, target_manual_keypts])
1653
+ target_manual_pose.change(enable_component, [target_manual_pose, target_manual_pose], target_manual_done)
1654
+ target_manual_pose.change(enable_component, [target_img, target_manual_pose], target_flip)
1655
+ target_manual_cond.change(lambda x: x, target_manual_cond, target_cond)
1656
+ target_manual_keypts.change(lambda x: x, target_manual_keypts, target_keypts)
1657
+ target_use_manual.click(lambda x: x, target_manual_cond, target_cond)
1658
+ target_use_manual.click(lambda x: x, target_manual_keypts, target_keypts)
1659
+ target_use_manual.click(lambda x: gr.Info("Manual hand keypoints will be used for 'Reference'", duration=3))
1660
+
1661
  target_flip.select(
1662
  flip_hand,
1663
+ [target, target_im_raw, target_pose, target_manual_pose, target_manual_kp_right, target_manual_kp_left, target_cond, target_auto_cond, target_manual_cond, target_keypts, target_auto_keypts, target_manual_keypts],
1664
+ [target, target_im_raw, target_pose, target_manual_pose, target_manual_kp_right, target_manual_kp_left, target_cond, target_auto_cond, target_manual_cond, target_keypts, target_auto_keypts, target_manual_keypts],
1665
  )
1666
+
1667
+ # run listerners
1668
+ ref_cond.change(enable_component, [ref_cond, target_cond], run)
1669
+ target_cond.change(enable_component, [ref_cond, target_cond], run)
1670
+ # ref_manual_pose.change(enable_component, [ref_manual_pose, target_manual_pose], run)
1671
+ # target_manual_pose.change(enable_component, [ref_manual_pose, target_manual_pose], run)
1672
  run.click(
1673
  sample_diff,
1674
  [ref_cond, target_cond, target_keypts, n_generation, seed, cfg],
 
1679
  [],
1680
  [
1681
  ref,
1682
+ ref_manual_checkbox,
1683
  ref_manual_kp_right,
1684
  ref_manual_kp_left,
1685
+ ref_img,
1686
  ref_pose,
1687
  ref_manual_pose,
1688
+ ref_cond,
1689
  ref_flip,
1690
  target,
1691
+ target_keypts,
1692
+ target_manual_checkbox,
1693
+ target_manual_kp_right,
1694
+ target_manual_kp_left,
1695
+ target_img,
1696
  target_pose,
1697
+ target_manual_pose,
1698
+ target_cond,
1699
  target_flip,
1700
  results,
1701
  results_pose,
 
 
 
 
 
1702
  n_generation,
1703
  seed,
1704
  cfg,
1705
  ref_kp_raw,
 
1706
  ],
1707
  )
1708
  clear.click(
1709
  set_unvisible,
1710
  [],
1711
  [
 
1712
  ref_manual_kp_l_info,
1713
+ ref_manual_kp_r_info,
1714
+ ref_manual_kp_left,
1715
+ ref_manual_kp_right,
1716
  ref_manual_undo_left,
1717
  ref_manual_undo_right,
1718
  ref_manual_reset_left,
 
1721
  ref_manual_done_info,
1722
  ref_manual_pose,
1723
  ref_use_manual,
1724
+ target_manual_kp_l_info,
1725
+ target_manual_kp_r_info,
1726
+ target_manual_kp_left,
1727
+ target_manual_kp_right,
1728
+ target_manual_undo_left,
1729
+ target_manual_undo_right,
1730
+ target_manual_reset_left,
1731
+ target_manual_reset_right,
1732
+ target_manual_done,
1733
+ target_manual_done_info,
1734
+ target_manual_pose,
1735
+ target_use_manual,
1736
  ]
1737
  )
1738
 
1739
  with gr.Tab("Fix Hands"):
1740
  fix_inpaint_mask = gr.State(value=None)
1741
  fix_original = gr.State(value=None)
1742
+ fix_crop_coord = gr.State(value=None)
1743
  fix_img = gr.State(value=None)
1744
  fix_kpts = gr.State(value=None)
1745
  fix_kpts_np = gr.State(value=None)
 
1748
  fix_latent = gr.State(value=None)
1749
  fix_inpaint_latent = gr.State(value=None)
1750
  with gr.Row():
1751
+ # crop & brush
1752
  with gr.Column():
1753
  gr.Markdown(
1754
  """<p style="text-align: center; font-size: 20px; font-weight: bold;">1. Upload a malformed hand image to fix 📥</p>"""
1755
  )
1756
  gr.Markdown(
1757
+ """<p style="text-align: center;">&#9312; Optionally crop the image by clicking <b>top left</b> and <b>bottom right</b> of your desired bounding box around the hand. </p>"""
1758
  )
1759
+ # fix_crop = gr.ImageEditor(
1760
+ # type="numpy",
1761
+ # sources=["upload", "webcam", "clipboard"],
1762
+ # label="Image crop",
1763
+ # show_label=True,
1764
+ # height=LENGTH,
1765
+ # width=LENGTH,
1766
+ # layers=False,
1767
+ # # crop_size="1:1",
1768
+ # transforms=(),
1769
+ # brush=False,
1770
+ # image_mode="RGBA",
1771
+ # container=False,
1772
+ # )
1773
+ fix_crop = gr.Image(
1774
  type="numpy",
1775
  sources=["upload", "webcam", "clipboard"],
1776
+ label="Input Image",
1777
  show_label=True,
1778
  height=LENGTH,
1779
  width=LENGTH,
1780
+ interactive=True,
1781
+ visible=True,
1782
+ )
1783
+ gr.Markdown(
1784
+ """<p style="text-align: center;">💡 If you crop, the model can focus on more details of the cropped area. Square crops might work better than rectangle crops.</p>"""
1785
  )
1786
+ # fix_tmp = gr.Image(
1787
+ # type="numpy",
1788
+ # label="tmp",
1789
+ # show_label=True,
1790
+ # height=LENGTH,
1791
+ # width=LENGTH,
1792
+ # interactive=True,
1793
+ # visible=True,
1794
+ # sources=[],
1795
+ # )
1796
  fix_example = gr.Examples(
1797
  fix_example_imgs,
1798
  inputs=[fix_crop],
1799
  examples_per_page=20,
1800
  )
1801
  gr.Markdown(
1802
+ """<p style="text-align: center;">&#9313; Brush area (e.g., wrong finger) that needs to be fixed. Don't brush the entire hand!</p>"""
1803
  )
1804
  fix_ref = gr.ImageEditor(
1805
  type="numpy",
1806
+ label="Image Brushing",
1807
  sources=(),
1808
  show_label=True,
1809
  height=LENGTH,
 
1817
  container=False,
1818
  interactive=False,
1819
  )
1820
+ gr.Markdown(
1821
+ """<p style="text-align: center;">&#9314; Hit the \"Finish Cropping & Brushing\" button</p>"""
1822
+ )
1823
  fix_finish_crop = gr.Button(
1824
  value="Finish Croping & Brushing", interactive=False
1825
  )
1826
+
1827
+ # keypoint selection
1828
  with gr.Column():
1829
  gr.Markdown(
1830
  """<p style="text-align: center; font-size: 20px; font-weight: bold;">2. Click on hand to get target hand pose</p>"""
 
1837
  show_label=False,
1838
  interactive=False,
1839
  )
 
 
 
1840
  fix_kp_r_info = gr.Markdown(
1841
+ """<p style="text-align: center;">&#9313; Click 21 keypoints on the image to provide the target hand pose of <b>right hand</b>. See the \"OpenPose keypoints convention\" for guidance.</p>""",
1842
+ visible=False
1843
  )
1844
+ # fix_kp_r_info = gr.Markdown(
1845
+ # """<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select right only</p>""",
1846
+ # visible=False,
1847
+ # )
1848
  fix_kp_right = gr.Image(
1849
  type="numpy",
1850
  label="Keypoint Selection (right hand)",
 
1863
  value="Reset", interactive=False, visible=False
1864
  )
1865
  fix_kp_l_info = gr.Markdown(
1866
+ """<p style="text-align: center;">&#9313; Click 21 keypoints on the image to provide the target hand pose of <b>left hand</b>. See the \"OpenPose keypoints convention\" for guidance.</p>""",
1867
  visible=False
1868
  )
1869
  fix_kp_left = gr.Image(
 
1894
  width=LENGTH // 2,
1895
  interactive=False,
1896
  )
1897
+
1898
+ # get latent
1899
  with gr.Column():
1900
  gr.Markdown(
1901
  """<p style="text-align: center; font-size: 20px; font-weight: bold;">3. Press &quot;Ready&quot; to start pre-processing</p>"""
1902
  )
1903
  fix_ready = gr.Button(value="Ready", interactive=False)
1904
  gr.Markdown(
1905
+ """<p style="text-align: center; font-weight: bold; ">Visualized (256, 256)-resized, brushed image</p>"""
1906
  )
1907
  fix_vis_mask32 = gr.Image(
1908
  type="numpy",
 
1921
  width=opts.image_size,
1922
  interactive=False,
1923
  )
1924
+ # gr.Markdown(
1925
+ # """<p style="text-align: center;">[NOTE] Above should be inpaint mask that you brushed, NOT the segmentation mask of the entire hand. </p>"""
1926
+ # )
1927
+
1928
+ # result column
1929
  with gr.Column():
1930
  gr.Markdown(
1931
  """<p style="text-align: center; font-size: 20px; font-weight: bold;">4. Press &quot;Run&quot; to get the fixed hand image 🎯</p>"""
 
1934
  gr.Markdown(
1935
  """<p style="text-align: center;">⚠️ >3min and ~24GB per generation</p>"""
1936
  )
1937
+ fix_result_original = gr.Gallery(
1938
+ type="numpy",
1939
+ label="Results on original input",
1940
+ show_label=True,
1941
+ height=LENGTH,
1942
+ min_width=LENGTH,
1943
+ columns=FIX_MAX_N,
1944
+ interactive=False,
1945
+ preview=True,
1946
+ )
1947
  fix_result = gr.Gallery(
1948
  type="numpy",
1949
  label="Results",
 
1969
  )
1970
  fix_clear = gr.ClearButton()
1971
 
1972
+ with gr.Tab("More options"):
1973
+ gr.Markdown(
1974
+ "⚠️ Currently, Number of generation > 1 could lead to out-of-memory"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1975
  )
1976
+ with gr.Row():
1977
+ fix_n_generation = gr.Slider(
1978
+ label="Number of generations",
1979
+ value=1,
1980
+ minimum=1,
1981
+ maximum=FIX_MAX_N,
1982
+ step=1,
1983
+ randomize=False,
1984
+ interactive=True,
1985
+ )
1986
+ fix_seed = gr.Slider(
1987
+ label="Seed",
1988
+ value=42,
1989
+ minimum=0,
1990
+ maximum=10000,
1991
+ step=1,
1992
+ randomize=False,
1993
+ interactive=True,
1994
+ )
1995
+ fix_cfg = gr.Slider(
1996
+ label="Classifier free guidance scale",
1997
+ value=3.0,
1998
+ minimum=0.0,
1999
+ maximum=10.0,
2000
+ step=0.1,
2001
+ randomize=False,
2002
+ interactive=True,
2003
+ )
2004
+ fix_quality = gr.Slider(
2005
+ label="Quality",
2006
+ value=10,
2007
+ minimum=1,
2008
+ maximum=10,
2009
+ step=1,
2010
+ randomize=False,
2011
+ interactive=True,
2012
+ )
2013
+
2014
+ # listeners
2015
+ # fix_crop.change(resize_to_full, fix_crop, fix_ref)
2016
+ fix_crop.change(lambda x: x, fix_crop, fix_original) # fix_original: (real_H, real_W, 3)
2017
+ fix_crop.change(stay_crop, [fix_crop, fix_crop_coord], [fix_crop_coord, fix_ref])
2018
+ fix_crop.select(process_crop, [fix_crop, fix_crop_coord], [fix_crop_coord, fix_ref])
2019
+ # fix_ref.change(disable_crop, fix_crop_coord, fix_crop)
2020
+ fix_ref.change(enable_component, [fix_crop, fix_crop], fix_ref)
2021
+ fix_ref.change(enable_component, [fix_crop, fix_crop], fix_finish_crop)
2022
+ fix_finish_crop.click(visualize_ref, [fix_ref], [fix_img])
2023
+ fix_finish_crop.click(get_mask_inpaint, [fix_ref], [fix_inpaint_mask]) # fix_ref: (real_cropped_H, real_cropped_W, 3)
2024
  fix_img.change(lambda x: x, [fix_img], [fix_kp_right])
2025
  fix_img.change(lambda x: x, [fix_img], [fix_kp_left])
2026
  fix_inpaint_mask.change(
 
2065
  ],
2066
  )
2067
  fix_kp_right.select(
2068
+ get_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts] # fix_img: (real_cropped_H, real_cropped_W, 3)
2069
  )
2070
  fix_undo_right.click(
2071
  undo_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
 
2087
  )
2088
  fix_ready.click(
2089
  ready_sample,
2090
+ [fix_ref, fix_inpaint_mask, fix_kpts],
2091
  [
2092
  fix_ref_cond,
2093
  fix_target_cond,
 
2106
  fix_latent,
2107
  fix_inpaint_latent,
2108
  fix_kpts_np,
2109
+ fix_original,
2110
+ fix_crop_coord,
2111
  fix_n_generation,
2112
  fix_seed,
2113
  fix_cfg,
2114
  fix_quality,
2115
  ],
2116
+ [fix_result, fix_result_pose, fix_result_original],
2117
  )
2118
  fix_clear.click(
2119
  fix_clear_all,
2120
  [],
2121
  [
2122
  fix_crop,
2123
+ fix_crop_coord,
2124
  fix_ref,
2125
+ fix_checkbox,
2126
  fix_kp_right,
2127
  fix_kp_left,
2128
  fix_result,
2129
  fix_result_pose,
2130
+ fix_result_original,
2131
  fix_inpaint_mask,
2132
  fix_original,
2133
  fix_img,
 
2145
  fix_quality,
2146
  ],
2147
  )
2148
+ fix_clear.click(
2149
+ fix_set_unvisible,
2150
+ [],
2151
+ [
2152
+ fix_kp_right,
2153
+ fix_kp_left,
2154
+ fix_kp_r_info,
2155
+ fix_kp_l_info,
2156
+ fix_undo_left,
2157
+ fix_undo_right,
2158
+ fix_reset_left,
2159
+ fix_reset_right
2160
+ ]
2161
+ )
2162
 
2163
  gr.Markdown("<h1>Citation</h1>")
2164
  gr.Markdown(
brown_logo.png ADDED

Git LFS Details

  • SHA256: 654ac3b7a615ed09cfaaf1cb0bc1d8a53051a42598fb3cba3e5620ba255e6a7c
  • Pointer size: 130 Bytes
  • Size of remote file: 35.8 kB
meta_logo.png ADDED

Git LFS Details

  • SHA256: d573af322a5fd721558b0d677dc963213d2696696b23e999179df3144ba6271b
  • Pointer size: 130 Bytes
  • Size of remote file: 21.6 kB
sbatch/sbatch_demo.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # job name
4
+ #SBATCH -J demo_foundhand
5
+
6
+ # partition
7
+ #SBATCH --partition=ssrinath-gcondo --gres=gpu:1 --gres-flags=enforce-binding
8
+ #SBATCH --account=ssrinath-gcondo
9
+
10
+ # ensures all allocated cores are on the same node
11
+ #SBATCH -N 1
12
+
13
+ # cpu cores
14
+ #SBATCH --ntasks-per-node=4
15
+
16
+ # memory per node
17
+ #SBATCH --mem=32G
18
+
19
+ # runtime
20
+ #SBATCH -t 240:00:00
21
+
22
+ # output
23
+ #SBATCH -o out/demo.out
24
+
25
+ # error
26
+ #SBATCH -e err/demo.err
27
+
28
+ # email notifiaction
29
+ # SBATCH --mail-type=ALL
30
+
31
+ module load miniconda3/23.11.0s
32
+ source /oscar/runtime/software/external/miniconda3/23.11.0/etc/profile.d/conda.sh
33
+ conda activate handdiff
34
+
35
+ cd $HOME/hdd/FoundHand_demo
36
+ echo Directory is `pwd`
37
+
38
+ python -u app.py
vqvae.py CHANGED
@@ -20,7 +20,10 @@ from typing import List
20
  import torch
21
  import torch.nn.functional as F
22
  from torch import nn
23
- import spaces
 
 
 
24
 
25
 
26
  class Autoencoder(nn.Module):
 
20
  import torch
21
  import torch.nn.functional as F
22
  from torch import nn
23
+ try:
24
+ import spaces
25
+ except:
26
+ pass
27
 
28
 
29
  class Autoencoder(nn.Module):