Spaces:
Running
on
Zero
Running
on
Zero
fix vae nan bug
Browse files
app.py
CHANGED
@@ -276,7 +276,7 @@ def get_ref_anno(ref):
|
|
276 |
None,
|
277 |
)
|
278 |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
279 |
-
|
280 |
img = ref["composite"][..., :3]
|
281 |
img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
|
282 |
keypts = np.zeros((42, 2))
|
@@ -566,7 +566,7 @@ def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg):
|
|
566 |
print(f"results[0].max(): {results[0].max()}")
|
567 |
return results, results_pose
|
568 |
|
569 |
-
@spaces.GPU(duration=120)
|
570 |
def ready_sample(img_ori, inpaint_mask, keypts):
|
571 |
img = cv2.resize(img_ori[..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
|
572 |
sam_predictor.set_image(img)
|
@@ -638,7 +638,7 @@ def ready_sample(img_ori, inpaint_mask, keypts):
|
|
638 |
inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
|
639 |
),
|
640 |
dtype=torch.float,
|
641 |
-
device=device,
|
642 |
).unsqueeze(0)[None, ...]
|
643 |
|
644 |
def make_ref_cond(
|
@@ -656,7 +656,7 @@ def ready_sample(img_ori, inpaint_mask, keypts):
|
|
656 |
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
657 |
]
|
658 |
)
|
659 |
-
image = image_transform(img)
|
660 |
kpts_valid = check_keypoints_validity(keypts, target_size)
|
661 |
heatmaps = torch.tensor(
|
662 |
keypoint_heatmap(
|
@@ -664,7 +664,7 @@ def ready_sample(img_ori, inpaint_mask, keypts):
|
|
664 |
)
|
665 |
* kpts_valid[:, None, None],
|
666 |
dtype=torch.float,
|
667 |
-
device=device,
|
668 |
)[None, ...]
|
669 |
mask = torch.tensor(
|
670 |
cv2.resize(
|
@@ -673,7 +673,7 @@ def ready_sample(img_ori, inpaint_mask, keypts):
|
|
673 |
interpolation=cv2.INTER_NEAREST,
|
674 |
),
|
675 |
dtype=torch.float,
|
676 |
-
device=device,
|
677 |
).unsqueeze(0)[None, ...]
|
678 |
return image[None, ...], heatmaps, mask
|
679 |
|
@@ -744,7 +744,7 @@ def sample_inpaint(
|
|
744 |
target_cond_N = target_cond.repeat(N, 1, 1, 1)
|
745 |
ref_cond_N = ref_cond.repeat(N, 1, 1, 1)
|
746 |
# novel view synthesis mode = off
|
747 |
-
nvs = torch.zeros(N, dtype=torch.int, device=device)
|
748 |
z = torch.cat([z, z], 0)
|
749 |
model_kwargs = dict(
|
750 |
target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]),
|
@@ -762,7 +762,7 @@ def sample_inpaint(
|
|
762 |
clip_denoised=False,
|
763 |
model_kwargs=model_kwargs,
|
764 |
progress=True,
|
765 |
-
device=device,
|
766 |
jump_length=jump_length,
|
767 |
jump_n_sample=jump_n_sample,
|
768 |
).chunk(2)
|
@@ -1078,7 +1078,7 @@ with gr.Blocks(css=custom_css) as demo:
|
|
1078 |
)
|
1079 |
run = gr.Button(value="Run", interactive=False)
|
1080 |
gr.Markdown(
|
1081 |
-
"""<p style="text-align: center;">~20s per generation. <br>(For example, if you set Number of generations as 2, it would take around 40s)</p>"""
|
1082 |
)
|
1083 |
results = gr.Gallery(
|
1084 |
type="numpy",
|
|
|
276 |
None,
|
277 |
)
|
278 |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
279 |
+
|
280 |
img = ref["composite"][..., :3]
|
281 |
img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
|
282 |
keypts = np.zeros((42, 2))
|
|
|
566 |
print(f"results[0].max(): {results[0].max()}")
|
567 |
return results, results_pose
|
568 |
|
569 |
+
# @spaces.GPU(duration=120)
|
570 |
def ready_sample(img_ori, inpaint_mask, keypts):
|
571 |
img = cv2.resize(img_ori[..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
|
572 |
sam_predictor.set_image(img)
|
|
|
638 |
inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
|
639 |
),
|
640 |
dtype=torch.float,
|
641 |
+
# device=device,
|
642 |
).unsqueeze(0)[None, ...]
|
643 |
|
644 |
def make_ref_cond(
|
|
|
656 |
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
657 |
]
|
658 |
)
|
659 |
+
image = image_transform(img)
|
660 |
kpts_valid = check_keypoints_validity(keypts, target_size)
|
661 |
heatmaps = torch.tensor(
|
662 |
keypoint_heatmap(
|
|
|
664 |
)
|
665 |
* kpts_valid[:, None, None],
|
666 |
dtype=torch.float,
|
667 |
+
# device=device,
|
668 |
)[None, ...]
|
669 |
mask = torch.tensor(
|
670 |
cv2.resize(
|
|
|
673 |
interpolation=cv2.INTER_NEAREST,
|
674 |
),
|
675 |
dtype=torch.float,
|
676 |
+
# device=device,
|
677 |
).unsqueeze(0)[None, ...]
|
678 |
return image[None, ...], heatmaps, mask
|
679 |
|
|
|
744 |
target_cond_N = target_cond.repeat(N, 1, 1, 1)
|
745 |
ref_cond_N = ref_cond.repeat(N, 1, 1, 1)
|
746 |
# novel view synthesis mode = off
|
747 |
+
nvs = torch.zeros(N, dtype=torch.int, device=z.device)
|
748 |
z = torch.cat([z, z], 0)
|
749 |
model_kwargs = dict(
|
750 |
target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]),
|
|
|
762 |
clip_denoised=False,
|
763 |
model_kwargs=model_kwargs,
|
764 |
progress=True,
|
765 |
+
device=z.device,
|
766 |
jump_length=jump_length,
|
767 |
jump_n_sample=jump_n_sample,
|
768 |
).chunk(2)
|
|
|
1078 |
)
|
1079 |
run = gr.Button(value="Run", interactive=False)
|
1080 |
gr.Markdown(
|
1081 |
+
"""<p style="text-align: center;">~20s per generation with RTX3090. ~50s with A100. <br>(For example, if you set Number of generations as 2, it would take around 40s)</p>"""
|
1082 |
)
|
1083 |
results = gr.Gallery(
|
1084 |
type="numpy",
|