Chaerin5 commited on
Commit
2b83923
·
1 Parent(s): 385c0f2

fix vae nan bug

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -276,7 +276,7 @@ def get_ref_anno(ref):
276
  None,
277
  )
278
  missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
279
-
280
  img = ref["composite"][..., :3]
281
  img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
282
  keypts = np.zeros((42, 2))
@@ -566,7 +566,7 @@ def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg):
566
  print(f"results[0].max(): {results[0].max()}")
567
  return results, results_pose
568
 
569
- @spaces.GPU(duration=120)
570
  def ready_sample(img_ori, inpaint_mask, keypts):
571
  img = cv2.resize(img_ori[..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
572
  sam_predictor.set_image(img)
@@ -638,7 +638,7 @@ def ready_sample(img_ori, inpaint_mask, keypts):
638
  inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
639
  ),
640
  dtype=torch.float,
641
- device=device,
642
  ).unsqueeze(0)[None, ...]
643
 
644
  def make_ref_cond(
@@ -656,7 +656,7 @@ def ready_sample(img_ori, inpaint_mask, keypts):
656
  Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
657
  ]
658
  )
659
- image = image_transform(img).to(device)
660
  kpts_valid = check_keypoints_validity(keypts, target_size)
661
  heatmaps = torch.tensor(
662
  keypoint_heatmap(
@@ -664,7 +664,7 @@ def ready_sample(img_ori, inpaint_mask, keypts):
664
  )
665
  * kpts_valid[:, None, None],
666
  dtype=torch.float,
667
- device=device,
668
  )[None, ...]
669
  mask = torch.tensor(
670
  cv2.resize(
@@ -673,7 +673,7 @@ def ready_sample(img_ori, inpaint_mask, keypts):
673
  interpolation=cv2.INTER_NEAREST,
674
  ),
675
  dtype=torch.float,
676
- device=device,
677
  ).unsqueeze(0)[None, ...]
678
  return image[None, ...], heatmaps, mask
679
 
@@ -744,7 +744,7 @@ def sample_inpaint(
744
  target_cond_N = target_cond.repeat(N, 1, 1, 1)
745
  ref_cond_N = ref_cond.repeat(N, 1, 1, 1)
746
  # novel view synthesis mode = off
747
- nvs = torch.zeros(N, dtype=torch.int, device=device)
748
  z = torch.cat([z, z], 0)
749
  model_kwargs = dict(
750
  target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]),
@@ -762,7 +762,7 @@ def sample_inpaint(
762
  clip_denoised=False,
763
  model_kwargs=model_kwargs,
764
  progress=True,
765
- device=device,
766
  jump_length=jump_length,
767
  jump_n_sample=jump_n_sample,
768
  ).chunk(2)
@@ -1078,7 +1078,7 @@ with gr.Blocks(css=custom_css) as demo:
1078
  )
1079
  run = gr.Button(value="Run", interactive=False)
1080
  gr.Markdown(
1081
- """<p style="text-align: center;">~20s per generation. <br>(For example, if you set Number of generations as 2, it would take around 40s)</p>"""
1082
  )
1083
  results = gr.Gallery(
1084
  type="numpy",
 
276
  None,
277
  )
278
  missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
279
+
280
  img = ref["composite"][..., :3]
281
  img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
282
  keypts = np.zeros((42, 2))
 
566
  print(f"results[0].max(): {results[0].max()}")
567
  return results, results_pose
568
 
569
+ # @spaces.GPU(duration=120)
570
  def ready_sample(img_ori, inpaint_mask, keypts):
571
  img = cv2.resize(img_ori[..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
572
  sam_predictor.set_image(img)
 
638
  inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
639
  ),
640
  dtype=torch.float,
641
+ # device=device,
642
  ).unsqueeze(0)[None, ...]
643
 
644
  def make_ref_cond(
 
656
  Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
657
  ]
658
  )
659
+ image = image_transform(img)
660
  kpts_valid = check_keypoints_validity(keypts, target_size)
661
  heatmaps = torch.tensor(
662
  keypoint_heatmap(
 
664
  )
665
  * kpts_valid[:, None, None],
666
  dtype=torch.float,
667
+ # device=device,
668
  )[None, ...]
669
  mask = torch.tensor(
670
  cv2.resize(
 
673
  interpolation=cv2.INTER_NEAREST,
674
  ),
675
  dtype=torch.float,
676
+ # device=device,
677
  ).unsqueeze(0)[None, ...]
678
  return image[None, ...], heatmaps, mask
679
 
 
744
  target_cond_N = target_cond.repeat(N, 1, 1, 1)
745
  ref_cond_N = ref_cond.repeat(N, 1, 1, 1)
746
  # novel view synthesis mode = off
747
+ nvs = torch.zeros(N, dtype=torch.int, device=z.device)
748
  z = torch.cat([z, z], 0)
749
  model_kwargs = dict(
750
  target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]),
 
762
  clip_denoised=False,
763
  model_kwargs=model_kwargs,
764
  progress=True,
765
+ device=z.device,
766
  jump_length=jump_length,
767
  jump_n_sample=jump_n_sample,
768
  ).chunk(2)
 
1078
  )
1079
  run = gr.Button(value="Run", interactive=False)
1080
  gr.Markdown(
1081
+ """<p style="text-align: center;">~20s per generation with RTX3090. ~50s with A100. <br>(For example, if you set Number of generations as 2, it would take around 40s)</p>"""
1082
  )
1083
  results = gr.Gallery(
1084
  type="numpy",