Spaces:
Running
on
Zero
Running
on
Zero
guangkaixu
commited on
Commit
·
10e02f0
1
Parent(s):
c83d507
upload
Browse files- app.py +240 -29
- empty_text_embed.npy +0 -3
- genpercept/__init__.py +13 -0
- genpercept/customized_modules/ddim.py +213 -0
- genpercept/genpercept_pipeline.py +519 -0
- genpercept/models/custom_unet.py +427 -0
- genpercept/models/dpt_head.py +593 -0
- {util → genpercept/util}/batchsize.py +26 -4
- genpercept/util/ensemble.py +205 -0
- {util → genpercept/util}/image_util.py +62 -104
- hf_configs/dpt-sd2.1-unet-after-upsample-general/config.json +48 -0
- hf_configs/dpt-sd2.1-unet-after-upsample-general/preprocessor_config.json +27 -0
- hf_configs/scheduler_beta_1.0_1.0/scheduler_config.json +20 -0
- pipeline_genpercept.py +0 -355
- requirements.txt +4 -0
- seg_images/seg_1.jpg +0 -0
- seg_images/seg_2.jpg +0 -0
- seg_images/seg_3.jpg +0 -0
- seg_images/seg_4.jpg +0 -0
- seg_images/seg_5.jpg +0 -0
- util/__init__.py +0 -0
- util/seed_all.py +0 -13
app.py
CHANGED
@@ -34,14 +34,17 @@ from PIL import Image
|
|
34 |
|
35 |
from gradio_imageslider import ImageSlider
|
36 |
from gradio_patches.examples import Examples
|
37 |
-
from
|
38 |
|
39 |
from diffusers import (
|
40 |
DiffusionPipeline,
|
41 |
-
UNet2DConditionModel,
|
42 |
AutoencoderKL,
|
43 |
)
|
44 |
|
|
|
|
|
|
|
45 |
warnings.filterwarnings(
|
46 |
"ignore", message=".*LoginButton created outside of a Blocks context.*"
|
47 |
)
|
@@ -194,11 +197,13 @@ def process_matting(
|
|
194 |
)
|
195 |
|
196 |
|
197 |
-
def run_demo_server(pipe_depth, pipe_normal, pipe_dis, pipe_matting):
|
198 |
process_pipe_depth = spaces.GPU(functools.partial(process_depth, pipe_depth))
|
199 |
process_pipe_normal = spaces.GPU(functools.partial(process_normal, pipe_normal))
|
200 |
process_pipe_dis = spaces.GPU(functools.partial(process_dis, pipe_dis))
|
201 |
process_pipe_matting = spaces.GPU(functools.partial(process_matting, pipe_matting))
|
|
|
|
|
202 |
gradio_theme = gr.themes.Default()
|
203 |
|
204 |
with gr.Blocks(
|
@@ -485,7 +490,7 @@ def run_demo_server(pipe_depth, pipe_normal, pipe_dis, pipe_matting):
|
|
485 |
example_folder = os.path.join(os.path.dirname(__file__), "matting_images")
|
486 |
# print(example_folder)
|
487 |
Examples(
|
488 |
-
fn=
|
489 |
examples=[
|
490 |
os.path.join(example_folder, name)
|
491 |
for name in filenames
|
@@ -496,6 +501,120 @@ def run_demo_server(pipe_depth, pipe_normal, pipe_dis, pipe_matting):
|
|
496 |
directory_name="images_cache",
|
497 |
cache_examples=False,
|
498 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
|
500 |
|
501 |
### Image tab
|
@@ -630,6 +749,72 @@ def run_demo_server(pipe_depth, pipe_normal, pipe_dis, pipe_matting):
|
|
630 |
],
|
631 |
queue=False,
|
632 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
633 |
### Server launch
|
634 |
|
635 |
demo.queue(
|
@@ -645,37 +830,61 @@ def main():
|
|
645 |
|
646 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
647 |
|
648 |
-
dtype = torch.float16
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
subfolder="unet",
|
654 |
-
use_safetensors=True).to(dtype)
|
655 |
-
unet_normal_v1 = UNet2DConditionModel.from_pretrained('guangkaixu/GenPercept', subfolder="unet_normal_v1", use_safetensors=True).to(dtype)
|
656 |
-
unet_dis_v1 = UNet2DConditionModel.from_pretrained('guangkaixu/GenPercept', subfolder="unet_dis_v1", use_safetensors=True).to(dtype)
|
657 |
-
unet_matting_v1 = UNet2DConditionModel.from_pretrained('guangkaixu/genpercept-matting', subfolder="unet", use_safetensors=True).to(dtype)
|
658 |
-
|
659 |
-
empty_text_embed = torch.from_numpy(np.load("./empty_text_embed.npy")).to(device, dtype)[None] # [1, 77, 1024]
|
660 |
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
673 |
try:
|
674 |
import xformers
|
675 |
pipe_depth.enable_xformers_memory_efficient_attention()
|
676 |
pipe_normal.enable_xformers_memory_efficient_attention()
|
677 |
pipe_dis.enable_xformers_memory_efficient_attention()
|
678 |
pipe_matting.enable_xformers_memory_efficient_attention()
|
|
|
|
|
679 |
except:
|
680 |
pass # run without xformers
|
681 |
|
@@ -683,8 +892,10 @@ def main():
|
|
683 |
pipe_normal = pipe_normal.to(device)
|
684 |
pipe_dis = pipe_dis.to(device)
|
685 |
pipe_matting = pipe_matting.to(device)
|
|
|
|
|
686 |
|
687 |
-
run_demo_server(pipe_depth, pipe_normal, pipe_dis, pipe_matting)
|
688 |
|
689 |
|
690 |
if __name__ == "__main__":
|
|
|
34 |
|
35 |
from gradio_imageslider import ImageSlider
|
36 |
from gradio_patches.examples import Examples
|
37 |
+
from genpercept.genpercept_pipeline import GenPerceptPipeline
|
38 |
|
39 |
from diffusers import (
|
40 |
DiffusionPipeline,
|
41 |
+
# UNet2DConditionModel,
|
42 |
AutoencoderKL,
|
43 |
)
|
44 |
|
45 |
+
from genpercept.models.custom_unet import CustomUNet2DConditionModel
|
46 |
+
from genpercept.customized_modules.ddim import DDIMSchedulerCustomized
|
47 |
+
|
48 |
warnings.filterwarnings(
|
49 |
"ignore", message=".*LoginButton created outside of a Blocks context.*"
|
50 |
)
|
|
|
197 |
)
|
198 |
|
199 |
|
200 |
+
def run_demo_server(pipe_depth, pipe_normal, pipe_dis, pipe_matting, pipe_seg, pipe_disparity):
|
201 |
process_pipe_depth = spaces.GPU(functools.partial(process_depth, pipe_depth))
|
202 |
process_pipe_normal = spaces.GPU(functools.partial(process_normal, pipe_normal))
|
203 |
process_pipe_dis = spaces.GPU(functools.partial(process_dis, pipe_dis))
|
204 |
process_pipe_matting = spaces.GPU(functools.partial(process_matting, pipe_matting))
|
205 |
+
process_pipe_seg = spaces.GPU(functools.partial(process_matting, pipe_seg))
|
206 |
+
process_pipe_disparity = spaces.GPU(functools.partial(process_matting, pipe_disparity))
|
207 |
gradio_theme = gr.themes.Default()
|
208 |
|
209 |
with gr.Blocks(
|
|
|
490 |
example_folder = os.path.join(os.path.dirname(__file__), "matting_images")
|
491 |
# print(example_folder)
|
492 |
Examples(
|
493 |
+
fn=process_pipe_matting,
|
494 |
examples=[
|
495 |
os.path.join(example_folder, name)
|
496 |
for name in filenames
|
|
|
501 |
directory_name="images_cache",
|
502 |
cache_examples=False,
|
503 |
)
|
504 |
+
|
505 |
+
with gr.Tab("Seg"):
|
506 |
+
with gr.Row():
|
507 |
+
with gr.Column():
|
508 |
+
seg_image_input = gr.Image(
|
509 |
+
label="Input Image",
|
510 |
+
type="filepath",
|
511 |
+
# type="pil",
|
512 |
+
)
|
513 |
+
with gr.Row():
|
514 |
+
seg_image_submit_btn = gr.Button(
|
515 |
+
value="Estimate Segmentation", variant="primary"
|
516 |
+
)
|
517 |
+
seg_image_reset_btn = gr.Button(value="Reset")
|
518 |
+
with gr.Accordion("Advanced options", open=False):
|
519 |
+
image_processing_res = gr.Radio(
|
520 |
+
[
|
521 |
+
("Native", 0),
|
522 |
+
("Recommended", 768),
|
523 |
+
],
|
524 |
+
label="Processing resolution",
|
525 |
+
value=default_image_processing_res,
|
526 |
+
)
|
527 |
+
with gr.Column():
|
528 |
+
seg_image_output_slider = ImageSlider(
|
529 |
+
label="Predicted segmentation results",
|
530 |
+
type="filepath",
|
531 |
+
show_download_button=True,
|
532 |
+
show_share_button=True,
|
533 |
+
interactive=False,
|
534 |
+
elem_classes="slider",
|
535 |
+
position=0.25,
|
536 |
+
)
|
537 |
+
seg_image_output_files = gr.Files(
|
538 |
+
label="Seg outputs",
|
539 |
+
elem_id="download",
|
540 |
+
interactive=False,
|
541 |
+
)
|
542 |
+
|
543 |
+
filenames = []
|
544 |
+
filenames.extend(["seg_anime_%d.jpg" %(i+1) for i in range(7)])
|
545 |
+
filenames.extend(["seg_line_%d.jpg" %(i+1) for i in range(6)])
|
546 |
+
filenames.extend(["seg_real_%d.jpg" %(i+1) for i in range(24)])
|
547 |
+
|
548 |
+
example_folder = os.path.join(os.path.dirname(__file__), "seg_images")
|
549 |
+
Examples(
|
550 |
+
fn=process_pipe_seg,
|
551 |
+
examples=[
|
552 |
+
os.path.join(example_folder, name)
|
553 |
+
for name in filenames
|
554 |
+
],
|
555 |
+
inputs=[seg_image_input],
|
556 |
+
outputs=[seg_image_output_slider, seg_image_output_files],
|
557 |
+
cache_examples=False,
|
558 |
+
# directory_name="examples_depth",
|
559 |
+
# cache_examples=False,
|
560 |
+
)
|
561 |
+
|
562 |
+
with gr.Tab("Disparity"):
|
563 |
+
with gr.Row():
|
564 |
+
with gr.Column():
|
565 |
+
disparity_image_input = gr.Image(
|
566 |
+
label="Input Image",
|
567 |
+
type="filepath",
|
568 |
+
# type="pil",
|
569 |
+
)
|
570 |
+
with gr.Row():
|
571 |
+
disparity_image_submit_btn = gr.Button(
|
572 |
+
value="Estimate Disparity", variant="primary"
|
573 |
+
)
|
574 |
+
disparity_image_reset_btn = gr.Button(value="Reset")
|
575 |
+
with gr.Accordion("Advanced options", open=False):
|
576 |
+
image_processing_res = gr.Radio(
|
577 |
+
[
|
578 |
+
("Native", 0),
|
579 |
+
("Recommended", 768),
|
580 |
+
],
|
581 |
+
label="Processing resolution",
|
582 |
+
value=default_image_processing_res,
|
583 |
+
)
|
584 |
+
with gr.Column():
|
585 |
+
disparity_image_output_slider = ImageSlider(
|
586 |
+
label="Predicted disparity results",
|
587 |
+
type="filepath",
|
588 |
+
show_download_button=True,
|
589 |
+
show_share_button=True,
|
590 |
+
interactive=False,
|
591 |
+
elem_classes="slider",
|
592 |
+
position=0.25,
|
593 |
+
)
|
594 |
+
disparity_image_output_files = gr.Files(
|
595 |
+
label="Disparity outputs",
|
596 |
+
elem_id="download",
|
597 |
+
interactive=False,
|
598 |
+
)
|
599 |
+
|
600 |
+
filenames = []
|
601 |
+
filenames.extend(["disparity_anime_%d.jpg" %(i+1) for i in range(7)])
|
602 |
+
filenames.extend(["disparity_line_%d.jpg" %(i+1) for i in range(6)])
|
603 |
+
filenames.extend(["disparity_real_%d.jpg" %(i+1) for i in range(24)])
|
604 |
+
|
605 |
+
example_folder = os.path.join(os.path.dirname(__file__), "depth_images")
|
606 |
+
Examples(
|
607 |
+
fn=process_pipe_disparity,
|
608 |
+
examples=[
|
609 |
+
os.path.join(example_folder, name)
|
610 |
+
for name in filenames
|
611 |
+
],
|
612 |
+
inputs=[disparity_image_input],
|
613 |
+
outputs=[disparity_image_output_slider, disparity_image_output_files],
|
614 |
+
cache_examples=False,
|
615 |
+
# directory_name="examples_depth",
|
616 |
+
# cache_examples=False,
|
617 |
+
)
|
618 |
|
619 |
|
620 |
### Image tab
|
|
|
749 |
],
|
750 |
queue=False,
|
751 |
)
|
752 |
+
|
753 |
+
seg_image_submit_btn.click(
|
754 |
+
fn=process_image_check,
|
755 |
+
inputs=seg_image_input,
|
756 |
+
outputs=None,
|
757 |
+
preprocess=False,
|
758 |
+
queue=False,
|
759 |
+
).success(
|
760 |
+
fn=process_pipe_seg,
|
761 |
+
inputs=[
|
762 |
+
seg_image_input,
|
763 |
+
image_processing_res,
|
764 |
+
],
|
765 |
+
outputs=[seg_image_output_slider, seg_image_output_files],
|
766 |
+
concurrency_limit=1,
|
767 |
+
)
|
768 |
+
|
769 |
+
seg_image_reset_btn.click(
|
770 |
+
fn=lambda: (
|
771 |
+
None,
|
772 |
+
None,
|
773 |
+
None,
|
774 |
+
default_image_processing_res,
|
775 |
+
),
|
776 |
+
inputs=[],
|
777 |
+
outputs=[
|
778 |
+
seg_image_input,
|
779 |
+
seg_image_output_slider,
|
780 |
+
seg_image_output_files,
|
781 |
+
image_processing_res,
|
782 |
+
],
|
783 |
+
queue=False,
|
784 |
+
)
|
785 |
+
|
786 |
+
disparity_image_submit_btn.click(
|
787 |
+
fn=process_image_check,
|
788 |
+
inputs=disparity_image_input,
|
789 |
+
outputs=None,
|
790 |
+
preprocess=False,
|
791 |
+
queue=False,
|
792 |
+
).success(
|
793 |
+
fn=process_pipe_disparity,
|
794 |
+
inputs=[
|
795 |
+
disparity_image_input,
|
796 |
+
image_processing_res,
|
797 |
+
],
|
798 |
+
outputs=[disparity_image_output_slider, disparity_image_output_files],
|
799 |
+
concurrency_limit=1,
|
800 |
+
)
|
801 |
+
|
802 |
+
disparity_image_reset_btn.click(
|
803 |
+
fn=lambda: (
|
804 |
+
None,
|
805 |
+
None,
|
806 |
+
None,
|
807 |
+
default_image_processing_res,
|
808 |
+
),
|
809 |
+
inputs=[],
|
810 |
+
outputs=[
|
811 |
+
disparity_image_input,
|
812 |
+
disparity_image_output_slider,
|
813 |
+
disparity_image_output_files,
|
814 |
+
image_processing_res,
|
815 |
+
],
|
816 |
+
queue=False,
|
817 |
+
)
|
818 |
### Server launch
|
819 |
|
820 |
demo.queue(
|
|
|
830 |
|
831 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
832 |
|
833 |
+
# dtype = torch.float16
|
834 |
+
# variant = "fp16"
|
835 |
+
|
836 |
+
dtype = torch.float32
|
837 |
+
variant = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
838 |
|
839 |
+
unet_depth_v2 = CustomUNet2DConditionModel.from_pretrained('guangkaixu/GenPercept-models', subfolder="unet_depth_v2", use_safetensors=True).to(dtype)
|
840 |
+
unet_normal_v2 = CustomUNet2DConditionModel.from_pretrained('guangkaixu/GenPercept-models', subfolder="unet_normal_v2", use_safetensors=True).to(dtype)
|
841 |
+
unet_dis_v2 = CustomUNet2DConditionModel.from_pretrained('guangkaixu/GenPercept-models', subfolder="unet_dis_v2", use_safetensors=True).to(dtype)
|
842 |
+
unet_matting_v2 = CustomUNet2DConditionModel.from_pretrained('guangkaixu/GenPercept-models', subfolder="unet_matting_v2", use_safetensors=True).to(dtype)
|
843 |
+
unet_disparity_v2 = CustomUNet2DConditionModel.from_pretrained('guangkaixu/GenPercept-models', subfolder="unet_disparity_v2", use_safetensors=True).to(dtype)
|
844 |
+
unet_seg_v2 = CustomUNet2DConditionModel.from_pretrained('guangkaixu/GenPercept-models', subfolder="unet_seg_v2", use_safetensors=True).to(dtype)
|
845 |
+
|
846 |
+
scheduler = DDIMSchedulerCustomized.from_pretrained("hf_configs/scheduler_beta_1.0_1.0", subfolder='scheduler')
|
847 |
+
genpercept_pipeline = True
|
848 |
+
|
849 |
+
pre_loaded_dict = dict(
|
850 |
+
scheduler=scheduler,
|
851 |
+
genpercept_pipeline=genpercept_pipeline,
|
852 |
+
torch_dtype=dtype,
|
853 |
+
variant=variant,
|
854 |
+
)
|
855 |
+
|
856 |
+
pipe_depth = GenPerceptPipeline.from_pretrained(
|
857 |
+
"stabilityai/stable-diffusion-2-1", unet=unet_depth_v2, **pre_loaded_dict,
|
858 |
+
)
|
859 |
+
|
860 |
+
pipe_normal = GenPerceptPipeline.from_pretrained(
|
861 |
+
"stabilityai/stable-diffusion-2-1", unet=unet_normal_v2, **pre_loaded_dict,
|
862 |
+
)
|
863 |
+
|
864 |
+
pipe_dis = GenPerceptPipeline.from_pretrained(
|
865 |
+
"stabilityai/stable-diffusion-2-1", unet=unet_dis_v2, **pre_loaded_dict,
|
866 |
+
)
|
867 |
+
|
868 |
+
pipe_matting = GenPerceptPipeline.from_pretrained(
|
869 |
+
"stabilityai/stable-diffusion-2-1", unet=unet_matting_v2, **pre_loaded_dict,
|
870 |
+
)
|
871 |
+
|
872 |
+
pipe_seg = GenPerceptPipeline.from_pretrained(
|
873 |
+
"stabilityai/stable-diffusion-2-1", unet=unet_seg_v2, **pre_loaded_dict,
|
874 |
+
)
|
875 |
+
|
876 |
+
pipe_disparity = GenPerceptPipeline.from_pretrained(
|
877 |
+
"stabilityai/stable-diffusion-2-1", unet=unet_disparity_v2, **pre_loaded_dict,
|
878 |
+
)
|
879 |
+
|
880 |
try:
|
881 |
import xformers
|
882 |
pipe_depth.enable_xformers_memory_efficient_attention()
|
883 |
pipe_normal.enable_xformers_memory_efficient_attention()
|
884 |
pipe_dis.enable_xformers_memory_efficient_attention()
|
885 |
pipe_matting.enable_xformers_memory_efficient_attention()
|
886 |
+
pipe_seg.enable_xformers_memory_efficient_attention()
|
887 |
+
pipe_disparity.enable_xformers_memory_efficient_attention()
|
888 |
except:
|
889 |
pass # run without xformers
|
890 |
|
|
|
892 |
pipe_normal = pipe_normal.to(device)
|
893 |
pipe_dis = pipe_dis.to(device)
|
894 |
pipe_matting = pipe_matting.to(device)
|
895 |
+
pipe_seg = pipe_seg.to(device)
|
896 |
+
pipe_disparity = pipe_disparity.to(device)
|
897 |
|
898 |
+
run_demo_server(pipe_depth, pipe_normal, pipe_dis, pipe_matting, pipe_seg, pipe_disparity)
|
899 |
|
900 |
|
901 |
if __name__ == "__main__":
|
empty_text_embed.npy
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:677e5e752b1d428a2e5f6f87a62c3a6c726343d264351dc1c433763ddc9b7182
|
3 |
-
size 157824
|
|
|
|
|
|
|
|
genpercept/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# What Matters When Repurposing Diffusion Models for General Dense Perception Tasks? (https://arxiv.org/abs/2403.06090)
|
3 |
+
# Github source: https://github.com/aim-uofa/GenPercept
|
4 |
+
# Copyright (c) 2024, Advanced Intelligent Machines (AIM)
|
5 |
+
# Licensed under The BSD 2-Clause License [see LICENSE for details]
|
6 |
+
# By Guangkai Xu
|
7 |
+
# Based on Marigold, diffusers codebases
|
8 |
+
# https://github.com/prs-eth/marigold
|
9 |
+
# https://github.com/huggingface/diffusers
|
10 |
+
# --------------------------------------------------------
|
11 |
+
|
12 |
+
|
13 |
+
from .genpercept_pipeline import GenPerceptPipeline, GenPerceptOutput
|
genpercept/customized_modules/ddim.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# What Matters When Repurposing Diffusion Models for General Dense Perception Tasks? (https://arxiv.org/abs/2403.06090)
|
3 |
+
# Github source: https://github.com/aim-uofa/GenPercept
|
4 |
+
# Copyright (c) 2024, Advanced Intelligent Machines (AIM)
|
5 |
+
# Licensed under The BSD 2-Clause License [see LICENSE for details]
|
6 |
+
# By Guangkai Xu
|
7 |
+
# Based on Marigold, diffusers codebases
|
8 |
+
# https://github.com/prs-eth/marigold
|
9 |
+
# https://github.com/huggingface/diffusers
|
10 |
+
# --------------------------------------------------------
|
11 |
+
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from typing import List, Optional, Tuple, Union
|
15 |
+
import numpy as np
|
16 |
+
from diffusers import DDIMScheduler, DDPMScheduler
|
17 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
18 |
+
|
19 |
+
|
20 |
+
def rescale_zero_terminal_snr(betas):
|
21 |
+
"""
|
22 |
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
23 |
+
|
24 |
+
|
25 |
+
Args:
|
26 |
+
betas (`torch.FloatTensor`):
|
27 |
+
the betas that the scheduler is being initialized with.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
31 |
+
"""
|
32 |
+
# Convert betas to alphas_bar_sqrt
|
33 |
+
alphas = 1.0 - betas
|
34 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
35 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
36 |
+
|
37 |
+
# Store old values.
|
38 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
39 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
40 |
+
|
41 |
+
# Shift so the last timestep is zero.
|
42 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
43 |
+
|
44 |
+
# Scale so the first timestep is back to the old value.
|
45 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
46 |
+
|
47 |
+
# Convert alphas_bar_sqrt to betas
|
48 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
49 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
50 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
51 |
+
betas = 1 - alphas
|
52 |
+
|
53 |
+
return betas
|
54 |
+
|
55 |
+
|
56 |
+
class DDPMSchedulerCustomized(DDPMScheduler):
|
57 |
+
|
58 |
+
@register_to_config
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
num_train_timesteps: int = 1000,
|
62 |
+
beta_start: float = 0.0001,
|
63 |
+
beta_end: float = 0.02,
|
64 |
+
beta_schedule: str = "linear",
|
65 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
66 |
+
variance_type: str = "fixed_small",
|
67 |
+
clip_sample: bool = True,
|
68 |
+
prediction_type: str = "epsilon",
|
69 |
+
thresholding: bool = False,
|
70 |
+
dynamic_thresholding_ratio: float = 0.995,
|
71 |
+
clip_sample_range: float = 1.0,
|
72 |
+
sample_max_value: float = 1.0,
|
73 |
+
timestep_spacing: str = "leading",
|
74 |
+
steps_offset: int = 0,
|
75 |
+
rescale_betas_zero_snr: int = False,
|
76 |
+
power_beta_curve = 1.0,
|
77 |
+
):
|
78 |
+
|
79 |
+
if trained_betas is not None:
|
80 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
81 |
+
elif beta_schedule == "linear":
|
82 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
83 |
+
elif beta_schedule == "scaled_linear":
|
84 |
+
# this schedule is very specific to the latent diffusion model.
|
85 |
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
86 |
+
elif beta_schedule == "scaled_linear_power":
|
87 |
+
self.betas = torch.linspace(beta_start**(1/power_beta_curve), beta_end**(1/power_beta_curve), num_train_timesteps, dtype=torch.float32) ** power_beta_curve
|
88 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
89 |
+
# Glide cosine schedule
|
90 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
91 |
+
elif beta_schedule == "sigmoid":
|
92 |
+
# GeoDiff sigmoid schedule
|
93 |
+
betas = torch.linspace(-6, 6, num_train_timesteps)
|
94 |
+
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
95 |
+
else:
|
96 |
+
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
97 |
+
|
98 |
+
# Rescale for zero SNR
|
99 |
+
if rescale_betas_zero_snr:
|
100 |
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
101 |
+
|
102 |
+
self.alphas = 1.0 - self.betas
|
103 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
104 |
+
self.one = torch.tensor(1.0)
|
105 |
+
|
106 |
+
# standard deviation of the initial noise distribution
|
107 |
+
self.init_noise_sigma = 1.0
|
108 |
+
|
109 |
+
# setable values
|
110 |
+
self.custom_timesteps = False
|
111 |
+
self.num_inference_steps = None
|
112 |
+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
113 |
+
|
114 |
+
self.variance_type = variance_type
|
115 |
+
|
116 |
+
def get_velocity(
|
117 |
+
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
118 |
+
) -> torch.FloatTensor:
|
119 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
120 |
+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
|
121 |
+
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
|
122 |
+
timesteps = timesteps.to(sample.device)
|
123 |
+
|
124 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
125 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
126 |
+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
127 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
128 |
+
|
129 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
130 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
131 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
132 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
133 |
+
|
134 |
+
# import pdb
|
135 |
+
# pdb.set_trace()
|
136 |
+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
137 |
+
return velocity
|
138 |
+
|
139 |
+
class DDIMSchedulerCustomized(DDIMScheduler):
|
140 |
+
|
141 |
+
@register_to_config
|
142 |
+
def __init__(
|
143 |
+
self,
|
144 |
+
num_train_timesteps: int = 1000,
|
145 |
+
beta_start: float = 0.0001,
|
146 |
+
beta_end: float = 0.02,
|
147 |
+
beta_schedule: str = "linear",
|
148 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
149 |
+
clip_sample: bool = True,
|
150 |
+
set_alpha_to_one: bool = True,
|
151 |
+
steps_offset: int = 0,
|
152 |
+
prediction_type: str = "epsilon",
|
153 |
+
thresholding: bool = False,
|
154 |
+
dynamic_thresholding_ratio: float = 0.995,
|
155 |
+
clip_sample_range: float = 1.0,
|
156 |
+
sample_max_value: float = 1.0,
|
157 |
+
timestep_spacing: str = "leading",
|
158 |
+
rescale_betas_zero_snr: bool = False,
|
159 |
+
power_beta_curve = 1.0,
|
160 |
+
):
|
161 |
+
if trained_betas is not None:
|
162 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
163 |
+
elif beta_schedule == "linear":
|
164 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
165 |
+
elif beta_schedule == "scaled_linear":
|
166 |
+
# this schedule is very specific to the latent diffusion model.
|
167 |
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
168 |
+
elif beta_schedule == "scaled_linear_power":
|
169 |
+
self.betas = torch.linspace(beta_start**(1/power_beta_curve), beta_end**(1/power_beta_curve), num_train_timesteps, dtype=torch.float32) ** power_beta_curve
|
170 |
+
self.power_beta_curve = power_beta_curve
|
171 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
172 |
+
# Glide cosine schedule
|
173 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
174 |
+
else:
|
175 |
+
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
176 |
+
|
177 |
+
# Rescale for zero SNR
|
178 |
+
if rescale_betas_zero_snr:
|
179 |
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
180 |
+
|
181 |
+
# self.betas = self.betas.double()
|
182 |
+
|
183 |
+
self.alphas = 1.0 - self.betas
|
184 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
185 |
+
|
186 |
+
# At every step in ddim, we are looking into the previous alphas_cumprod
|
187 |
+
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
188 |
+
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
189 |
+
# whether we use the final alpha of the "non-previous" one.
|
190 |
+
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
191 |
+
|
192 |
+
# standard deviation of the initial noise distribution
|
193 |
+
self.init_noise_sigma = 1.0
|
194 |
+
|
195 |
+
# setable values
|
196 |
+
self.num_inference_steps = None
|
197 |
+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
198 |
+
|
199 |
+
self.beta_schedule = beta_schedule
|
200 |
+
|
201 |
+
def _get_variance(self, timestep, prev_timestep):
|
202 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
203 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
204 |
+
beta_prod_t = 1 - alpha_prod_t
|
205 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
206 |
+
|
207 |
+
alpha_t_prev_to_t = self.alphas[(prev_timestep+1):(timestep+1)]
|
208 |
+
alpha_t_prev_to_t = torch.prod(alpha_t_prev_to_t)
|
209 |
+
|
210 |
+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_t_prev_to_t)
|
211 |
+
|
212 |
+
return variance
|
213 |
+
|
genpercept/genpercept_pipeline.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# What Matters When Repurposing Diffusion Models for General Dense Perception Tasks? (https://arxiv.org/abs/2403.06090)
|
3 |
+
# Github source: https://github.com/aim-uofa/GenPercept
|
4 |
+
# Copyright (c) 2024, Advanced Intelligent Machines (AIM)
|
5 |
+
# Licensed under The BSD 2-Clause License [see LICENSE for details]
|
6 |
+
# By Guangkai Xu
|
7 |
+
# Based on Marigold, diffusers codebases
|
8 |
+
# https://github.com/prs-eth/marigold
|
9 |
+
# https://github.com/huggingface/diffusers
|
10 |
+
# --------------------------------------------------------
|
11 |
+
|
12 |
+
|
13 |
+
import logging
|
14 |
+
from typing import Dict, Optional, Union
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
from diffusers import (
|
19 |
+
AutoencoderKL,
|
20 |
+
DDIMScheduler,
|
21 |
+
DiffusionPipeline,
|
22 |
+
LCMScheduler,
|
23 |
+
UNet2DConditionModel,
|
24 |
+
)
|
25 |
+
from diffusers.utils import BaseOutput
|
26 |
+
from PIL import Image
|
27 |
+
from torch.utils.data import DataLoader, TensorDataset
|
28 |
+
from torchvision.transforms import InterpolationMode
|
29 |
+
from torchvision.transforms.functional import pil_to_tensor, resize
|
30 |
+
from tqdm.auto import tqdm
|
31 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
32 |
+
|
33 |
+
from .util.batchsize import find_batch_size
|
34 |
+
from .util.ensemble import ensemble_depth
|
35 |
+
from .util.image_util import (
|
36 |
+
chw2hwc,
|
37 |
+
colorize_depth_maps,
|
38 |
+
get_tv_resample_method,
|
39 |
+
resize_max_res,
|
40 |
+
)
|
41 |
+
|
42 |
+
import matplotlib.pyplot as plt
|
43 |
+
from genpercept.models.dpt_head import DPTNeckHeadForUnetAfterUpsampleIdentity
|
44 |
+
|
45 |
+
|
46 |
+
class GenPerceptOutput(BaseOutput):
|
47 |
+
"""
|
48 |
+
Output class for GenPercept general perception pipeline.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
pred_np (`np.ndarray`):
|
52 |
+
Predicted result, with values in the range of [0, 1].
|
53 |
+
pred_colored (`PIL.Image.Image`):
|
54 |
+
Colorized result, with the shape of [3, H, W] and values in [0, 1].
|
55 |
+
"""
|
56 |
+
|
57 |
+
pred_np: np.ndarray
|
58 |
+
pred_colored: Union[None, Image.Image]
|
59 |
+
|
60 |
+
class GenPerceptPipeline(DiffusionPipeline):
|
61 |
+
"""
|
62 |
+
Pipeline for general perception using GenPercept: https://github.com/aim-uofa/GenPercept.
|
63 |
+
|
64 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
65 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
66 |
+
|
67 |
+
Args:
|
68 |
+
unet (`UNet2DConditionModel`):
|
69 |
+
Conditional U-Net to denoise the perception latent, conditioned on image latent.
|
70 |
+
vae (`AutoencoderKL`):
|
71 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images and results
|
72 |
+
to and from latent representations.
|
73 |
+
scheduler (`DDIMScheduler`):
|
74 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
75 |
+
text_encoder (`CLIPTextModel`):
|
76 |
+
Text-encoder, for empty text embedding.
|
77 |
+
tokenizer (`CLIPTokenizer`):
|
78 |
+
CLIP tokenizer.
|
79 |
+
default_denoising_steps (`int`, *optional*):
|
80 |
+
The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
|
81 |
+
quality with the given model. This value must be set in the model config. When the pipeline is called
|
82 |
+
without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
|
83 |
+
reasonable results with various model flavors compatible with the pipeline, such as those relying on very
|
84 |
+
short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
|
85 |
+
default_processing_resolution (`int`, *optional*):
|
86 |
+
The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
|
87 |
+
the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
|
88 |
+
default value is used. This is required to ensure reasonable results with various model flavors trained
|
89 |
+
with varying optimal processing resolution values.
|
90 |
+
"""
|
91 |
+
|
92 |
+
latent_scale_factor = 0.18215
|
93 |
+
|
94 |
+
def __init__(
|
95 |
+
self,
|
96 |
+
unet: UNet2DConditionModel,
|
97 |
+
vae: AutoencoderKL,
|
98 |
+
scheduler: Union[DDIMScheduler, LCMScheduler],
|
99 |
+
text_encoder: CLIPTextModel,
|
100 |
+
tokenizer: CLIPTokenizer,
|
101 |
+
default_denoising_steps: Optional[int] = 10,
|
102 |
+
default_processing_resolution: Optional[int] = 768,
|
103 |
+
rgb_blending = False,
|
104 |
+
customized_head = None,
|
105 |
+
genpercept_pipeline = True,
|
106 |
+
):
|
107 |
+
super().__init__()
|
108 |
+
|
109 |
+
self.genpercept_pipeline = genpercept_pipeline
|
110 |
+
|
111 |
+
if self.genpercept_pipeline:
|
112 |
+
default_denoising_steps = 1
|
113 |
+
rgb_blending = True
|
114 |
+
|
115 |
+
self.register_modules(
|
116 |
+
unet=unet,
|
117 |
+
customized_head=customized_head,
|
118 |
+
vae=vae,
|
119 |
+
scheduler=scheduler,
|
120 |
+
text_encoder=text_encoder,
|
121 |
+
tokenizer=tokenizer,
|
122 |
+
)
|
123 |
+
self.register_to_config(
|
124 |
+
default_denoising_steps=default_denoising_steps,
|
125 |
+
default_processing_resolution=default_processing_resolution,
|
126 |
+
rgb_blending=rgb_blending,
|
127 |
+
)
|
128 |
+
|
129 |
+
self.default_denoising_steps = default_denoising_steps
|
130 |
+
self.default_processing_resolution = default_processing_resolution
|
131 |
+
self.rgb_blending = rgb_blending
|
132 |
+
|
133 |
+
self.text_embed = None
|
134 |
+
|
135 |
+
self.customized_head = customized_head
|
136 |
+
|
137 |
+
if self.customized_head:
|
138 |
+
assert self.rgb_blending and self.scheduler.beta_start == 1 and self.scheduler.beta_end == 1
|
139 |
+
assert self.genpercept_pipeline
|
140 |
+
|
141 |
+
@torch.no_grad()
|
142 |
+
def __call__(
|
143 |
+
self,
|
144 |
+
input_image: Union[Image.Image, torch.Tensor],
|
145 |
+
denoising_steps: Optional[int] = None,
|
146 |
+
ensemble_size: int = 1,
|
147 |
+
processing_res: Optional[int] = None,
|
148 |
+
match_input_res: bool = True,
|
149 |
+
resample_method: str = "bilinear",
|
150 |
+
batch_size: int = 0,
|
151 |
+
generator: Union[torch.Generator, None] = None,
|
152 |
+
color_map: str = "Spectral",
|
153 |
+
show_progress_bar: bool = True,
|
154 |
+
ensemble_kwargs: Dict = None,
|
155 |
+
mode = None,
|
156 |
+
fix_timesteps = None,
|
157 |
+
prompt = "",
|
158 |
+
) -> GenPerceptOutput:
|
159 |
+
"""
|
160 |
+
Function invoked when calling the pipeline.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
input_image (`Image`):
|
164 |
+
Input RGB (or gray-scale) image.
|
165 |
+
denoising_steps (`int`, *optional*, defaults to `None`):
|
166 |
+
Number of denoising diffusion steps during inference. The default value `None` results in automatic
|
167 |
+
selection.
|
168 |
+
ensemble_size (`int`, *optional*, defaults to `10`):
|
169 |
+
Number of predictions to be ensembled.
|
170 |
+
processing_res (`int`, *optional*, defaults to `None`):
|
171 |
+
Effective processing resolution. When set to `0`, processes at the original image resolution. This
|
172 |
+
produces crisper predictions, but may also lead to the overall loss of global context. The default
|
173 |
+
value `None` resolves to the optimal value from the model config.
|
174 |
+
match_input_res (`bool`, *optional*, defaults to `True`):
|
175 |
+
Resize perception result to match input resolution.
|
176 |
+
Only valid if `processing_res` > 0.
|
177 |
+
resample_method: (`str`, *optional*, defaults to `bilinear`):
|
178 |
+
Resampling method used to resize images and perception results. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
|
179 |
+
batch_size (`int`, *optional*, defaults to `0`):
|
180 |
+
Inference batch size, no bigger than `num_ensemble`.
|
181 |
+
If set to 0, the script will automatically decide the proper batch size.
|
182 |
+
generator (`torch.Generator`, *optional*, defaults to `None`)
|
183 |
+
Random generator for initial noise generation.
|
184 |
+
show_progress_bar (`bool`, *optional*, defaults to `True`):
|
185 |
+
Display a progress bar of diffusion denoising.
|
186 |
+
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized result generation):
|
187 |
+
Colormap used to colorize the result.
|
188 |
+
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
|
189 |
+
Arguments for detailed ensembling settings.
|
190 |
+
Returns:
|
191 |
+
`GenPerceptOutput`: Output class for GenPercept general perception pipeline, including:
|
192 |
+
- **pred_np** (`np.ndarray`) Predicted result, with values in the range of [0, 1]
|
193 |
+
- **pred_colored** (`PIL.Image.Image`) Colorized result, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
|
194 |
+
"""
|
195 |
+
assert mode is not None, "mode of GenPerceptPipeline can be chosen from ['depth', 'normal', 'seg', 'matting', 'dis']."
|
196 |
+
self.mode = mode
|
197 |
+
|
198 |
+
# Model-specific optimal default values leading to fast and reasonable results.
|
199 |
+
if denoising_steps is None:
|
200 |
+
denoising_steps = self.default_denoising_steps
|
201 |
+
if processing_res is None:
|
202 |
+
processing_res = self.default_processing_resolution
|
203 |
+
|
204 |
+
assert processing_res >= 0
|
205 |
+
assert ensemble_size >= 1
|
206 |
+
|
207 |
+
if self.genpercept_pipeline:
|
208 |
+
assert ensemble_size == 1
|
209 |
+
assert denoising_steps == 1
|
210 |
+
else:
|
211 |
+
# Check if denoising step is reasonable
|
212 |
+
self._check_inference_step(denoising_steps)
|
213 |
+
|
214 |
+
resample_method: InterpolationMode = get_tv_resample_method(resample_method)
|
215 |
+
|
216 |
+
# ----------------- Image Preprocess -----------------
|
217 |
+
# Convert to torch tensor
|
218 |
+
if isinstance(input_image, Image.Image):
|
219 |
+
input_image = input_image.convert("RGB")
|
220 |
+
# convert to torch tensor [H, W, rgb] -> [rgb, H, W]
|
221 |
+
rgb = pil_to_tensor(input_image)
|
222 |
+
rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
|
223 |
+
elif isinstance(input_image, torch.Tensor):
|
224 |
+
rgb = input_image
|
225 |
+
else:
|
226 |
+
raise TypeError(f"Unknown input type: {type(input_image) = }")
|
227 |
+
input_size = rgb.shape
|
228 |
+
assert (
|
229 |
+
4 == rgb.dim() and 3 == input_size[-3]
|
230 |
+
), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
|
231 |
+
|
232 |
+
# Resize image
|
233 |
+
if processing_res > 0:
|
234 |
+
rgb = resize_max_res(
|
235 |
+
rgb,
|
236 |
+
max_edge_resolution=processing_res,
|
237 |
+
resample_method=resample_method,
|
238 |
+
)
|
239 |
+
|
240 |
+
# Normalize rgb values
|
241 |
+
rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
|
242 |
+
rgb_norm = rgb_norm.to(self.dtype)
|
243 |
+
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
|
244 |
+
|
245 |
+
# ----------------- Perception Inference -----------------
|
246 |
+
# Batch repeated input image
|
247 |
+
duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
|
248 |
+
single_rgb_dataset = TensorDataset(duplicated_rgb)
|
249 |
+
if batch_size > 0:
|
250 |
+
_bs = batch_size
|
251 |
+
else:
|
252 |
+
_bs = find_batch_size(
|
253 |
+
ensemble_size=ensemble_size,
|
254 |
+
input_res=max(rgb_norm.shape[1:]),
|
255 |
+
dtype=self.dtype,
|
256 |
+
)
|
257 |
+
|
258 |
+
single_rgb_loader = DataLoader(
|
259 |
+
single_rgb_dataset, batch_size=_bs, shuffle=False
|
260 |
+
)
|
261 |
+
|
262 |
+
# Predict results (batched)
|
263 |
+
pipe_pred_ls = []
|
264 |
+
if show_progress_bar:
|
265 |
+
iterable = tqdm(
|
266 |
+
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
|
267 |
+
)
|
268 |
+
else:
|
269 |
+
iterable = single_rgb_loader
|
270 |
+
for batch in iterable:
|
271 |
+
(batched_img,) = batch
|
272 |
+
pipe_pred_raw = self.single_infer(
|
273 |
+
rgb_in=batched_img,
|
274 |
+
num_inference_steps=denoising_steps,
|
275 |
+
show_pbar=show_progress_bar,
|
276 |
+
generator=generator,
|
277 |
+
fix_timesteps=fix_timesteps,
|
278 |
+
prompt=prompt,
|
279 |
+
)
|
280 |
+
pipe_pred_ls.append(pipe_pred_raw.detach())
|
281 |
+
pipe_preds = torch.concat(pipe_pred_ls, dim=0)
|
282 |
+
torch.cuda.empty_cache() # clear vram cache for ensembling
|
283 |
+
|
284 |
+
# ----------------- Test-time ensembling -----------------
|
285 |
+
if ensemble_size > 1:
|
286 |
+
pipe_pred, _ = ensemble_depth(
|
287 |
+
pipe_preds,
|
288 |
+
scale_invariant=True,
|
289 |
+
shift_invariant=True,
|
290 |
+
max_res=50,
|
291 |
+
**(ensemble_kwargs or {}),
|
292 |
+
)
|
293 |
+
else:
|
294 |
+
pipe_pred = pipe_preds
|
295 |
+
|
296 |
+
# Resize back to original resolution
|
297 |
+
if match_input_res:
|
298 |
+
pipe_pred = resize(
|
299 |
+
pipe_pred,
|
300 |
+
input_size[-2:],
|
301 |
+
interpolation=resample_method,
|
302 |
+
antialias=True,
|
303 |
+
)
|
304 |
+
|
305 |
+
# Convert to numpy
|
306 |
+
pipe_pred = pipe_pred.squeeze()
|
307 |
+
pipe_pred = pipe_pred.cpu().numpy()
|
308 |
+
|
309 |
+
# Clip output range
|
310 |
+
pipe_pred = pipe_pred.clip(0, 1)
|
311 |
+
|
312 |
+
# Colorize
|
313 |
+
if color_map is not None:
|
314 |
+
assert self.mode == 'depth'
|
315 |
+
pred_colored = colorize_depth_maps(
|
316 |
+
pipe_pred, 0, 1, cmap=color_map
|
317 |
+
).squeeze() # [3, H, W], value in (0, 1)
|
318 |
+
pred_colored = (pred_colored * 255).astype(np.uint8)
|
319 |
+
pred_colored_hwc = chw2hwc(pred_colored)
|
320 |
+
pred_colored_img = Image.fromarray(pred_colored_hwc)
|
321 |
+
else:
|
322 |
+
pred_colored_img = None
|
323 |
+
|
324 |
+
if len(pipe_pred.shape) == 3 and pipe_pred.shape[0] == 3:
|
325 |
+
pipe_pred = np.transpose(pipe_pred, (1, 2, 0))
|
326 |
+
|
327 |
+
return GenPerceptOutput(
|
328 |
+
pred_np=pipe_pred,
|
329 |
+
pred_colored=pred_colored_img,
|
330 |
+
)
|
331 |
+
|
332 |
+
def _check_inference_step(self, n_step: int) -> None:
|
333 |
+
"""
|
334 |
+
Check if denoising step is reasonable
|
335 |
+
Args:
|
336 |
+
n_step (`int`): denoising steps
|
337 |
+
"""
|
338 |
+
assert n_step >= 1
|
339 |
+
|
340 |
+
if isinstance(self.scheduler, DDIMScheduler):
|
341 |
+
if n_step < 10:
|
342 |
+
logging.warning(
|
343 |
+
f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference."
|
344 |
+
)
|
345 |
+
elif isinstance(self.scheduler, LCMScheduler):
|
346 |
+
if not 1 <= n_step <= 4:
|
347 |
+
logging.warning(
|
348 |
+
f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps."
|
349 |
+
)
|
350 |
+
else:
|
351 |
+
raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
|
352 |
+
|
353 |
+
def encode_text(self, prompt):
|
354 |
+
"""
|
355 |
+
Encode text embedding for empty prompt
|
356 |
+
"""
|
357 |
+
text_inputs = self.tokenizer(
|
358 |
+
prompt,
|
359 |
+
padding="do_not_pad",
|
360 |
+
max_length=self.tokenizer.model_max_length,
|
361 |
+
truncation=True,
|
362 |
+
return_tensors="pt",
|
363 |
+
)
|
364 |
+
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
|
365 |
+
self.text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
|
366 |
+
|
367 |
+
@torch.no_grad()
|
368 |
+
def single_infer(
|
369 |
+
self,
|
370 |
+
rgb_in: torch.Tensor,
|
371 |
+
num_inference_steps: int,
|
372 |
+
generator: Union[torch.Generator, None],
|
373 |
+
show_pbar: bool,
|
374 |
+
fix_timesteps = None,
|
375 |
+
prompt = "",
|
376 |
+
) -> torch.Tensor:
|
377 |
+
"""
|
378 |
+
Perform an individual perception result without ensembling.
|
379 |
+
|
380 |
+
Args:
|
381 |
+
rgb_in (`torch.Tensor`):
|
382 |
+
Input RGB image.
|
383 |
+
num_inference_steps (`int`):
|
384 |
+
Number of diffusion denoisign steps (DDIM) during inference.
|
385 |
+
show_pbar (`bool`):
|
386 |
+
Display a progress bar of diffusion denoising.
|
387 |
+
generator (`torch.Generator`)
|
388 |
+
Random generator for initial noise generation.
|
389 |
+
Returns:
|
390 |
+
`torch.Tensor`: Predicted result.
|
391 |
+
"""
|
392 |
+
device = self.device
|
393 |
+
rgb_in = rgb_in.to(device)
|
394 |
+
|
395 |
+
# Set timesteps
|
396 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
397 |
+
|
398 |
+
if fix_timesteps:
|
399 |
+
timesteps = torch.tensor([fix_timesteps]).long().repeat(self.scheduler.timesteps.shape[0]).to(device)
|
400 |
+
else:
|
401 |
+
timesteps = self.scheduler.timesteps # [T]
|
402 |
+
|
403 |
+
# Encode image
|
404 |
+
rgb_latent = self.encode_rgb(rgb_in)
|
405 |
+
|
406 |
+
if not (self.rgb_blending or self.genpercept_pipeline):
|
407 |
+
# Initial result (noise)
|
408 |
+
pred_latent = torch.randn(
|
409 |
+
rgb_latent.shape,
|
410 |
+
device=device,
|
411 |
+
dtype=self.dtype,
|
412 |
+
generator=generator,
|
413 |
+
) # [B, 4, h, w]
|
414 |
+
else:
|
415 |
+
pred_latent = rgb_latent
|
416 |
+
|
417 |
+
# Batched empty text embedding
|
418 |
+
if self.text_embed is None:
|
419 |
+
self.encode_text(prompt)
|
420 |
+
batch_text_embed = self.text_embed.repeat(
|
421 |
+
(rgb_latent.shape[0], 1, 1)
|
422 |
+
).to(device) # [B, 2, 1024]
|
423 |
+
|
424 |
+
# Denoising loop
|
425 |
+
if show_pbar:
|
426 |
+
iterable = tqdm(
|
427 |
+
enumerate(timesteps),
|
428 |
+
total=len(timesteps),
|
429 |
+
leave=False,
|
430 |
+
desc=" " * 4 + "Diffusion denoising",
|
431 |
+
)
|
432 |
+
else:
|
433 |
+
iterable = enumerate(timesteps)
|
434 |
+
|
435 |
+
if not self.customized_head:
|
436 |
+
for i, t in iterable:
|
437 |
+
if self.genpercept_pipeline and i > 0:
|
438 |
+
assert ValueError, "GenPercept only forward once."
|
439 |
+
|
440 |
+
if not (self.rgb_blending or self.genpercept_pipeline):
|
441 |
+
unet_input = torch.cat(
|
442 |
+
[rgb_latent, pred_latent], dim=1
|
443 |
+
) # this order is important
|
444 |
+
else:
|
445 |
+
unet_input = pred_latent
|
446 |
+
|
447 |
+
# predict the noise residual
|
448 |
+
noise_pred = self.unet(
|
449 |
+
unet_input, t, encoder_hidden_states=batch_text_embed
|
450 |
+
).sample # [B, 4, h, w]
|
451 |
+
|
452 |
+
# compute the previous noisy sample x_t -> x_t-1
|
453 |
+
step_output = self.scheduler.step(
|
454 |
+
noise_pred, t, pred_latent, generator=generator
|
455 |
+
)
|
456 |
+
pred_latent = step_output.prev_sample
|
457 |
+
|
458 |
+
pred_latent = step_output.pred_original_sample # NOTE: for GenPercept, it is equivalent to "pred_latent = - noise_pred"
|
459 |
+
|
460 |
+
pred = self.decode_pred(pred_latent)
|
461 |
+
|
462 |
+
# clip prediction
|
463 |
+
pred = torch.clip(pred, -1.0, 1.0)
|
464 |
+
# shift to [0, 1]
|
465 |
+
pred = (pred + 1.0) / 2.0
|
466 |
+
|
467 |
+
elif isinstance(self.customized_head, DPTNeckHeadForUnetAfterUpsampleIdentity):
|
468 |
+
unet_input = pred_latent
|
469 |
+
model_pred_output = self.unet(
|
470 |
+
unet_input, timesteps, encoder_hidden_states=batch_text_embed, return_feature=True
|
471 |
+
) # [B, 4, h, w]
|
472 |
+
unet_features = model_pred_output.multi_level_feats[::-1]
|
473 |
+
pred = self.customized_head(hidden_states=unet_features).prediction[:, None]
|
474 |
+
# shift to [0, 1]
|
475 |
+
pred = (pred - pred.min()) / (pred.max() - pred.min())
|
476 |
+
else:
|
477 |
+
raise ValueError
|
478 |
+
|
479 |
+
return pred
|
480 |
+
|
481 |
+
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
|
482 |
+
"""
|
483 |
+
Encode RGB image into latent.
|
484 |
+
|
485 |
+
Args:
|
486 |
+
rgb_in (`torch.Tensor`):
|
487 |
+
Input RGB image to be encoded.
|
488 |
+
|
489 |
+
Returns:
|
490 |
+
`torch.Tensor`: Image latent.
|
491 |
+
"""
|
492 |
+
# encode
|
493 |
+
h = self.vae.encoder(rgb_in)
|
494 |
+
moments = self.vae.quant_conv(h)
|
495 |
+
mean, logvar = torch.chunk(moments, 2, dim=1)
|
496 |
+
# scale latent
|
497 |
+
rgb_latent = mean * self.latent_scale_factor
|
498 |
+
return rgb_latent
|
499 |
+
|
500 |
+
def decode_pred(self, pred_latent: torch.Tensor) -> torch.Tensor:
|
501 |
+
"""
|
502 |
+
Decode pred latent into result.
|
503 |
+
|
504 |
+
Args:
|
505 |
+
pred_latent (`torch.Tensor`):
|
506 |
+
pred latent to be decoded.
|
507 |
+
|
508 |
+
Returns:
|
509 |
+
`torch.Tensor`: Decoded result.
|
510 |
+
"""
|
511 |
+
# scale latent
|
512 |
+
pred_latent = pred_latent / self.latent_scale_factor
|
513 |
+
# decode
|
514 |
+
z = self.vae.post_quant_conv(pred_latent)
|
515 |
+
stacked = self.vae.decoder(z)
|
516 |
+
if self.mode in ['depth', 'matting', 'dis']:
|
517 |
+
# mean of output channels
|
518 |
+
stacked = stacked.mean(dim=1, keepdim=True)
|
519 |
+
return stacked
|
genpercept/models/custom_unet.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# What Matters When Repurposing Diffusion Models for General Dense Perception Tasks? (https://arxiv.org/abs/2403.06090)
|
3 |
+
# Github source: https://github.com/aim-uofa/GenPercept
|
4 |
+
# Copyright (c) 2024, Advanced Intelligent Machines (AIM)
|
5 |
+
# Licensed under The BSD 2-Clause License [see LICENSE for details]
|
6 |
+
# By Guangkai Xu
|
7 |
+
# Based on diffusers codebases
|
8 |
+
# https://github.com/huggingface/diffusers
|
9 |
+
# --------------------------------------------------------
|
10 |
+
|
11 |
+
from diffusers import UNet2DConditionModel
|
12 |
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
13 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
14 |
+
import torch
|
15 |
+
import torch.utils.checkpoint
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class CustomUNet2DConditionOutput(BaseOutput):
|
21 |
+
"""
|
22 |
+
The output of [`UNet2DConditionModel`].
|
23 |
+
|
24 |
+
Args:
|
25 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
26 |
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
27 |
+
"""
|
28 |
+
|
29 |
+
sample: torch.FloatTensor = None
|
30 |
+
multi_level_feats: [torch.FloatTensor] = None
|
31 |
+
|
32 |
+
class CustomUNet2DConditionModel(UNet2DConditionModel):
|
33 |
+
|
34 |
+
def forward(
|
35 |
+
self,
|
36 |
+
sample: torch.FloatTensor,
|
37 |
+
timestep: Union[torch.Tensor, float, int],
|
38 |
+
encoder_hidden_states: torch.Tensor,
|
39 |
+
class_labels: Optional[torch.Tensor] = None,
|
40 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
41 |
+
attention_mask: Optional[torch.Tensor] = None,
|
42 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
43 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
44 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
45 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
46 |
+
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
47 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
48 |
+
return_feature: bool = False,
|
49 |
+
return_dict: bool = True,
|
50 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
51 |
+
r"""
|
52 |
+
The [`UNet2DConditionModel`] forward method.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
sample (`torch.FloatTensor`):
|
56 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
57 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
58 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
59 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
60 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
61 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
62 |
+
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
63 |
+
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
64 |
+
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
65 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
66 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
67 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
68 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
69 |
+
cross_attention_kwargs (`dict`, *optional*):
|
70 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
71 |
+
`self.processor` in
|
72 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
73 |
+
added_cond_kwargs: (`dict`, *optional*):
|
74 |
+
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
75 |
+
are passed along to the UNet blocks.
|
76 |
+
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
77 |
+
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
78 |
+
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
79 |
+
A tensor that if specified is added to the residual of the middle unet block.
|
80 |
+
encoder_attention_mask (`torch.Tensor`):
|
81 |
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
82 |
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
83 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
84 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
85 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
86 |
+
tuple.
|
87 |
+
cross_attention_kwargs (`dict`, *optional*):
|
88 |
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
89 |
+
added_cond_kwargs: (`dict`, *optional*):
|
90 |
+
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
91 |
+
are passed along to the UNet blocks.
|
92 |
+
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
93 |
+
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
|
94 |
+
example from ControlNet side model(s)
|
95 |
+
mid_block_additional_residual (`torch.Tensor`, *optional*):
|
96 |
+
additional residual to be added to UNet mid block output, for example from ControlNet side model
|
97 |
+
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
98 |
+
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
102 |
+
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
103 |
+
a `tuple` is returned where the first element is the sample tensor.
|
104 |
+
"""
|
105 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
106 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
107 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
108 |
+
# on the fly if necessary.
|
109 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
110 |
+
|
111 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
112 |
+
forward_upsample_size = False
|
113 |
+
upsample_size = None
|
114 |
+
|
115 |
+
for dim in sample.shape[-2:]:
|
116 |
+
if dim % default_overall_up_factor != 0:
|
117 |
+
# Forward upsample size to force interpolation output size.
|
118 |
+
forward_upsample_size = True
|
119 |
+
break
|
120 |
+
|
121 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
122 |
+
# expects mask of shape:
|
123 |
+
# [batch, key_tokens]
|
124 |
+
# adds singleton query_tokens dimension:
|
125 |
+
# [batch, 1, key_tokens]
|
126 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
127 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
128 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
129 |
+
if attention_mask is not None:
|
130 |
+
# assume that mask is expressed as:
|
131 |
+
# (1 = keep, 0 = discard)
|
132 |
+
# convert mask into a bias that can be added to attention scores:
|
133 |
+
# (keep = +0, discard = -10000.0)
|
134 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
135 |
+
attention_mask = attention_mask.unsqueeze(1)
|
136 |
+
|
137 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
138 |
+
if encoder_attention_mask is not None:
|
139 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
140 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
141 |
+
|
142 |
+
# 0. center input if necessary
|
143 |
+
if self.config.center_input_sample:
|
144 |
+
sample = 2 * sample - 1.0
|
145 |
+
|
146 |
+
# 1. time
|
147 |
+
timesteps = timestep
|
148 |
+
if not torch.is_tensor(timesteps):
|
149 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
150 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
151 |
+
is_mps = sample.device.type == "mps"
|
152 |
+
if isinstance(timestep, float):
|
153 |
+
dtype = torch.float32 if is_mps else torch.float64
|
154 |
+
else:
|
155 |
+
dtype = torch.int32 if is_mps else torch.int64
|
156 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
157 |
+
elif len(timesteps.shape) == 0:
|
158 |
+
timesteps = timesteps[None].to(sample.device)
|
159 |
+
|
160 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
161 |
+
timesteps = timesteps.expand(sample.shape[0])
|
162 |
+
|
163 |
+
t_emb = self.time_proj(timesteps)
|
164 |
+
|
165 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
166 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
167 |
+
# there might be better ways to encapsulate this.
|
168 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
169 |
+
|
170 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
171 |
+
aug_emb = None
|
172 |
+
|
173 |
+
if self.class_embedding is not None:
|
174 |
+
if class_labels is None:
|
175 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
176 |
+
|
177 |
+
if self.config.class_embed_type == "timestep":
|
178 |
+
class_labels = self.time_proj(class_labels)
|
179 |
+
|
180 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
181 |
+
# there might be better ways to encapsulate this.
|
182 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
183 |
+
|
184 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
185 |
+
|
186 |
+
if self.config.class_embeddings_concat:
|
187 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
188 |
+
else:
|
189 |
+
emb = emb + class_emb
|
190 |
+
|
191 |
+
if self.config.addition_embed_type == "text":
|
192 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
193 |
+
elif self.config.addition_embed_type == "text_image":
|
194 |
+
# Kandinsky 2.1 - style
|
195 |
+
if "image_embeds" not in added_cond_kwargs:
|
196 |
+
raise ValueError(
|
197 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
198 |
+
)
|
199 |
+
|
200 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
201 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
202 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
203 |
+
elif self.config.addition_embed_type == "text_time":
|
204 |
+
# SDXL - style
|
205 |
+
if "text_embeds" not in added_cond_kwargs:
|
206 |
+
raise ValueError(
|
207 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
208 |
+
)
|
209 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
210 |
+
if "time_ids" not in added_cond_kwargs:
|
211 |
+
raise ValueError(
|
212 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
213 |
+
)
|
214 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
215 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
216 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
217 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
218 |
+
add_embeds = add_embeds.to(emb.dtype)
|
219 |
+
aug_emb = self.add_embedding(add_embeds)
|
220 |
+
elif self.config.addition_embed_type == "image":
|
221 |
+
# Kandinsky 2.2 - style
|
222 |
+
if "image_embeds" not in added_cond_kwargs:
|
223 |
+
raise ValueError(
|
224 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
225 |
+
)
|
226 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
227 |
+
aug_emb = self.add_embedding(image_embs)
|
228 |
+
elif self.config.addition_embed_type == "image_hint":
|
229 |
+
# Kandinsky 2.2 - style
|
230 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
231 |
+
raise ValueError(
|
232 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
233 |
+
)
|
234 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
235 |
+
hint = added_cond_kwargs.get("hint")
|
236 |
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
237 |
+
sample = torch.cat([sample, hint], dim=1)
|
238 |
+
|
239 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
240 |
+
|
241 |
+
if self.time_embed_act is not None:
|
242 |
+
emb = self.time_embed_act(emb)
|
243 |
+
|
244 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
245 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
246 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
247 |
+
# Kadinsky 2.1 - style
|
248 |
+
if "image_embeds" not in added_cond_kwargs:
|
249 |
+
raise ValueError(
|
250 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
251 |
+
)
|
252 |
+
|
253 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
254 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
255 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
256 |
+
# Kandinsky 2.2 - style
|
257 |
+
if "image_embeds" not in added_cond_kwargs:
|
258 |
+
raise ValueError(
|
259 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
260 |
+
)
|
261 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
262 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
263 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
264 |
+
if "image_embeds" not in added_cond_kwargs:
|
265 |
+
raise ValueError(
|
266 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
267 |
+
)
|
268 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
269 |
+
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
|
270 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
|
271 |
+
|
272 |
+
# 2. pre-process
|
273 |
+
sample = self.conv_in(sample)
|
274 |
+
|
275 |
+
# 2.5 GLIGEN position net
|
276 |
+
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
277 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
278 |
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
279 |
+
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
280 |
+
|
281 |
+
# 3. down
|
282 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
283 |
+
if USE_PEFT_BACKEND:
|
284 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
285 |
+
scale_lora_layers(self, lora_scale)
|
286 |
+
|
287 |
+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
288 |
+
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
289 |
+
is_adapter = down_intrablock_additional_residuals is not None
|
290 |
+
# maintain backward compatibility for legacy usage, where
|
291 |
+
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
292 |
+
# but can only use one or the other
|
293 |
+
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
|
294 |
+
deprecate(
|
295 |
+
"T2I should not use down_block_additional_residuals",
|
296 |
+
"1.3.0",
|
297 |
+
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
298 |
+
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
299 |
+
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
300 |
+
standard_warn=False,
|
301 |
+
)
|
302 |
+
down_intrablock_additional_residuals = down_block_additional_residuals
|
303 |
+
is_adapter = True
|
304 |
+
|
305 |
+
down_block_res_samples = (sample,)
|
306 |
+
for downsample_block in self.down_blocks:
|
307 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
308 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
309 |
+
additional_residuals = {}
|
310 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
311 |
+
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
312 |
+
|
313 |
+
sample, res_samples = downsample_block(
|
314 |
+
hidden_states=sample,
|
315 |
+
temb=emb,
|
316 |
+
encoder_hidden_states=encoder_hidden_states,
|
317 |
+
attention_mask=attention_mask,
|
318 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
319 |
+
encoder_attention_mask=encoder_attention_mask,
|
320 |
+
**additional_residuals,
|
321 |
+
)
|
322 |
+
else:
|
323 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
|
324 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
325 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
326 |
+
|
327 |
+
down_block_res_samples += res_samples
|
328 |
+
|
329 |
+
if is_controlnet:
|
330 |
+
new_down_block_res_samples = ()
|
331 |
+
|
332 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
333 |
+
down_block_res_samples, down_block_additional_residuals
|
334 |
+
):
|
335 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
336 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
337 |
+
|
338 |
+
down_block_res_samples = new_down_block_res_samples
|
339 |
+
|
340 |
+
# 4. mid
|
341 |
+
if self.mid_block is not None:
|
342 |
+
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
343 |
+
sample = self.mid_block(
|
344 |
+
sample,
|
345 |
+
emb,
|
346 |
+
encoder_hidden_states=encoder_hidden_states,
|
347 |
+
attention_mask=attention_mask,
|
348 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
349 |
+
encoder_attention_mask=encoder_attention_mask,
|
350 |
+
)
|
351 |
+
else:
|
352 |
+
sample = self.mid_block(sample, emb)
|
353 |
+
|
354 |
+
# To support T2I-Adapter-XL
|
355 |
+
if (
|
356 |
+
is_adapter
|
357 |
+
and len(down_intrablock_additional_residuals) > 0
|
358 |
+
and sample.shape == down_intrablock_additional_residuals[0].shape
|
359 |
+
):
|
360 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
361 |
+
|
362 |
+
if is_controlnet:
|
363 |
+
sample = sample + mid_block_additional_residual
|
364 |
+
|
365 |
+
multi_level_feats = []
|
366 |
+
# 1, 1280, 24, 24
|
367 |
+
# multi_level_feats.append(sample) # 1/64
|
368 |
+
# 5. up
|
369 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
370 |
+
is_final_block = i == len(self.up_blocks) - 1
|
371 |
+
|
372 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
373 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
374 |
+
|
375 |
+
# if we have not reached the final block and need to forward the
|
376 |
+
# upsample size, we do it here
|
377 |
+
if not is_final_block and forward_upsample_size:
|
378 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
379 |
+
|
380 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
381 |
+
sample = upsample_block(
|
382 |
+
hidden_states=sample,
|
383 |
+
temb=emb,
|
384 |
+
res_hidden_states_tuple=res_samples,
|
385 |
+
encoder_hidden_states=encoder_hidden_states,
|
386 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
387 |
+
upsample_size=upsample_size,
|
388 |
+
attention_mask=attention_mask,
|
389 |
+
encoder_attention_mask=encoder_attention_mask,
|
390 |
+
)
|
391 |
+
else:
|
392 |
+
sample = upsample_block(
|
393 |
+
hidden_states=sample,
|
394 |
+
temb=emb,
|
395 |
+
res_hidden_states_tuple=res_samples,
|
396 |
+
upsample_size=upsample_size,
|
397 |
+
scale=lora_scale,
|
398 |
+
)
|
399 |
+
# if not is_final_block:
|
400 |
+
multi_level_feats.append(sample)
|
401 |
+
|
402 |
+
if return_feature:
|
403 |
+
if USE_PEFT_BACKEND:
|
404 |
+
# remove `lora_scale` from each PEFT layer
|
405 |
+
unscale_lora_layers(self, lora_scale)
|
406 |
+
return CustomUNet2DConditionOutput(
|
407 |
+
multi_level_feats=multi_level_feats,
|
408 |
+
)
|
409 |
+
|
410 |
+
# 6. post-process
|
411 |
+
if self.conv_norm_out:
|
412 |
+
sample = self.conv_norm_out(sample)
|
413 |
+
sample = self.conv_act(sample)
|
414 |
+
|
415 |
+
sample = self.conv_out(sample)
|
416 |
+
|
417 |
+
if USE_PEFT_BACKEND:
|
418 |
+
# remove `lora_scale` from each PEFT layer
|
419 |
+
unscale_lora_layers(self, lora_scale)
|
420 |
+
|
421 |
+
if not return_dict:
|
422 |
+
return (sample,)
|
423 |
+
|
424 |
+
return CustomUNet2DConditionOutput(
|
425 |
+
sample=sample,
|
426 |
+
multi_level_feats=multi_level_feats,
|
427 |
+
)
|
genpercept/models/dpt_head.py
ADDED
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# What Matters When Repurposing Diffusion Models for General Dense Perception Tasks? (https://arxiv.org/abs/2403.06090)
|
3 |
+
# Github source: https://github.com/aim-uofa/GenPercept
|
4 |
+
# Copyright (c) 2024, Advanced Intelligent Machines (AIM)
|
5 |
+
# Licensed under The BSD 2-Clause License [see LICENSE for details]
|
6 |
+
# By Guangkai Xu
|
7 |
+
# Based on diffusers codebases
|
8 |
+
# https://github.com/huggingface/diffusers
|
9 |
+
# --------------------------------------------------------
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from typing import List, Optional, Tuple, Union
|
14 |
+
from transformers import DPTPreTrainedModel
|
15 |
+
|
16 |
+
from transformers.utils import ModelOutput
|
17 |
+
from transformers.file_utils import replace_return_docstrings, add_start_docstrings_to_model_forward
|
18 |
+
from transformers.models.dpt.modeling_dpt import DPTReassembleStage
|
19 |
+
|
20 |
+
from diffusers.models.lora import LoRACompatibleConv
|
21 |
+
from diffusers.utils import USE_PEFT_BACKEND
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
class DepthEstimatorOutput(ModelOutput):
|
25 |
+
"""
|
26 |
+
Base class for outputs of depth estimation models.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
30 |
+
Classification (or regression if config.num_labels==1) loss.
|
31 |
+
prediction (`torch.FloatTensor` of shape `(batch_size, height, width)`):
|
32 |
+
Predicted depth for each pixel.
|
33 |
+
|
34 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
35 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
36 |
+
one for the output of each layer) of shape `(batch_size, num_channels, height, width)`.
|
37 |
+
|
38 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
39 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
40 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
|
41 |
+
sequence_length)`.
|
42 |
+
|
43 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
44 |
+
heads.
|
45 |
+
"""
|
46 |
+
|
47 |
+
loss: Optional[torch.FloatTensor] = None
|
48 |
+
prediction: torch.FloatTensor = None
|
49 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
50 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
51 |
+
|
52 |
+
class DPTDepthEstimationHead(nn.Module):
|
53 |
+
"""
|
54 |
+
Output head head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
|
55 |
+
the predictions to the input resolution after the first convolutional layer (details can be found in the paper's
|
56 |
+
supplementary material).
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, config):
|
60 |
+
super().__init__()
|
61 |
+
|
62 |
+
self.config = config
|
63 |
+
|
64 |
+
self.projection = None
|
65 |
+
features = config.fusion_hidden_size
|
66 |
+
if config.add_projection:
|
67 |
+
self.projection = nn.Conv2d(features, features, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
68 |
+
|
69 |
+
self.head = nn.Sequential(
|
70 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
71 |
+
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
|
72 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
73 |
+
nn.ReLU(),
|
74 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
75 |
+
nn.ReLU(),
|
76 |
+
)
|
77 |
+
|
78 |
+
def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
|
79 |
+
# use last features
|
80 |
+
hidden_states = hidden_states[self.config.head_in_index]
|
81 |
+
|
82 |
+
if self.projection is not None:
|
83 |
+
hidden_states = self.projection(hidden_states)
|
84 |
+
hidden_states = nn.ReLU()(hidden_states)
|
85 |
+
|
86 |
+
predicted_depth = self.head(hidden_states)
|
87 |
+
|
88 |
+
predicted_depth = predicted_depth.squeeze(dim=1)
|
89 |
+
|
90 |
+
return predicted_depth
|
91 |
+
|
92 |
+
class Upsample2D(nn.Module):
|
93 |
+
"""A 2D upsampling layer with an optional convolution.
|
94 |
+
|
95 |
+
Parameters:
|
96 |
+
channels (`int`):
|
97 |
+
number of channels in the inputs and outputs.
|
98 |
+
use_conv (`bool`, default `False`):
|
99 |
+
option to use a convolution.
|
100 |
+
use_conv_transpose (`bool`, default `False`):
|
101 |
+
option to use a convolution transpose.
|
102 |
+
out_channels (`int`, optional):
|
103 |
+
number of output channels. Defaults to `channels`.
|
104 |
+
name (`str`, default `conv`):
|
105 |
+
name of the upsampling 2D layer.
|
106 |
+
"""
|
107 |
+
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
channels: int,
|
111 |
+
use_conv: bool = False,
|
112 |
+
use_conv_transpose: bool = False,
|
113 |
+
out_channels: Optional[int] = None,
|
114 |
+
name: str = "conv",
|
115 |
+
kernel_size: Optional[int] = None,
|
116 |
+
padding=1,
|
117 |
+
norm_type=None,
|
118 |
+
eps=None,
|
119 |
+
elementwise_affine=None,
|
120 |
+
bias=True,
|
121 |
+
interpolate=True,
|
122 |
+
):
|
123 |
+
super().__init__()
|
124 |
+
self.channels = channels
|
125 |
+
self.out_channels = out_channels or channels
|
126 |
+
self.use_conv = use_conv
|
127 |
+
self.use_conv_transpose = use_conv_transpose
|
128 |
+
self.name = name
|
129 |
+
self.interpolate = interpolate
|
130 |
+
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
131 |
+
|
132 |
+
if norm_type == "ln_norm":
|
133 |
+
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
134 |
+
elif norm_type == "rms_norm":
|
135 |
+
# self.norm = RMSNorm(channels, eps, elementwise_affine)
|
136 |
+
raise NotImplementedError
|
137 |
+
elif norm_type is None:
|
138 |
+
self.norm = None
|
139 |
+
else:
|
140 |
+
raise ValueError(f"unknown norm_type: {norm_type}")
|
141 |
+
|
142 |
+
conv = None
|
143 |
+
if use_conv_transpose:
|
144 |
+
if kernel_size is None:
|
145 |
+
kernel_size = 4
|
146 |
+
conv = nn.ConvTranspose2d(
|
147 |
+
channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias
|
148 |
+
)
|
149 |
+
elif use_conv:
|
150 |
+
if kernel_size is None:
|
151 |
+
kernel_size = 3
|
152 |
+
conv = conv_cls(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
|
153 |
+
|
154 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
155 |
+
if name == "conv":
|
156 |
+
self.conv = conv
|
157 |
+
else:
|
158 |
+
self.Conv2d_0 = conv
|
159 |
+
|
160 |
+
def forward(
|
161 |
+
self,
|
162 |
+
hidden_states: torch.FloatTensor,
|
163 |
+
output_size: Optional[int] = None,
|
164 |
+
scale: float = 1.0,
|
165 |
+
) -> torch.FloatTensor:
|
166 |
+
assert hidden_states.shape[1] == self.channels
|
167 |
+
|
168 |
+
if self.norm is not None:
|
169 |
+
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
170 |
+
|
171 |
+
if self.use_conv_transpose:
|
172 |
+
return self.conv(hidden_states)
|
173 |
+
|
174 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
175 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
176 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
177 |
+
dtype = hidden_states.dtype
|
178 |
+
if dtype == torch.bfloat16:
|
179 |
+
hidden_states = hidden_states.to(torch.float32)
|
180 |
+
|
181 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
182 |
+
if hidden_states.shape[0] >= 64:
|
183 |
+
hidden_states = hidden_states.contiguous()
|
184 |
+
|
185 |
+
# if `output_size` is passed we force the interpolation output
|
186 |
+
# size and do not make use of `scale_factor=2`
|
187 |
+
if self.interpolate:
|
188 |
+
if output_size is None:
|
189 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
190 |
+
else:
|
191 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
192 |
+
|
193 |
+
# If the input is bfloat16, we cast back to bfloat16
|
194 |
+
if dtype == torch.bfloat16:
|
195 |
+
hidden_states = hidden_states.to(dtype)
|
196 |
+
|
197 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
198 |
+
if self.use_conv:
|
199 |
+
if self.name == "conv":
|
200 |
+
if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
|
201 |
+
hidden_states = self.conv(hidden_states, scale)
|
202 |
+
else:
|
203 |
+
hidden_states = self.conv(hidden_states)
|
204 |
+
else:
|
205 |
+
if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
|
206 |
+
hidden_states = self.Conv2d_0(hidden_states, scale)
|
207 |
+
else:
|
208 |
+
hidden_states = self.Conv2d_0(hidden_states)
|
209 |
+
|
210 |
+
return hidden_states
|
211 |
+
|
212 |
+
|
213 |
+
class DPTPreActResidualLayer(nn.Module):
|
214 |
+
"""
|
215 |
+
ResidualConvUnit, pre-activate residual unit.
|
216 |
+
|
217 |
+
Args:
|
218 |
+
config (`[DPTConfig]`):
|
219 |
+
Model configuration class defining the model architecture.
|
220 |
+
"""
|
221 |
+
|
222 |
+
def __init__(self, config):
|
223 |
+
super().__init__()
|
224 |
+
|
225 |
+
self.use_batch_norm = config.use_batch_norm_in_fusion_residual
|
226 |
+
use_bias_in_fusion_residual = (
|
227 |
+
config.use_bias_in_fusion_residual
|
228 |
+
if config.use_bias_in_fusion_residual is not None
|
229 |
+
else not self.use_batch_norm
|
230 |
+
)
|
231 |
+
|
232 |
+
self.activation1 = nn.ReLU()
|
233 |
+
self.convolution1 = nn.Conv2d(
|
234 |
+
config.fusion_hidden_size,
|
235 |
+
config.fusion_hidden_size,
|
236 |
+
kernel_size=3,
|
237 |
+
stride=1,
|
238 |
+
padding=1,
|
239 |
+
bias=use_bias_in_fusion_residual,
|
240 |
+
)
|
241 |
+
|
242 |
+
self.activation2 = nn.ReLU()
|
243 |
+
self.convolution2 = nn.Conv2d(
|
244 |
+
config.fusion_hidden_size,
|
245 |
+
config.fusion_hidden_size,
|
246 |
+
kernel_size=3,
|
247 |
+
stride=1,
|
248 |
+
padding=1,
|
249 |
+
bias=use_bias_in_fusion_residual,
|
250 |
+
)
|
251 |
+
|
252 |
+
if self.use_batch_norm:
|
253 |
+
self.batch_norm1 = nn.BatchNorm2d(config.fusion_hidden_size)
|
254 |
+
self.batch_norm2 = nn.BatchNorm2d(config.fusion_hidden_size)
|
255 |
+
|
256 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
257 |
+
residual = hidden_state.clone()
|
258 |
+
hidden_state = self.activation1(hidden_state)
|
259 |
+
|
260 |
+
hidden_state = self.convolution1(hidden_state)
|
261 |
+
|
262 |
+
if self.use_batch_norm:
|
263 |
+
hidden_state = self.batch_norm1(hidden_state)
|
264 |
+
|
265 |
+
hidden_state = self.activation2(hidden_state)
|
266 |
+
hidden_state = self.convolution2(hidden_state)
|
267 |
+
|
268 |
+
if self.use_batch_norm:
|
269 |
+
hidden_state = self.batch_norm2(hidden_state)
|
270 |
+
|
271 |
+
return hidden_state + residual
|
272 |
+
|
273 |
+
|
274 |
+
class DPTFeatureFusionLayer(nn.Module):
|
275 |
+
"""Feature fusion layer, merges feature maps from different stages.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
config (`[DPTConfig]`):
|
279 |
+
Model configuration class defining the model architecture.
|
280 |
+
align_corners (`bool`, *optional*, defaults to `True`):
|
281 |
+
The align_corner setting for bilinear upsample.
|
282 |
+
"""
|
283 |
+
|
284 |
+
def __init__(self, config, align_corners=True, with_residual_1=True):
|
285 |
+
super().__init__()
|
286 |
+
|
287 |
+
self.align_corners = align_corners
|
288 |
+
|
289 |
+
self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
|
290 |
+
|
291 |
+
if with_residual_1:
|
292 |
+
self.residual_layer1 = DPTPreActResidualLayer(config)
|
293 |
+
self.residual_layer2 = DPTPreActResidualLayer(config)
|
294 |
+
|
295 |
+
def forward(self, hidden_state, residual=None):
|
296 |
+
if residual is not None:
|
297 |
+
if hidden_state.shape != residual.shape:
|
298 |
+
residual = nn.functional.interpolate(
|
299 |
+
residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False
|
300 |
+
)
|
301 |
+
hidden_state = hidden_state + self.residual_layer1(residual)
|
302 |
+
|
303 |
+
hidden_state = self.residual_layer2(hidden_state)
|
304 |
+
hidden_state = nn.functional.interpolate(
|
305 |
+
hidden_state, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
306 |
+
)
|
307 |
+
hidden_state = self.projection(hidden_state)
|
308 |
+
|
309 |
+
return hidden_state
|
310 |
+
|
311 |
+
|
312 |
+
class DPTFeatureFusionStage(nn.Module):
|
313 |
+
def __init__(self, config):
|
314 |
+
super().__init__()
|
315 |
+
self.layers = nn.ModuleList()
|
316 |
+
for i in range(len(config.neck_hidden_sizes)):
|
317 |
+
if i == 0:
|
318 |
+
self.layers.append(DPTFeatureFusionLayer(config, with_residual_1=False))
|
319 |
+
else:
|
320 |
+
self.layers.append(DPTFeatureFusionLayer(config))
|
321 |
+
|
322 |
+
def forward(self, hidden_states):
|
323 |
+
# reversing the hidden_states, we start from the last
|
324 |
+
hidden_states = hidden_states[::-1]
|
325 |
+
|
326 |
+
fused_hidden_states = []
|
327 |
+
# first layer only uses the last hidden_state
|
328 |
+
fused_hidden_state = self.layers[0](hidden_states[0])
|
329 |
+
fused_hidden_states.append(fused_hidden_state)
|
330 |
+
# looping from the last layer to the second
|
331 |
+
for hidden_state, layer in zip(hidden_states[1:], self.layers[1:]):
|
332 |
+
fused_hidden_state = layer(fused_hidden_state, hidden_state)
|
333 |
+
fused_hidden_states.append(fused_hidden_state)
|
334 |
+
|
335 |
+
return fused_hidden_states
|
336 |
+
|
337 |
+
|
338 |
+
class DPTNeck(nn.Module):
|
339 |
+
"""
|
340 |
+
DPTNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as
|
341 |
+
input and produces another list of tensors as output. For DPT, it includes 2 stages:
|
342 |
+
|
343 |
+
* DPTReassembleStage
|
344 |
+
* DPTFeatureFusionStage.
|
345 |
+
|
346 |
+
Args:
|
347 |
+
config (dict): config dict.
|
348 |
+
"""
|
349 |
+
|
350 |
+
def __init__(self, config):
|
351 |
+
super().__init__()
|
352 |
+
self.config = config
|
353 |
+
|
354 |
+
# postprocessing: only required in case of a non-hierarchical backbone (e.g. ViT, BEiT)
|
355 |
+
if config.backbone_config is not None and config.backbone_config.model_type in ["swinv2"]:
|
356 |
+
self.reassemble_stage = None
|
357 |
+
else:
|
358 |
+
self.reassemble_stage = DPTReassembleStage(config)
|
359 |
+
|
360 |
+
self.convs = nn.ModuleList()
|
361 |
+
for channel in config.neck_hidden_sizes:
|
362 |
+
self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))
|
363 |
+
|
364 |
+
# fusion
|
365 |
+
self.fusion_stage = DPTFeatureFusionStage(config)
|
366 |
+
|
367 |
+
def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]:
|
368 |
+
"""
|
369 |
+
Args:
|
370 |
+
hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
|
371 |
+
List of hidden states from the backbone.
|
372 |
+
"""
|
373 |
+
if not isinstance(hidden_states, (tuple, list)):
|
374 |
+
raise TypeError("hidden_states should be a tuple or list of tensors")
|
375 |
+
|
376 |
+
if len(hidden_states) != len(self.config.neck_hidden_sizes):
|
377 |
+
raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.")
|
378 |
+
|
379 |
+
# postprocess hidden states
|
380 |
+
if self.reassemble_stage is not None:
|
381 |
+
hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
|
382 |
+
|
383 |
+
features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
|
384 |
+
|
385 |
+
# fusion blocks
|
386 |
+
output = self.fusion_stage(features)
|
387 |
+
|
388 |
+
return output
|
389 |
+
|
390 |
+
|
391 |
+
DPT_INPUTS_DOCSTRING = r"""
|
392 |
+
Args:
|
393 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
394 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`]
|
395 |
+
for details.
|
396 |
+
|
397 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
398 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
399 |
+
|
400 |
+
- 1 indicates the head is **not masked**,
|
401 |
+
- 0 indicates the head is **masked**.
|
402 |
+
|
403 |
+
output_attentions (`bool`, *optional*):
|
404 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
405 |
+
tensors for more detail.
|
406 |
+
output_hidden_states (`bool`, *optional*):
|
407 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
408 |
+
more detail.
|
409 |
+
return_dict (`bool`, *optional*):
|
410 |
+
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
411 |
+
"""
|
412 |
+
|
413 |
+
_CONFIG_FOR_DOC = "DPTConfig"
|
414 |
+
|
415 |
+
|
416 |
+
class DPTNeckHeadForUnetAfterUpsample(DPTPreTrainedModel):
|
417 |
+
def __init__(self, config):
|
418 |
+
super().__init__(config)
|
419 |
+
|
420 |
+
# self.backbone = None
|
421 |
+
# if config.backbone_config is not None and config.is_hybrid is False:
|
422 |
+
# self.backbone = load_backbone(config)
|
423 |
+
# else:
|
424 |
+
# self.dpt = DPTModel(config, add_pooling_layer=False)
|
425 |
+
|
426 |
+
self.feature_upsample_0 = Upsample2D(channels=config.neck_hidden_sizes[0], use_conv=True)
|
427 |
+
# self.feature_upsample_1 = Upsample2D(channels=config.neck_hidden_sizes[1], use_conv=True)
|
428 |
+
# self.feature_upsample_2 = Upsample2D(channels=config.neck_hidden_sizes[2], use_conv=True)
|
429 |
+
# self.feature_upsample_3 = Upsample2D(channels=config.neck_hidden_sizes[3], use_conv=True)
|
430 |
+
|
431 |
+
# Neck
|
432 |
+
self.neck = DPTNeck(config)
|
433 |
+
self.neck.reassemble_stage = None
|
434 |
+
|
435 |
+
# Depth estimation head
|
436 |
+
self.head = DPTDepthEstimationHead(config)
|
437 |
+
|
438 |
+
# Initialize weights and apply final processing
|
439 |
+
self.post_init()
|
440 |
+
|
441 |
+
@add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING)
|
442 |
+
@replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC)
|
443 |
+
def forward(
|
444 |
+
self,
|
445 |
+
hidden_states,
|
446 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
447 |
+
labels: Optional[torch.LongTensor] = None,
|
448 |
+
output_attentions: Optional[bool] = None,
|
449 |
+
output_hidden_states: Optional[bool] = None,
|
450 |
+
return_depth_only: bool = False,
|
451 |
+
return_dict: Optional[bool] = None,
|
452 |
+
) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]:
|
453 |
+
r"""
|
454 |
+
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
455 |
+
Ground truth depth estimation maps for computing the loss.
|
456 |
+
|
457 |
+
Returns:
|
458 |
+
|
459 |
+
Examples:
|
460 |
+
```python
|
461 |
+
>>> from transformers import AutoImageProcessor, DPTForDepthEstimation
|
462 |
+
>>> import torch
|
463 |
+
>>> import numpy as np
|
464 |
+
>>> from PIL import Image
|
465 |
+
>>> import requests
|
466 |
+
|
467 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
468 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
469 |
+
|
470 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large")
|
471 |
+
>>> model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
|
472 |
+
|
473 |
+
>>> # prepare image for the model
|
474 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
475 |
+
|
476 |
+
>>> with torch.no_grad():
|
477 |
+
... outputs = model(**inputs)
|
478 |
+
... predicted_depth = outputs.predicted_depth
|
479 |
+
|
480 |
+
>>> # interpolate to original size
|
481 |
+
>>> prediction = torch.nn.functional.interpolate(
|
482 |
+
... predicted_depth.unsqueeze(1),
|
483 |
+
... size=image.size[::-1],
|
484 |
+
... mode="bicubic",
|
485 |
+
... align_corners=False,
|
486 |
+
... )
|
487 |
+
|
488 |
+
>>> # visualize the prediction
|
489 |
+
>>> output = prediction.squeeze().cpu().numpy()
|
490 |
+
>>> formatted = (output * 255 / np.max(output)).astype("uint8")
|
491 |
+
>>> depth = Image.fromarray(formatted)
|
492 |
+
```"""
|
493 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
494 |
+
output_hidden_states = (
|
495 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
496 |
+
)
|
497 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
498 |
+
|
499 |
+
# if self.backbone is not None:
|
500 |
+
# outputs = self.backbone.forward_with_filtered_kwargs(
|
501 |
+
# pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
|
502 |
+
# )
|
503 |
+
# hidden_states = outputs.feature_maps
|
504 |
+
# else:
|
505 |
+
# outputs = self.dpt(
|
506 |
+
# pixel_values,
|
507 |
+
# head_mask=head_mask,
|
508 |
+
# output_attentions=output_attentions,
|
509 |
+
# output_hidden_states=True, # we need the intermediate hidden states
|
510 |
+
# return_dict=return_dict,
|
511 |
+
# )
|
512 |
+
# hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
513 |
+
# # only keep certain features based on config.backbone_out_indices
|
514 |
+
# # note that the hidden_states also include the initial embeddings
|
515 |
+
# if not self.config.is_hybrid:
|
516 |
+
# hidden_states = [
|
517 |
+
# feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices
|
518 |
+
# ]
|
519 |
+
# else:
|
520 |
+
# backbone_hidden_states = outputs.intermediate_activations if return_dict else list(outputs[-1])
|
521 |
+
# backbone_hidden_states.extend(
|
522 |
+
# feature
|
523 |
+
# for idx, feature in enumerate(hidden_states[1:])
|
524 |
+
# if idx in self.config.backbone_out_indices[2:]
|
525 |
+
# )
|
526 |
+
|
527 |
+
# hidden_states = backbone_hidden_states
|
528 |
+
|
529 |
+
|
530 |
+
assert len(hidden_states) == 4
|
531 |
+
|
532 |
+
# upsample hidden_states for unet
|
533 |
+
# hidden_states = [getattr(self, "feature_upsample_%s" %i)(hidden_states[i]) for i in range(len(hidden_states))]
|
534 |
+
hidden_states[0] = self.feature_upsample_0(hidden_states[0])
|
535 |
+
|
536 |
+
patch_height, patch_width = None, None
|
537 |
+
if self.config.backbone_config is not None and self.config.is_hybrid is False:
|
538 |
+
_, _, height, width = hidden_states[3].shape
|
539 |
+
height *= 8; width *= 8
|
540 |
+
patch_size = self.config.backbone_config.patch_size
|
541 |
+
patch_height = height // patch_size
|
542 |
+
patch_width = width // patch_size
|
543 |
+
|
544 |
+
hidden_states = self.neck(hidden_states, patch_height, patch_width)
|
545 |
+
|
546 |
+
predicted_depth = self.head(hidden_states)
|
547 |
+
|
548 |
+
loss = None
|
549 |
+
if labels is not None:
|
550 |
+
raise NotImplementedError("Training is not implemented yet")
|
551 |
+
|
552 |
+
if return_depth_only:
|
553 |
+
return predicted_depth
|
554 |
+
|
555 |
+
return DepthEstimatorOutput(
|
556 |
+
loss=loss,
|
557 |
+
prediction=predicted_depth,
|
558 |
+
hidden_states=None,
|
559 |
+
attentions=None,
|
560 |
+
)
|
561 |
+
|
562 |
+
|
563 |
+
|
564 |
+
class DPTDepthEstimationHeadIdentity(DPTDepthEstimationHead):
|
565 |
+
"""
|
566 |
+
Output head head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
|
567 |
+
the predictions to the input resolution after the first convolutional layer (details can be found in the paper's
|
568 |
+
supplementary material).
|
569 |
+
"""
|
570 |
+
|
571 |
+
def __init__(self, config):
|
572 |
+
super().__init__(config)
|
573 |
+
|
574 |
+
features = config.fusion_hidden_size
|
575 |
+
self.head = nn.Sequential(
|
576 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
577 |
+
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
|
578 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
579 |
+
nn.ReLU(),
|
580 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
581 |
+
nn.Identity(),
|
582 |
+
)
|
583 |
+
|
584 |
+
|
585 |
+
class DPTNeckHeadForUnetAfterUpsampleIdentity(DPTNeckHeadForUnetAfterUpsample):
|
586 |
+
def __init__(self, config):
|
587 |
+
super().__init__(config)
|
588 |
+
|
589 |
+
# Depth estimation head
|
590 |
+
self.head = DPTDepthEstimationHeadIdentity(config)
|
591 |
+
|
592 |
+
# Initialize weights and apply final processing
|
593 |
+
self.post_init()
|
{util → genpercept/util}/batchsize.py
RENAMED
@@ -1,3 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import math
|
3 |
|
@@ -33,11 +53,13 @@ def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> i
|
|
33 |
Automatically search for suitable operating batch size.
|
34 |
|
35 |
Args:
|
36 |
-
ensemble_size (int):
|
37 |
-
|
|
|
|
|
38 |
|
39 |
Returns:
|
40 |
-
int
|
41 |
"""
|
42 |
if not torch.cuda.is_available():
|
43 |
return 1
|
@@ -56,4 +78,4 @@ def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> i
|
|
56 |
bs = math.ceil(ensemble_size / 2)
|
57 |
return bs
|
58 |
|
59 |
-
return 1
|
|
|
1 |
+
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# --------------------------------------------------------------------------
|
15 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
16 |
+
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
17 |
+
# More information about the method can be found at https://marigoldmonodepth.github.io
|
18 |
+
# --------------------------------------------------------------------------
|
19 |
+
|
20 |
+
|
21 |
import torch
|
22 |
import math
|
23 |
|
|
|
53 |
Automatically search for suitable operating batch size.
|
54 |
|
55 |
Args:
|
56 |
+
ensemble_size (`int`):
|
57 |
+
Number of predictions to be ensembled.
|
58 |
+
input_res (`int`):
|
59 |
+
Operating resolution of the input image.
|
60 |
|
61 |
Returns:
|
62 |
+
`int`: Operating batch size.
|
63 |
"""
|
64 |
if not torch.cuda.is_available():
|
65 |
return 1
|
|
|
78 |
bs = math.ceil(ensemble_size / 2)
|
79 |
return bs
|
80 |
|
81 |
+
return 1
|
genpercept/util/ensemble.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# --------------------------------------------------------------------------
|
15 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
16 |
+
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
17 |
+
# More information about the method can be found at https://marigoldmonodepth.github.io
|
18 |
+
# --------------------------------------------------------------------------
|
19 |
+
|
20 |
+
|
21 |
+
from functools import partial
|
22 |
+
from typing import Optional, Tuple
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
import torch
|
26 |
+
|
27 |
+
from .image_util import get_tv_resample_method, resize_max_res
|
28 |
+
|
29 |
+
|
30 |
+
def inter_distances(tensors: torch.Tensor):
|
31 |
+
"""
|
32 |
+
To calculate the distance between each two depth maps.
|
33 |
+
"""
|
34 |
+
distances = []
|
35 |
+
for i, j in torch.combinations(torch.arange(tensors.shape[0])):
|
36 |
+
arr1 = tensors[i : i + 1]
|
37 |
+
arr2 = tensors[j : j + 1]
|
38 |
+
distances.append(arr1 - arr2)
|
39 |
+
dist = torch.concatenate(distances, dim=0)
|
40 |
+
return dist
|
41 |
+
|
42 |
+
|
43 |
+
def ensemble_depth(
|
44 |
+
depth: torch.Tensor,
|
45 |
+
scale_invariant: bool = True,
|
46 |
+
shift_invariant: bool = True,
|
47 |
+
output_uncertainty: bool = False,
|
48 |
+
reduction: str = "median",
|
49 |
+
regularizer_strength: float = 0.02,
|
50 |
+
max_iter: int = 2,
|
51 |
+
tol: float = 1e-3,
|
52 |
+
max_res: int = 1024,
|
53 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
54 |
+
"""
|
55 |
+
Ensembles depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the
|
56 |
+
number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for
|
57 |
+
depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The
|
58 |
+
alignment happens when the predictions have one or more degrees of freedom, that is when they are either
|
59 |
+
affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only
|
60 |
+
`scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`)
|
61 |
+
alignment is skipped and only ensembling is performed.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
depth (`torch.Tensor`):
|
65 |
+
Input ensemble depth maps.
|
66 |
+
scale_invariant (`bool`, *optional*, defaults to `True`):
|
67 |
+
Whether to treat predictions as scale-invariant.
|
68 |
+
shift_invariant (`bool`, *optional*, defaults to `True`):
|
69 |
+
Whether to treat predictions as shift-invariant.
|
70 |
+
output_uncertainty (`bool`, *optional*, defaults to `False`):
|
71 |
+
Whether to output uncertainty map.
|
72 |
+
reduction (`str`, *optional*, defaults to `"median"`):
|
73 |
+
Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and
|
74 |
+
`"median"`.
|
75 |
+
regularizer_strength (`float`, *optional*, defaults to `0.02`):
|
76 |
+
Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1.
|
77 |
+
max_iter (`int`, *optional*, defaults to `2`):
|
78 |
+
Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options`
|
79 |
+
argument.
|
80 |
+
tol (`float`, *optional*, defaults to `1e-3`):
|
81 |
+
Alignment solver tolerance. The solver stops when the tolerance is reached.
|
82 |
+
max_res (`int`, *optional*, defaults to `1024`):
|
83 |
+
Resolution at which the alignment is performed; `None` matches the `processing_resolution`.
|
84 |
+
Returns:
|
85 |
+
A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape:
|
86 |
+
`(1, 1, H, W)`.
|
87 |
+
"""
|
88 |
+
if depth.dim() != 4 or depth.shape[1] != 1:
|
89 |
+
raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.")
|
90 |
+
if reduction not in ("mean", "median"):
|
91 |
+
raise ValueError(f"Unrecognized reduction method: {reduction}.")
|
92 |
+
if not scale_invariant and shift_invariant:
|
93 |
+
raise ValueError("Pure shift-invariant ensembling is not supported.")
|
94 |
+
|
95 |
+
def init_param(depth: torch.Tensor):
|
96 |
+
init_min = depth.reshape(ensemble_size, -1).min(dim=1).values
|
97 |
+
init_max = depth.reshape(ensemble_size, -1).max(dim=1).values
|
98 |
+
|
99 |
+
if scale_invariant and shift_invariant:
|
100 |
+
init_s = 1.0 / (init_max - init_min).clamp(min=1e-6)
|
101 |
+
init_t = -init_s * init_min
|
102 |
+
param = torch.cat((init_s, init_t)).cpu().numpy()
|
103 |
+
elif scale_invariant:
|
104 |
+
init_s = 1.0 / init_max.clamp(min=1e-6)
|
105 |
+
param = init_s.cpu().numpy()
|
106 |
+
else:
|
107 |
+
raise ValueError("Unrecognized alignment.")
|
108 |
+
|
109 |
+
return param
|
110 |
+
|
111 |
+
def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor:
|
112 |
+
if scale_invariant and shift_invariant:
|
113 |
+
s, t = np.split(param, 2)
|
114 |
+
s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1)
|
115 |
+
t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1)
|
116 |
+
out = depth * s + t
|
117 |
+
elif scale_invariant:
|
118 |
+
s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1)
|
119 |
+
out = depth * s
|
120 |
+
else:
|
121 |
+
raise ValueError("Unrecognized alignment.")
|
122 |
+
return out
|
123 |
+
|
124 |
+
def ensemble(
|
125 |
+
depth_aligned: torch.Tensor, return_uncertainty: bool = False
|
126 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
127 |
+
uncertainty = None
|
128 |
+
if reduction == "mean":
|
129 |
+
prediction = torch.mean(depth_aligned, dim=0, keepdim=True)
|
130 |
+
if return_uncertainty:
|
131 |
+
uncertainty = torch.std(depth_aligned, dim=0, keepdim=True)
|
132 |
+
elif reduction == "median":
|
133 |
+
prediction = torch.median(depth_aligned, dim=0, keepdim=True).values
|
134 |
+
if return_uncertainty:
|
135 |
+
uncertainty = torch.median(
|
136 |
+
torch.abs(depth_aligned - prediction), dim=0, keepdim=True
|
137 |
+
).values
|
138 |
+
else:
|
139 |
+
raise ValueError(f"Unrecognized reduction method: {reduction}.")
|
140 |
+
return prediction, uncertainty
|
141 |
+
|
142 |
+
def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float:
|
143 |
+
cost = 0.0
|
144 |
+
depth_aligned = align(depth, param)
|
145 |
+
|
146 |
+
for i, j in torch.combinations(torch.arange(ensemble_size)):
|
147 |
+
diff = depth_aligned[i] - depth_aligned[j]
|
148 |
+
cost += (diff**2).mean().sqrt().item()
|
149 |
+
|
150 |
+
if regularizer_strength > 0:
|
151 |
+
prediction, _ = ensemble(depth_aligned, return_uncertainty=False)
|
152 |
+
err_near = (0.0 - prediction.min()).abs().item()
|
153 |
+
err_far = (1.0 - prediction.max()).abs().item()
|
154 |
+
cost += (err_near + err_far) * regularizer_strength
|
155 |
+
|
156 |
+
return cost
|
157 |
+
|
158 |
+
def compute_param(depth: torch.Tensor):
|
159 |
+
import scipy
|
160 |
+
|
161 |
+
depth_to_align = depth.to(torch.float32)
|
162 |
+
if max_res is not None and max(depth_to_align.shape[2:]) > max_res:
|
163 |
+
try:
|
164 |
+
depth_to_align = resize_max_res(
|
165 |
+
depth_to_align, max_res, get_tv_resample_method("nearest-exact")
|
166 |
+
)
|
167 |
+
except:
|
168 |
+
depth_to_align = resize_max_res(
|
169 |
+
depth_to_align, max_res, get_tv_resample_method("bilinear")
|
170 |
+
)
|
171 |
+
|
172 |
+
param = init_param(depth_to_align)
|
173 |
+
|
174 |
+
res = scipy.optimize.minimize(
|
175 |
+
partial(cost_fn, depth=depth_to_align),
|
176 |
+
param,
|
177 |
+
method="BFGS",
|
178 |
+
tol=tol,
|
179 |
+
options={"maxiter": max_iter, "disp": False},
|
180 |
+
)
|
181 |
+
|
182 |
+
return res.x
|
183 |
+
|
184 |
+
requires_aligning = scale_invariant or shift_invariant
|
185 |
+
ensemble_size = depth.shape[0]
|
186 |
+
|
187 |
+
if requires_aligning:
|
188 |
+
param = compute_param(depth)
|
189 |
+
depth = align(depth, param)
|
190 |
+
|
191 |
+
depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty)
|
192 |
+
|
193 |
+
depth_max = depth.max()
|
194 |
+
if scale_invariant and shift_invariant:
|
195 |
+
depth_min = depth.min()
|
196 |
+
elif scale_invariant:
|
197 |
+
depth_min = 0
|
198 |
+
else:
|
199 |
+
raise ValueError("Unrecognized alignment.")
|
200 |
+
depth_range = (depth_max - depth_min).clamp(min=1e-6)
|
201 |
+
depth = (depth - depth_min) / depth_range
|
202 |
+
if output_uncertainty:
|
203 |
+
uncertainty /= depth_range
|
204 |
+
|
205 |
+
return depth, uncertainty # [1,1,H,W], [1,1,H,W]
|
{util → genpercept/util}/image_util.py
RENAMED
@@ -1,15 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import matplotlib
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
-
from
|
5 |
-
from torchvision import
|
6 |
|
7 |
-
def norm_to_rgb(norm):
|
8 |
-
# norm: (3, H, W), range from [-1, 1]
|
9 |
-
norm_rgb = ((norm + 1) * 0.5) * 255
|
10 |
-
norm_rgb = np.clip(norm_rgb, a_min=0, a_max=255)
|
11 |
-
norm_rgb = norm_rgb.astype(np.uint8)
|
12 |
-
return norm_rgb
|
13 |
|
14 |
def colorize_depth_maps(
|
15 |
depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
|
@@ -20,9 +35,9 @@ def colorize_depth_maps(
|
|
20 |
assert len(depth_map.shape) >= 2, "Invalid dimension"
|
21 |
|
22 |
if isinstance(depth_map, torch.Tensor):
|
23 |
-
depth = depth_map.detach().
|
24 |
elif isinstance(depth_map, np.ndarray):
|
25 |
-
depth =
|
26 |
# reshape to [ (B,) H, W ]
|
27 |
if depth.ndim < 3:
|
28 |
depth = depth[np.newaxis, :, :]
|
@@ -36,7 +51,7 @@ def colorize_depth_maps(
|
|
36 |
if valid_mask is not None:
|
37 |
if isinstance(depth_map, torch.Tensor):
|
38 |
valid_mask = valid_mask.detach().numpy()
|
39 |
-
valid_mask =
|
40 |
if valid_mask.ndim < 3:
|
41 |
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
|
42 |
else:
|
@@ -61,18 +76,28 @@ def chw2hwc(chw):
|
|
61 |
return hwc
|
62 |
|
63 |
|
64 |
-
def resize_max_res(
|
|
|
|
|
|
|
|
|
65 |
"""
|
66 |
-
Resize image to limit maximum edge length while keeping aspect ratio
|
67 |
|
68 |
Args:
|
69 |
-
img (
|
70 |
-
|
|
|
|
|
|
|
|
|
71 |
|
72 |
Returns:
|
73 |
-
|
74 |
"""
|
75 |
-
|
|
|
|
|
76 |
downscale_factor = min(
|
77 |
max_edge_resolution / original_width, max_edge_resolution / original_height
|
78 |
)
|
@@ -80,93 +105,26 @@ def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
|
|
80 |
new_width = int(original_width * downscale_factor)
|
81 |
new_height = int(original_height * downscale_factor)
|
82 |
|
83 |
-
resized_img =
|
84 |
-
return resized_img
|
85 |
-
|
86 |
-
def resize_max_res_integer_16(img: Image.Image, max_edge_resolution: int) -> Image.Image:
|
87 |
-
"""
|
88 |
-
Resize image to limit maximum edge length while keeping aspect ratio
|
89 |
-
|
90 |
-
Args:
|
91 |
-
img (Image.Image): Image to be resized
|
92 |
-
max_edge_resolution (int): Maximum edge length (px).
|
93 |
-
|
94 |
-
Returns:
|
95 |
-
Image.Image: Resized image.
|
96 |
-
"""
|
97 |
-
original_width, original_height = img.size
|
98 |
-
downscale_factor = min(
|
99 |
-
max_edge_resolution / original_width, max_edge_resolution / original_height
|
100 |
-
)
|
101 |
-
|
102 |
-
new_width = int(original_width * downscale_factor) // 16 * 16 # make sure it is integer multiples of 16, used for pixart
|
103 |
-
new_height = int(original_height * downscale_factor) // 16 * 16 # make sure it is integer multiples of 16, used for pixart
|
104 |
-
|
105 |
-
resized_img = img.resize((new_width, new_height))
|
106 |
-
return resized_img
|
107 |
-
|
108 |
-
def resize_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
|
109 |
-
"""
|
110 |
-
Resize image to limit maximum edge length while keeping aspect ratio
|
111 |
-
|
112 |
-
Args:
|
113 |
-
img (Image.Image): Image to be resized
|
114 |
-
max_edge_resolution (int): Maximum edge length (px).
|
115 |
-
|
116 |
-
Returns:
|
117 |
-
Image.Image: Resized image.
|
118 |
-
"""
|
119 |
-
|
120 |
-
resized_img = img.resize((max_edge_resolution, max_edge_resolution))
|
121 |
return resized_img
|
122 |
|
123 |
-
class ResizeLongestEdge:
|
124 |
-
def __init__(self, max_size, interpolation=transforms.InterpolationMode.BILINEAR):
|
125 |
-
self.max_size = max_size
|
126 |
-
self.interpolation = interpolation
|
127 |
-
|
128 |
-
def __call__(self, img):
|
129 |
-
|
130 |
-
scale = self.max_size / max(img.width, img.height)
|
131 |
-
new_size = (int(img.height * scale), int(img.width * scale))
|
132 |
-
|
133 |
-
return transforms.functional.resize(img, new_size, self.interpolation)
|
134 |
-
|
135 |
-
class ResizeShortestEdge:
|
136 |
-
def __init__(self, min_size, interpolation=transforms.InterpolationMode.BILINEAR):
|
137 |
-
self.min_size = min_size
|
138 |
-
self.interpolation = interpolation
|
139 |
-
|
140 |
-
def __call__(self, img):
|
141 |
-
|
142 |
-
scale = self.min_size / min(img.width, img.height)
|
143 |
-
new_size = (int(img.height * scale), int(img.width * scale))
|
144 |
-
|
145 |
-
return transforms.functional.resize(img, new_size, self.interpolation)
|
146 |
-
|
147 |
-
class ResizeHard:
|
148 |
-
def __init__(self, size, interpolation=transforms.InterpolationMode.BILINEAR):
|
149 |
-
self.size = size
|
150 |
-
self.interpolation = interpolation
|
151 |
-
|
152 |
-
def __call__(self, img):
|
153 |
-
|
154 |
-
new_size = (int(self.size), int(self.size))
|
155 |
-
|
156 |
-
return transforms.functional.resize(img, new_size, self.interpolation)
|
157 |
-
|
158 |
-
|
159 |
-
class ResizeLongestEdgeInteger:
|
160 |
-
def __init__(self, max_size, interpolation=transforms.InterpolationMode.BILINEAR, integer=16):
|
161 |
-
self.max_size = max_size
|
162 |
-
self.interpolation = interpolation
|
163 |
-
self.integer = integer
|
164 |
-
|
165 |
-
def __call__(self, img):
|
166 |
-
|
167 |
-
scale = self.max_size / max(img.width, img.height)
|
168 |
-
new_size_h = int(img.height * scale) // self.integer * self.integer
|
169 |
-
new_size_w = int(img.width * scale) // self.integer * self.integer
|
170 |
-
new_size = (new_size_h, new_size_w)
|
171 |
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
2 |
+
# Last modified: 2024-05-24
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# --------------------------------------------------------------------------
|
16 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
17 |
+
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
18 |
+
# More information about the method can be found at https://marigoldmonodepth.github.io
|
19 |
+
# --------------------------------------------------------------------------
|
20 |
+
|
21 |
+
|
22 |
import matplotlib
|
23 |
import numpy as np
|
24 |
import torch
|
25 |
+
from torchvision.transforms import InterpolationMode
|
26 |
+
from torchvision.transforms.functional import resize
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
def colorize_depth_maps(
|
30 |
depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
|
|
|
35 |
assert len(depth_map.shape) >= 2, "Invalid dimension"
|
36 |
|
37 |
if isinstance(depth_map, torch.Tensor):
|
38 |
+
depth = depth_map.detach().squeeze().numpy()
|
39 |
elif isinstance(depth_map, np.ndarray):
|
40 |
+
depth = depth_map.copy().squeeze()
|
41 |
# reshape to [ (B,) H, W ]
|
42 |
if depth.ndim < 3:
|
43 |
depth = depth[np.newaxis, :, :]
|
|
|
51 |
if valid_mask is not None:
|
52 |
if isinstance(depth_map, torch.Tensor):
|
53 |
valid_mask = valid_mask.detach().numpy()
|
54 |
+
valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
|
55 |
if valid_mask.ndim < 3:
|
56 |
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
|
57 |
else:
|
|
|
76 |
return hwc
|
77 |
|
78 |
|
79 |
+
def resize_max_res(
|
80 |
+
img: torch.Tensor,
|
81 |
+
max_edge_resolution: int,
|
82 |
+
resample_method: InterpolationMode = InterpolationMode.BILINEAR,
|
83 |
+
) -> torch.Tensor:
|
84 |
"""
|
85 |
+
Resize image to limit maximum edge length while keeping aspect ratio.
|
86 |
|
87 |
Args:
|
88 |
+
img (`torch.Tensor`):
|
89 |
+
Image tensor to be resized. Expected shape: [B, C, H, W]
|
90 |
+
max_edge_resolution (`int`):
|
91 |
+
Maximum edge length (pixel).
|
92 |
+
resample_method (`PIL.Image.Resampling`):
|
93 |
+
Resampling method used to resize images.
|
94 |
|
95 |
Returns:
|
96 |
+
`torch.Tensor`: Resized image.
|
97 |
"""
|
98 |
+
assert 4 == img.dim(), f"Invalid input shape {img.shape}"
|
99 |
+
|
100 |
+
original_height, original_width = img.shape[-2:]
|
101 |
downscale_factor = min(
|
102 |
max_edge_resolution / original_width, max_edge_resolution / original_height
|
103 |
)
|
|
|
105 |
new_width = int(original_width * downscale_factor)
|
106 |
new_height = int(original_height * downscale_factor)
|
107 |
|
108 |
+
resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
return resized_img
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
+
def get_tv_resample_method(method_str: str) -> InterpolationMode:
|
113 |
+
try:
|
114 |
+
resample_method_dict = {
|
115 |
+
"bilinear": InterpolationMode.BILINEAR,
|
116 |
+
"bicubic": InterpolationMode.BICUBIC,
|
117 |
+
"nearest": InterpolationMode.NEAREST_EXACT,
|
118 |
+
"nearest-exact": InterpolationMode.NEAREST_EXACT,
|
119 |
+
}
|
120 |
+
except:
|
121 |
+
resample_method_dict = {
|
122 |
+
"bilinear": InterpolationMode.BILINEAR,
|
123 |
+
"bicubic": InterpolationMode.BICUBIC,
|
124 |
+
"nearest": InterpolationMode.NEAREST,
|
125 |
+
}
|
126 |
+
resample_method = resample_method_dict.get(method_str, None)
|
127 |
+
if resample_method is None:
|
128 |
+
raise ValueError(f"Unknown resampling method: {resample_method}")
|
129 |
+
else:
|
130 |
+
return resample_method
|
hf_configs/dpt-sd2.1-unet-after-upsample-general/config.json
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_commit_hash": null,
|
3 |
+
"add_projection": true,
|
4 |
+
"architectures": [
|
5 |
+
"DPTForDepthEstimation"
|
6 |
+
],
|
7 |
+
"attention_probs_dropout_prob": null,
|
8 |
+
"auxiliary_loss_weight": 0.4,
|
9 |
+
"backbone_featmap_shape": null,
|
10 |
+
"backbone_out_indices": null,
|
11 |
+
"fusion_hidden_size": 256,
|
12 |
+
"head_in_index": -1,
|
13 |
+
"hidden_act": "gelu",
|
14 |
+
"hidden_dropout_prob": null,
|
15 |
+
"hidden_size": 768,
|
16 |
+
"image_size": null,
|
17 |
+
"initializer_range": 0.02,
|
18 |
+
"intermediate_size": null,
|
19 |
+
"is_hybrid": false,
|
20 |
+
"layer_norm_eps": null,
|
21 |
+
"model_type": "dpt",
|
22 |
+
"neck_hidden_sizes": [
|
23 |
+
320,
|
24 |
+
640,
|
25 |
+
1280,
|
26 |
+
1280
|
27 |
+
],
|
28 |
+
"neck_ignore_stages": [],
|
29 |
+
"num_attention_heads": null,
|
30 |
+
"num_channels": null,
|
31 |
+
"num_hidden_layers": null,
|
32 |
+
"patch_size": null,
|
33 |
+
"qkv_bias": null,
|
34 |
+
"readout_type": "project",
|
35 |
+
"reassemble_factors": [
|
36 |
+
4,
|
37 |
+
2,
|
38 |
+
1,
|
39 |
+
0.5
|
40 |
+
],
|
41 |
+
"semantic_classifier_dropout": 0.1,
|
42 |
+
"semantic_loss_ignore_index": 255,
|
43 |
+
"torch_dtype": "float32",
|
44 |
+
"transformers_version": null,
|
45 |
+
"use_auxiliary_head": true,
|
46 |
+
"use_batch_norm_in_fusion_residual": false,
|
47 |
+
"use_bias_in_fusion_residual": false
|
48 |
+
}
|
hf_configs/dpt-sd2.1-unet-after-upsample-general/preprocessor_config.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_normalize": true,
|
3 |
+
"do_pad": true,
|
4 |
+
"do_rescale": false,
|
5 |
+
"do_resize": true,
|
6 |
+
"ensure_multiple_of": 1,
|
7 |
+
"image_mean": [
|
8 |
+
123.675,
|
9 |
+
116.28,
|
10 |
+
103.53
|
11 |
+
],
|
12 |
+
"image_processor_type": "DPTImageProcessor",
|
13 |
+
"image_std": [
|
14 |
+
58.395,
|
15 |
+
57.12,
|
16 |
+
57.375
|
17 |
+
],
|
18 |
+
"keep_aspect_ratio": false,
|
19 |
+
"resample": 2,
|
20 |
+
"rescale_factor": 0.00392156862745098,
|
21 |
+
"size": {
|
22 |
+
"height": 392,
|
23 |
+
"width": 392
|
24 |
+
},
|
25 |
+
"size_divisor": 14
|
26 |
+
}
|
27 |
+
|
hf_configs/scheduler_beta_1.0_1.0/scheduler_config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "DDIMScheduler",
|
3 |
+
"_diffusers_version": "0.29.2",
|
4 |
+
"beta_end": 1.0,
|
5 |
+
"beta_schedule": "scaled_linear",
|
6 |
+
"beta_start": 1.0,
|
7 |
+
"clip_sample": false,
|
8 |
+
"clip_sample_range": 1.0,
|
9 |
+
"dynamic_thresholding_ratio": 0.995,
|
10 |
+
"num_train_timesteps": 1000,
|
11 |
+
"prediction_type": "v_prediction",
|
12 |
+
"rescale_betas_zero_snr": false,
|
13 |
+
"sample_max_value": 1.0,
|
14 |
+
"set_alpha_to_one": false,
|
15 |
+
"skip_prk_steps": true,
|
16 |
+
"steps_offset": 1,
|
17 |
+
"thresholding": false,
|
18 |
+
"timestep_spacing": "leading",
|
19 |
+
"trained_betas": null
|
20 |
+
}
|
pipeline_genpercept.py
DELETED
@@ -1,355 +0,0 @@
|
|
1 |
-
# --------------------------------------------------------
|
2 |
-
# Diffusion Models Trained with Large Data Are Transferable Visual Models (https://arxiv.org/abs/2403.06090)
|
3 |
-
# Github source: https://github.com/aim-uofa/GenPercept
|
4 |
-
# Copyright (c) 2024 Zhejiang University
|
5 |
-
# Licensed under The CC0 1.0 License [see LICENSE for details]
|
6 |
-
# By Guangkai Xu
|
7 |
-
# Based on Marigold, diffusers codebases
|
8 |
-
# https://github.com/prs-eth/marigold
|
9 |
-
# https://github.com/huggingface/diffusers
|
10 |
-
# --------------------------------------------------------
|
11 |
-
|
12 |
-
import torch
|
13 |
-
import numpy as np
|
14 |
-
import torch.nn.functional as F
|
15 |
-
import matplotlib.pyplot as plt
|
16 |
-
|
17 |
-
from tqdm.auto import tqdm
|
18 |
-
from PIL import Image
|
19 |
-
from typing import List, Dict, Union
|
20 |
-
from torch.utils.data import DataLoader, TensorDataset
|
21 |
-
|
22 |
-
from diffusers import (
|
23 |
-
DiffusionPipeline,
|
24 |
-
UNet2DConditionModel,
|
25 |
-
AutoencoderKL,
|
26 |
-
)
|
27 |
-
from diffusers.utils import BaseOutput
|
28 |
-
|
29 |
-
from util.image_util import chw2hwc, colorize_depth_maps, resize_max_res, norm_to_rgb, resize_res
|
30 |
-
from util.batchsize import find_batch_size
|
31 |
-
|
32 |
-
class GenPerceptOutput(BaseOutput):
|
33 |
-
|
34 |
-
pred_np: np.ndarray
|
35 |
-
pred_colored: Image.Image
|
36 |
-
|
37 |
-
class GenPerceptPipeline(DiffusionPipeline):
|
38 |
-
|
39 |
-
vae_scale_factor = 0.18215
|
40 |
-
task_infos = {
|
41 |
-
'depth': dict(task_channel_num=1, interpolate='bilinear', ),
|
42 |
-
'seg': dict(task_channel_num=3, interpolate='nearest', ),
|
43 |
-
'sr': dict(task_channel_num=3, interpolate='nearest', ),
|
44 |
-
'normal': dict(task_channel_num=3, interpolate='bilinear', ),
|
45 |
-
}
|
46 |
-
|
47 |
-
def __init__(
|
48 |
-
self,
|
49 |
-
unet: UNet2DConditionModel,
|
50 |
-
vae: AutoencoderKL,
|
51 |
-
customized_head=None,
|
52 |
-
empty_text_embed=None,
|
53 |
-
):
|
54 |
-
super().__init__()
|
55 |
-
|
56 |
-
self.empty_text_embed = empty_text_embed
|
57 |
-
|
58 |
-
# register
|
59 |
-
register_dict = dict(
|
60 |
-
unet=unet,
|
61 |
-
vae=vae,
|
62 |
-
customized_head=customized_head,
|
63 |
-
)
|
64 |
-
self.register_modules(**register_dict)
|
65 |
-
|
66 |
-
@torch.no_grad()
|
67 |
-
def __call__(
|
68 |
-
self,
|
69 |
-
input_image: Union[Image.Image, torch.Tensor],
|
70 |
-
mode: str = 'depth',
|
71 |
-
resize_hard = False,
|
72 |
-
processing_res: int = 768,
|
73 |
-
match_input_res: bool = False,
|
74 |
-
batch_size: int = 0,
|
75 |
-
color_map: str = "Spectral",
|
76 |
-
show_progress_bar: bool = True,
|
77 |
-
) -> GenPerceptOutput:
|
78 |
-
"""
|
79 |
-
Function invoked when calling the pipeline.
|
80 |
-
|
81 |
-
Args:
|
82 |
-
input_image (Image):
|
83 |
-
Input RGB (or gray-scale) image.
|
84 |
-
processing_res (int, optional):
|
85 |
-
Maximum resolution of processing.
|
86 |
-
If set to 0: will not resize at all.
|
87 |
-
Defaults to 768.
|
88 |
-
match_input_res (bool, optional):
|
89 |
-
Resize depth prediction to match input resolution.
|
90 |
-
Only valid if `limit_input_res` is not None.
|
91 |
-
Defaults to True.
|
92 |
-
batch_size (int, optional):
|
93 |
-
Inference batch size.
|
94 |
-
If set to 0, the script will automatically decide the proper batch size.
|
95 |
-
Defaults to 0.
|
96 |
-
show_progress_bar (bool, optional):
|
97 |
-
Display a progress bar of diffusion denoising.
|
98 |
-
Defaults to True.
|
99 |
-
color_map (str, optional):
|
100 |
-
Colormap used to colorize the depth map.
|
101 |
-
Defaults to "Spectral".
|
102 |
-
Returns:
|
103 |
-
`GenPerceptOutput`
|
104 |
-
"""
|
105 |
-
|
106 |
-
device = self.device
|
107 |
-
|
108 |
-
task_channel_num = self.task_infos[mode]['task_channel_num']
|
109 |
-
|
110 |
-
if not match_input_res:
|
111 |
-
assert (
|
112 |
-
processing_res is not None
|
113 |
-
), "Value error: `resize_output_back` is only valid with "
|
114 |
-
assert processing_res >= 0
|
115 |
-
|
116 |
-
# ----------------- Image Preprocess -----------------
|
117 |
-
|
118 |
-
if type(input_image) == torch.Tensor: # [B, 3, H, W]
|
119 |
-
rgb_norm = input_image.to(device)
|
120 |
-
input_size = input_image.shape[2:]
|
121 |
-
bs_imgs = rgb_norm.shape[0]
|
122 |
-
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
|
123 |
-
rgb_norm = rgb_norm.to(self.dtype)
|
124 |
-
else:
|
125 |
-
# if len(rgb_paths) > 0 and 'kitti' in rgb_paths[0]:
|
126 |
-
# # kb crop
|
127 |
-
# height = input_image.size[1]
|
128 |
-
# width = input_image.size[0]
|
129 |
-
# top_margin = int(height - 352)
|
130 |
-
# left_margin = int((width - 1216) / 2)
|
131 |
-
# input_image = input_image.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
132 |
-
|
133 |
-
# TODO: check the kitti evaluation resolution here.
|
134 |
-
input_size = (input_image.size[1], input_image.size[0])
|
135 |
-
# Resize image
|
136 |
-
if processing_res > 0:
|
137 |
-
if resize_hard:
|
138 |
-
input_image = resize_res(
|
139 |
-
input_image, max_edge_resolution=processing_res
|
140 |
-
)
|
141 |
-
else:
|
142 |
-
input_image = resize_max_res(
|
143 |
-
input_image, max_edge_resolution=processing_res
|
144 |
-
)
|
145 |
-
input_image = input_image.convert("RGB")
|
146 |
-
image = np.asarray(input_image)
|
147 |
-
|
148 |
-
# Normalize rgb values
|
149 |
-
rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
|
150 |
-
rgb_norm = rgb / 255.0 * 2.0 - 1.0
|
151 |
-
rgb_norm = torch.from_numpy(rgb_norm).to(self.unet.dtype)
|
152 |
-
rgb_norm = rgb_norm[None].to(device)
|
153 |
-
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
|
154 |
-
bs_imgs = 1
|
155 |
-
|
156 |
-
# ----------------- Predicting depth -----------------
|
157 |
-
|
158 |
-
single_rgb_dataset = TensorDataset(rgb_norm)
|
159 |
-
if batch_size > 0:
|
160 |
-
_bs = batch_size
|
161 |
-
else:
|
162 |
-
_bs = find_batch_size(
|
163 |
-
ensemble_size=1,
|
164 |
-
input_res=max(rgb_norm.shape[1:]),
|
165 |
-
dtype=self.dtype,
|
166 |
-
)
|
167 |
-
|
168 |
-
single_rgb_loader = DataLoader(
|
169 |
-
single_rgb_dataset, batch_size=_bs, shuffle=False
|
170 |
-
)
|
171 |
-
|
172 |
-
# Predict depth maps (batched)
|
173 |
-
pred_list = []
|
174 |
-
if show_progress_bar:
|
175 |
-
iterable = tqdm(
|
176 |
-
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
|
177 |
-
)
|
178 |
-
else:
|
179 |
-
iterable = single_rgb_loader
|
180 |
-
|
181 |
-
for batch in iterable:
|
182 |
-
(batched_img, ) = batch
|
183 |
-
pred = self.single_infer(
|
184 |
-
rgb_in=batched_img,
|
185 |
-
mode=mode,
|
186 |
-
)
|
187 |
-
pred_list.append(pred.detach().clone())
|
188 |
-
preds = torch.concat(pred_list, axis=0).squeeze() # [bs_imgs, task_channel_num, H, W]
|
189 |
-
preds = preds.view(bs_imgs, task_channel_num, preds.shape[-2], preds.shape[-1])
|
190 |
-
|
191 |
-
if match_input_res:
|
192 |
-
preds = F.interpolate(preds, input_size, mode=self.task_infos[mode]['interpolate'])
|
193 |
-
|
194 |
-
# ----------------- Post processing -----------------
|
195 |
-
if mode == 'depth':
|
196 |
-
if len(preds.shape) == 4:
|
197 |
-
preds = preds[:, 0] # [bs_imgs, H, W]
|
198 |
-
# Scale prediction to [0, 1]
|
199 |
-
min_d = preds.view(bs_imgs, -1).min(dim=1)[0]
|
200 |
-
max_d = preds.view(bs_imgs, -1).max(dim=1)[0]
|
201 |
-
preds = (preds - min_d[:, None, None]) / (max_d[:, None, None] - min_d[:, None, None])
|
202 |
-
preds = preds.cpu().numpy().astype(np.float32)
|
203 |
-
# Colorize
|
204 |
-
pred_colored_img_list = []
|
205 |
-
for i in range(bs_imgs):
|
206 |
-
pred_colored_chw = colorize_depth_maps(
|
207 |
-
preds[i], 0, 1, cmap=color_map
|
208 |
-
).squeeze() # [3, H, W], value in (0, 1)
|
209 |
-
pred_colored_chw = (pred_colored_chw * 255).astype(np.uint8)
|
210 |
-
pred_colored_hwc = chw2hwc(pred_colored_chw)
|
211 |
-
pred_colored_img = Image.fromarray(pred_colored_hwc)
|
212 |
-
pred_colored_img_list.append(pred_colored_img)
|
213 |
-
|
214 |
-
return GenPerceptOutput(
|
215 |
-
pred_np=np.squeeze(preds),
|
216 |
-
pred_colored=pred_colored_img_list[0] if len(pred_colored_img_list) == 1 else pred_colored_img_list,
|
217 |
-
)
|
218 |
-
|
219 |
-
elif mode == 'seg' or mode == 'sr':
|
220 |
-
if not self.customized_head:
|
221 |
-
# shift to [0, 1]
|
222 |
-
preds = (preds + 1.0) / 2.0
|
223 |
-
# shift to [0, 255]
|
224 |
-
preds = preds * 255
|
225 |
-
# Clip output range
|
226 |
-
preds = preds.clip(0, 255).cpu().numpy().astype(np.uint8)
|
227 |
-
else:
|
228 |
-
raise NotImplementedError
|
229 |
-
|
230 |
-
pred_colored_img_list = []
|
231 |
-
for i in range(preds.shape[0]):
|
232 |
-
pred_colored_hwc = chw2hwc(preds[i])
|
233 |
-
pred_colored_img = Image.fromarray(pred_colored_hwc)
|
234 |
-
pred_colored_img_list.append(pred_colored_img)
|
235 |
-
|
236 |
-
return GenPerceptOutput(
|
237 |
-
pred_np=np.squeeze(preds),
|
238 |
-
pred_colored=pred_colored_img_list[0] if len(pred_colored_img_list) == 1 else pred_colored_img_list,
|
239 |
-
)
|
240 |
-
|
241 |
-
elif mode == 'normal':
|
242 |
-
if not self.customized_head:
|
243 |
-
preds = preds.clip(-1, 1).cpu().numpy() # [-1, 1]
|
244 |
-
else:
|
245 |
-
raise NotImplementedError
|
246 |
-
|
247 |
-
pred_colored_img_list = []
|
248 |
-
for i in range(preds.shape[0]):
|
249 |
-
pred_colored_chw = norm_to_rgb(preds[i])
|
250 |
-
pred_colored_hwc = chw2hwc(pred_colored_chw)
|
251 |
-
normal_colored_img_i = Image.fromarray(pred_colored_hwc)
|
252 |
-
pred_colored_img_list.append(normal_colored_img_i)
|
253 |
-
|
254 |
-
return GenPerceptOutput(
|
255 |
-
pred_np=np.squeeze(preds),
|
256 |
-
pred_colored=pred_colored_img_list[0] if len(pred_colored_img_list) == 1 else pred_colored_img_list,
|
257 |
-
)
|
258 |
-
|
259 |
-
else:
|
260 |
-
raise NotImplementedError
|
261 |
-
|
262 |
-
@torch.no_grad()
|
263 |
-
def single_infer(
|
264 |
-
self,
|
265 |
-
rgb_in: torch.Tensor,
|
266 |
-
mode: str = 'depth',
|
267 |
-
) -> torch.Tensor:
|
268 |
-
"""
|
269 |
-
Perform an individual depth prediction without ensembling.
|
270 |
-
|
271 |
-
Args:
|
272 |
-
rgb_in (torch.Tensor):
|
273 |
-
Input RGB image.
|
274 |
-
num_inference_steps (int):
|
275 |
-
Number of diffusion denoising steps (DDIM) during inference.
|
276 |
-
show_pbar (bool):
|
277 |
-
Display a progress bar of diffusion denoising.
|
278 |
-
|
279 |
-
Returns:
|
280 |
-
torch.Tensor: Predicted depth map.
|
281 |
-
"""
|
282 |
-
device = rgb_in.device
|
283 |
-
bs_imgs = rgb_in.shape[0]
|
284 |
-
timesteps = torch.tensor([1]).long().repeat(bs_imgs).to(device)
|
285 |
-
|
286 |
-
# Encode image
|
287 |
-
rgb_latent = self.encode_rgb(rgb_in)
|
288 |
-
|
289 |
-
batch_embed = self.empty_text_embed
|
290 |
-
batch_embed = batch_embed.repeat((rgb_latent.shape[0], 1, 1)).to(device) # [bs_imgs, 77, 1024]
|
291 |
-
|
292 |
-
# Forward!
|
293 |
-
if self.customized_head:
|
294 |
-
unet_features = self.unet(rgb_latent, timesteps, encoder_hidden_states=batch_embed, return_feature_only=True)[0][::-1]
|
295 |
-
pred = self.customized_head(unet_features)
|
296 |
-
else:
|
297 |
-
unet_output = self.unet(
|
298 |
-
rgb_latent, timesteps, encoder_hidden_states=batch_embed
|
299 |
-
) # [bs_imgs, 4, h, w]
|
300 |
-
unet_pred = unet_output.sample
|
301 |
-
pred_latent = - unet_pred
|
302 |
-
pred_latent.to(device)
|
303 |
-
pred = self.decode_pred(pred_latent)
|
304 |
-
if mode == 'depth':
|
305 |
-
# mean of output channels
|
306 |
-
pred = pred.mean(dim=1, keepdim=True)
|
307 |
-
# clip prediction
|
308 |
-
pred = torch.clip(pred, -1.0, 1.0)
|
309 |
-
return pred
|
310 |
-
|
311 |
-
|
312 |
-
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
|
313 |
-
"""
|
314 |
-
Encode RGB image into latent.
|
315 |
-
|
316 |
-
Args:
|
317 |
-
rgb_in (torch.Tensor):
|
318 |
-
Input RGB image to be encoded.
|
319 |
-
|
320 |
-
Returns:
|
321 |
-
torch.Tensor: Image latent
|
322 |
-
"""
|
323 |
-
try:
|
324 |
-
# encode
|
325 |
-
h_temp = self.vae.encoder(rgb_in)
|
326 |
-
moments = self.vae.quant_conv(h_temp)
|
327 |
-
except:
|
328 |
-
# encode
|
329 |
-
h_temp = self.vae.encoder(rgb_in.float())
|
330 |
-
moments = self.vae.quant_conv(h_temp.float())
|
331 |
-
|
332 |
-
mean, logvar = torch.chunk(moments, 2, dim=1)
|
333 |
-
# scale latent
|
334 |
-
rgb_latent = mean * self.vae_scale_factor
|
335 |
-
return rgb_latent
|
336 |
-
|
337 |
-
def decode_pred(self, pred_latent: torch.Tensor) -> torch.Tensor:
|
338 |
-
"""
|
339 |
-
Decode pred latent into pred label.
|
340 |
-
|
341 |
-
Args:
|
342 |
-
pred_latent (torch.Tensor):
|
343 |
-
prediction latent to be decoded.
|
344 |
-
|
345 |
-
Returns:
|
346 |
-
torch.Tensor: Decoded prediction label.
|
347 |
-
"""
|
348 |
-
# scale latent
|
349 |
-
pred_latent = pred_latent / self.vae_scale_factor
|
350 |
-
# decode
|
351 |
-
z = self.vae.post_quant_conv(pred_latent)
|
352 |
-
pred = self.vae.decoder(z)
|
353 |
-
|
354 |
-
return pred
|
355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -21,3 +21,7 @@ spaces
|
|
21 |
gradio>=4.32.2
|
22 |
gradio_client>=0.17.0
|
23 |
gradio_imageslider>=0.0.20
|
|
|
|
|
|
|
|
|
|
21 |
gradio>=4.32.2
|
22 |
gradio_client>=0.17.0
|
23 |
gradio_imageslider>=0.0.20
|
24 |
+
omegaconf
|
25 |
+
tabulate
|
26 |
+
wandb
|
27 |
+
pandas
|
seg_images/seg_1.jpg
ADDED
seg_images/seg_2.jpg
ADDED
seg_images/seg_3.jpg
ADDED
seg_images/seg_4.jpg
ADDED
seg_images/seg_5.jpg
ADDED
util/__init__.py
DELETED
File without changes
|
util/seed_all.py
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import random
|
3 |
-
import torch
|
4 |
-
|
5 |
-
|
6 |
-
def seed_all(seed: int = 0):
|
7 |
-
"""
|
8 |
-
Set random seeds of all components.
|
9 |
-
"""
|
10 |
-
random.seed(seed)
|
11 |
-
np.random.seed(seed)
|
12 |
-
torch.manual_seed(seed)
|
13 |
-
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|