guangkaixu commited on
Commit
10e02f0
·
1 Parent(s): c83d507
app.py CHANGED
@@ -34,14 +34,17 @@ from PIL import Image
34
 
35
  from gradio_imageslider import ImageSlider
36
  from gradio_patches.examples import Examples
37
- from pipeline_genpercept import GenPerceptPipeline
38
 
39
  from diffusers import (
40
  DiffusionPipeline,
41
- UNet2DConditionModel,
42
  AutoencoderKL,
43
  )
44
 
 
 
 
45
  warnings.filterwarnings(
46
  "ignore", message=".*LoginButton created outside of a Blocks context.*"
47
  )
@@ -194,11 +197,13 @@ def process_matting(
194
  )
195
 
196
 
197
- def run_demo_server(pipe_depth, pipe_normal, pipe_dis, pipe_matting):
198
  process_pipe_depth = spaces.GPU(functools.partial(process_depth, pipe_depth))
199
  process_pipe_normal = spaces.GPU(functools.partial(process_normal, pipe_normal))
200
  process_pipe_dis = spaces.GPU(functools.partial(process_dis, pipe_dis))
201
  process_pipe_matting = spaces.GPU(functools.partial(process_matting, pipe_matting))
 
 
202
  gradio_theme = gr.themes.Default()
203
 
204
  with gr.Blocks(
@@ -485,7 +490,7 @@ def run_demo_server(pipe_depth, pipe_normal, pipe_dis, pipe_matting):
485
  example_folder = os.path.join(os.path.dirname(__file__), "matting_images")
486
  # print(example_folder)
487
  Examples(
488
- fn=process_pipe_dis,
489
  examples=[
490
  os.path.join(example_folder, name)
491
  for name in filenames
@@ -496,6 +501,120 @@ def run_demo_server(pipe_depth, pipe_normal, pipe_dis, pipe_matting):
496
  directory_name="images_cache",
497
  cache_examples=False,
498
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
 
500
 
501
  ### Image tab
@@ -630,6 +749,72 @@ def run_demo_server(pipe_depth, pipe_normal, pipe_dis, pipe_matting):
630
  ],
631
  queue=False,
632
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
  ### Server launch
634
 
635
  demo.queue(
@@ -645,37 +830,61 @@ def main():
645
 
646
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
647
 
648
- dtype = torch.float16
649
-
650
- vae = AutoencoderKL.from_pretrained("guangkaixu/GenPercept", subfolder='vae').to(dtype)
651
- unet_depth_v1 = UNet2DConditionModel.from_pretrained(
652
- 'guangkaixu/genpercept-depth',
653
- subfolder="unet",
654
- use_safetensors=True).to(dtype)
655
- unet_normal_v1 = UNet2DConditionModel.from_pretrained('guangkaixu/GenPercept', subfolder="unet_normal_v1", use_safetensors=True).to(dtype)
656
- unet_dis_v1 = UNet2DConditionModel.from_pretrained('guangkaixu/GenPercept', subfolder="unet_dis_v1", use_safetensors=True).to(dtype)
657
- unet_matting_v1 = UNet2DConditionModel.from_pretrained('guangkaixu/genpercept-matting', subfolder="unet", use_safetensors=True).to(dtype)
658
-
659
- empty_text_embed = torch.from_numpy(np.load("./empty_text_embed.npy")).to(device, dtype)[None] # [1, 77, 1024]
660
 
661
- pipe_depth = GenPerceptPipeline(vae=vae,
662
- unet=unet_depth_v1,
663
- empty_text_embed=empty_text_embed)
664
- pipe_normal = GenPerceptPipeline(vae=vae,
665
- unet=unet_normal_v1,
666
- empty_text_embed=empty_text_embed)
667
- pipe_dis = GenPerceptPipeline(vae=vae,
668
- unet=unet_dis_v1,
669
- empty_text_embed=empty_text_embed)
670
- pipe_matting = GenPerceptPipeline(vae=vae,
671
- unet=unet_matting_v1,
672
- empty_text_embed=empty_text_embed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
673
  try:
674
  import xformers
675
  pipe_depth.enable_xformers_memory_efficient_attention()
676
  pipe_normal.enable_xformers_memory_efficient_attention()
677
  pipe_dis.enable_xformers_memory_efficient_attention()
678
  pipe_matting.enable_xformers_memory_efficient_attention()
 
 
679
  except:
680
  pass # run without xformers
681
 
@@ -683,8 +892,10 @@ def main():
683
  pipe_normal = pipe_normal.to(device)
684
  pipe_dis = pipe_dis.to(device)
685
  pipe_matting = pipe_matting.to(device)
 
 
686
 
687
- run_demo_server(pipe_depth, pipe_normal, pipe_dis, pipe_matting)
688
 
689
 
690
  if __name__ == "__main__":
 
34
 
35
  from gradio_imageslider import ImageSlider
36
  from gradio_patches.examples import Examples
37
+ from genpercept.genpercept_pipeline import GenPerceptPipeline
38
 
39
  from diffusers import (
40
  DiffusionPipeline,
41
+ # UNet2DConditionModel,
42
  AutoencoderKL,
43
  )
44
 
45
+ from genpercept.models.custom_unet import CustomUNet2DConditionModel
46
+ from genpercept.customized_modules.ddim import DDIMSchedulerCustomized
47
+
48
  warnings.filterwarnings(
49
  "ignore", message=".*LoginButton created outside of a Blocks context.*"
50
  )
 
197
  )
198
 
199
 
200
+ def run_demo_server(pipe_depth, pipe_normal, pipe_dis, pipe_matting, pipe_seg, pipe_disparity):
201
  process_pipe_depth = spaces.GPU(functools.partial(process_depth, pipe_depth))
202
  process_pipe_normal = spaces.GPU(functools.partial(process_normal, pipe_normal))
203
  process_pipe_dis = spaces.GPU(functools.partial(process_dis, pipe_dis))
204
  process_pipe_matting = spaces.GPU(functools.partial(process_matting, pipe_matting))
205
+ process_pipe_seg = spaces.GPU(functools.partial(process_matting, pipe_seg))
206
+ process_pipe_disparity = spaces.GPU(functools.partial(process_matting, pipe_disparity))
207
  gradio_theme = gr.themes.Default()
208
 
209
  with gr.Blocks(
 
490
  example_folder = os.path.join(os.path.dirname(__file__), "matting_images")
491
  # print(example_folder)
492
  Examples(
493
+ fn=process_pipe_matting,
494
  examples=[
495
  os.path.join(example_folder, name)
496
  for name in filenames
 
501
  directory_name="images_cache",
502
  cache_examples=False,
503
  )
504
+
505
+ with gr.Tab("Seg"):
506
+ with gr.Row():
507
+ with gr.Column():
508
+ seg_image_input = gr.Image(
509
+ label="Input Image",
510
+ type="filepath",
511
+ # type="pil",
512
+ )
513
+ with gr.Row():
514
+ seg_image_submit_btn = gr.Button(
515
+ value="Estimate Segmentation", variant="primary"
516
+ )
517
+ seg_image_reset_btn = gr.Button(value="Reset")
518
+ with gr.Accordion("Advanced options", open=False):
519
+ image_processing_res = gr.Radio(
520
+ [
521
+ ("Native", 0),
522
+ ("Recommended", 768),
523
+ ],
524
+ label="Processing resolution",
525
+ value=default_image_processing_res,
526
+ )
527
+ with gr.Column():
528
+ seg_image_output_slider = ImageSlider(
529
+ label="Predicted segmentation results",
530
+ type="filepath",
531
+ show_download_button=True,
532
+ show_share_button=True,
533
+ interactive=False,
534
+ elem_classes="slider",
535
+ position=0.25,
536
+ )
537
+ seg_image_output_files = gr.Files(
538
+ label="Seg outputs",
539
+ elem_id="download",
540
+ interactive=False,
541
+ )
542
+
543
+ filenames = []
544
+ filenames.extend(["seg_anime_%d.jpg" %(i+1) for i in range(7)])
545
+ filenames.extend(["seg_line_%d.jpg" %(i+1) for i in range(6)])
546
+ filenames.extend(["seg_real_%d.jpg" %(i+1) for i in range(24)])
547
+
548
+ example_folder = os.path.join(os.path.dirname(__file__), "seg_images")
549
+ Examples(
550
+ fn=process_pipe_seg,
551
+ examples=[
552
+ os.path.join(example_folder, name)
553
+ for name in filenames
554
+ ],
555
+ inputs=[seg_image_input],
556
+ outputs=[seg_image_output_slider, seg_image_output_files],
557
+ cache_examples=False,
558
+ # directory_name="examples_depth",
559
+ # cache_examples=False,
560
+ )
561
+
562
+ with gr.Tab("Disparity"):
563
+ with gr.Row():
564
+ with gr.Column():
565
+ disparity_image_input = gr.Image(
566
+ label="Input Image",
567
+ type="filepath",
568
+ # type="pil",
569
+ )
570
+ with gr.Row():
571
+ disparity_image_submit_btn = gr.Button(
572
+ value="Estimate Disparity", variant="primary"
573
+ )
574
+ disparity_image_reset_btn = gr.Button(value="Reset")
575
+ with gr.Accordion("Advanced options", open=False):
576
+ image_processing_res = gr.Radio(
577
+ [
578
+ ("Native", 0),
579
+ ("Recommended", 768),
580
+ ],
581
+ label="Processing resolution",
582
+ value=default_image_processing_res,
583
+ )
584
+ with gr.Column():
585
+ disparity_image_output_slider = ImageSlider(
586
+ label="Predicted disparity results",
587
+ type="filepath",
588
+ show_download_button=True,
589
+ show_share_button=True,
590
+ interactive=False,
591
+ elem_classes="slider",
592
+ position=0.25,
593
+ )
594
+ disparity_image_output_files = gr.Files(
595
+ label="Disparity outputs",
596
+ elem_id="download",
597
+ interactive=False,
598
+ )
599
+
600
+ filenames = []
601
+ filenames.extend(["disparity_anime_%d.jpg" %(i+1) for i in range(7)])
602
+ filenames.extend(["disparity_line_%d.jpg" %(i+1) for i in range(6)])
603
+ filenames.extend(["disparity_real_%d.jpg" %(i+1) for i in range(24)])
604
+
605
+ example_folder = os.path.join(os.path.dirname(__file__), "depth_images")
606
+ Examples(
607
+ fn=process_pipe_disparity,
608
+ examples=[
609
+ os.path.join(example_folder, name)
610
+ for name in filenames
611
+ ],
612
+ inputs=[disparity_image_input],
613
+ outputs=[disparity_image_output_slider, disparity_image_output_files],
614
+ cache_examples=False,
615
+ # directory_name="examples_depth",
616
+ # cache_examples=False,
617
+ )
618
 
619
 
620
  ### Image tab
 
749
  ],
750
  queue=False,
751
  )
752
+
753
+ seg_image_submit_btn.click(
754
+ fn=process_image_check,
755
+ inputs=seg_image_input,
756
+ outputs=None,
757
+ preprocess=False,
758
+ queue=False,
759
+ ).success(
760
+ fn=process_pipe_seg,
761
+ inputs=[
762
+ seg_image_input,
763
+ image_processing_res,
764
+ ],
765
+ outputs=[seg_image_output_slider, seg_image_output_files],
766
+ concurrency_limit=1,
767
+ )
768
+
769
+ seg_image_reset_btn.click(
770
+ fn=lambda: (
771
+ None,
772
+ None,
773
+ None,
774
+ default_image_processing_res,
775
+ ),
776
+ inputs=[],
777
+ outputs=[
778
+ seg_image_input,
779
+ seg_image_output_slider,
780
+ seg_image_output_files,
781
+ image_processing_res,
782
+ ],
783
+ queue=False,
784
+ )
785
+
786
+ disparity_image_submit_btn.click(
787
+ fn=process_image_check,
788
+ inputs=disparity_image_input,
789
+ outputs=None,
790
+ preprocess=False,
791
+ queue=False,
792
+ ).success(
793
+ fn=process_pipe_disparity,
794
+ inputs=[
795
+ disparity_image_input,
796
+ image_processing_res,
797
+ ],
798
+ outputs=[disparity_image_output_slider, disparity_image_output_files],
799
+ concurrency_limit=1,
800
+ )
801
+
802
+ disparity_image_reset_btn.click(
803
+ fn=lambda: (
804
+ None,
805
+ None,
806
+ None,
807
+ default_image_processing_res,
808
+ ),
809
+ inputs=[],
810
+ outputs=[
811
+ disparity_image_input,
812
+ disparity_image_output_slider,
813
+ disparity_image_output_files,
814
+ image_processing_res,
815
+ ],
816
+ queue=False,
817
+ )
818
  ### Server launch
819
 
820
  demo.queue(
 
830
 
831
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
832
 
833
+ # dtype = torch.float16
834
+ # variant = "fp16"
835
+
836
+ dtype = torch.float32
837
+ variant = None
 
 
 
 
 
 
 
838
 
839
+ unet_depth_v2 = CustomUNet2DConditionModel.from_pretrained('guangkaixu/GenPercept-models', subfolder="unet_depth_v2", use_safetensors=True).to(dtype)
840
+ unet_normal_v2 = CustomUNet2DConditionModel.from_pretrained('guangkaixu/GenPercept-models', subfolder="unet_normal_v2", use_safetensors=True).to(dtype)
841
+ unet_dis_v2 = CustomUNet2DConditionModel.from_pretrained('guangkaixu/GenPercept-models', subfolder="unet_dis_v2", use_safetensors=True).to(dtype)
842
+ unet_matting_v2 = CustomUNet2DConditionModel.from_pretrained('guangkaixu/GenPercept-models', subfolder="unet_matting_v2", use_safetensors=True).to(dtype)
843
+ unet_disparity_v2 = CustomUNet2DConditionModel.from_pretrained('guangkaixu/GenPercept-models', subfolder="unet_disparity_v2", use_safetensors=True).to(dtype)
844
+ unet_seg_v2 = CustomUNet2DConditionModel.from_pretrained('guangkaixu/GenPercept-models', subfolder="unet_seg_v2", use_safetensors=True).to(dtype)
845
+
846
+ scheduler = DDIMSchedulerCustomized.from_pretrained("hf_configs/scheduler_beta_1.0_1.0", subfolder='scheduler')
847
+ genpercept_pipeline = True
848
+
849
+ pre_loaded_dict = dict(
850
+ scheduler=scheduler,
851
+ genpercept_pipeline=genpercept_pipeline,
852
+ torch_dtype=dtype,
853
+ variant=variant,
854
+ )
855
+
856
+ pipe_depth = GenPerceptPipeline.from_pretrained(
857
+ "stabilityai/stable-diffusion-2-1", unet=unet_depth_v2, **pre_loaded_dict,
858
+ )
859
+
860
+ pipe_normal = GenPerceptPipeline.from_pretrained(
861
+ "stabilityai/stable-diffusion-2-1", unet=unet_normal_v2, **pre_loaded_dict,
862
+ )
863
+
864
+ pipe_dis = GenPerceptPipeline.from_pretrained(
865
+ "stabilityai/stable-diffusion-2-1", unet=unet_dis_v2, **pre_loaded_dict,
866
+ )
867
+
868
+ pipe_matting = GenPerceptPipeline.from_pretrained(
869
+ "stabilityai/stable-diffusion-2-1", unet=unet_matting_v2, **pre_loaded_dict,
870
+ )
871
+
872
+ pipe_seg = GenPerceptPipeline.from_pretrained(
873
+ "stabilityai/stable-diffusion-2-1", unet=unet_seg_v2, **pre_loaded_dict,
874
+ )
875
+
876
+ pipe_disparity = GenPerceptPipeline.from_pretrained(
877
+ "stabilityai/stable-diffusion-2-1", unet=unet_disparity_v2, **pre_loaded_dict,
878
+ )
879
+
880
  try:
881
  import xformers
882
  pipe_depth.enable_xformers_memory_efficient_attention()
883
  pipe_normal.enable_xformers_memory_efficient_attention()
884
  pipe_dis.enable_xformers_memory_efficient_attention()
885
  pipe_matting.enable_xformers_memory_efficient_attention()
886
+ pipe_seg.enable_xformers_memory_efficient_attention()
887
+ pipe_disparity.enable_xformers_memory_efficient_attention()
888
  except:
889
  pass # run without xformers
890
 
 
892
  pipe_normal = pipe_normal.to(device)
893
  pipe_dis = pipe_dis.to(device)
894
  pipe_matting = pipe_matting.to(device)
895
+ pipe_seg = pipe_seg.to(device)
896
+ pipe_disparity = pipe_disparity.to(device)
897
 
898
+ run_demo_server(pipe_depth, pipe_normal, pipe_dis, pipe_matting, pipe_seg, pipe_disparity)
899
 
900
 
901
  if __name__ == "__main__":
empty_text_embed.npy DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:677e5e752b1d428a2e5f6f87a62c3a6c726343d264351dc1c433763ddc9b7182
3
- size 157824
 
 
 
 
genpercept/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # What Matters When Repurposing Diffusion Models for General Dense Perception Tasks? (https://arxiv.org/abs/2403.06090)
3
+ # Github source: https://github.com/aim-uofa/GenPercept
4
+ # Copyright (c) 2024, Advanced Intelligent Machines (AIM)
5
+ # Licensed under The BSD 2-Clause License [see LICENSE for details]
6
+ # By Guangkai Xu
7
+ # Based on Marigold, diffusers codebases
8
+ # https://github.com/prs-eth/marigold
9
+ # https://github.com/huggingface/diffusers
10
+ # --------------------------------------------------------
11
+
12
+
13
+ from .genpercept_pipeline import GenPerceptPipeline, GenPerceptOutput
genpercept/customized_modules/ddim.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # What Matters When Repurposing Diffusion Models for General Dense Perception Tasks? (https://arxiv.org/abs/2403.06090)
3
+ # Github source: https://github.com/aim-uofa/GenPercept
4
+ # Copyright (c) 2024, Advanced Intelligent Machines (AIM)
5
+ # Licensed under The BSD 2-Clause License [see LICENSE for details]
6
+ # By Guangkai Xu
7
+ # Based on Marigold, diffusers codebases
8
+ # https://github.com/prs-eth/marigold
9
+ # https://github.com/huggingface/diffusers
10
+ # --------------------------------------------------------
11
+
12
+
13
+ import torch
14
+ from typing import List, Optional, Tuple, Union
15
+ import numpy as np
16
+ from diffusers import DDIMScheduler, DDPMScheduler
17
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
18
+
19
+
20
+ def rescale_zero_terminal_snr(betas):
21
+ """
22
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
23
+
24
+
25
+ Args:
26
+ betas (`torch.FloatTensor`):
27
+ the betas that the scheduler is being initialized with.
28
+
29
+ Returns:
30
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
31
+ """
32
+ # Convert betas to alphas_bar_sqrt
33
+ alphas = 1.0 - betas
34
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
35
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
36
+
37
+ # Store old values.
38
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
39
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
40
+
41
+ # Shift so the last timestep is zero.
42
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
43
+
44
+ # Scale so the first timestep is back to the old value.
45
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
46
+
47
+ # Convert alphas_bar_sqrt to betas
48
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
49
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
50
+ alphas = torch.cat([alphas_bar[0:1], alphas])
51
+ betas = 1 - alphas
52
+
53
+ return betas
54
+
55
+
56
+ class DDPMSchedulerCustomized(DDPMScheduler):
57
+
58
+ @register_to_config
59
+ def __init__(
60
+ self,
61
+ num_train_timesteps: int = 1000,
62
+ beta_start: float = 0.0001,
63
+ beta_end: float = 0.02,
64
+ beta_schedule: str = "linear",
65
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
66
+ variance_type: str = "fixed_small",
67
+ clip_sample: bool = True,
68
+ prediction_type: str = "epsilon",
69
+ thresholding: bool = False,
70
+ dynamic_thresholding_ratio: float = 0.995,
71
+ clip_sample_range: float = 1.0,
72
+ sample_max_value: float = 1.0,
73
+ timestep_spacing: str = "leading",
74
+ steps_offset: int = 0,
75
+ rescale_betas_zero_snr: int = False,
76
+ power_beta_curve = 1.0,
77
+ ):
78
+
79
+ if trained_betas is not None:
80
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
81
+ elif beta_schedule == "linear":
82
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
83
+ elif beta_schedule == "scaled_linear":
84
+ # this schedule is very specific to the latent diffusion model.
85
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
86
+ elif beta_schedule == "scaled_linear_power":
87
+ self.betas = torch.linspace(beta_start**(1/power_beta_curve), beta_end**(1/power_beta_curve), num_train_timesteps, dtype=torch.float32) ** power_beta_curve
88
+ elif beta_schedule == "squaredcos_cap_v2":
89
+ # Glide cosine schedule
90
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
91
+ elif beta_schedule == "sigmoid":
92
+ # GeoDiff sigmoid schedule
93
+ betas = torch.linspace(-6, 6, num_train_timesteps)
94
+ self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
95
+ else:
96
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
97
+
98
+ # Rescale for zero SNR
99
+ if rescale_betas_zero_snr:
100
+ self.betas = rescale_zero_terminal_snr(self.betas)
101
+
102
+ self.alphas = 1.0 - self.betas
103
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
104
+ self.one = torch.tensor(1.0)
105
+
106
+ # standard deviation of the initial noise distribution
107
+ self.init_noise_sigma = 1.0
108
+
109
+ # setable values
110
+ self.custom_timesteps = False
111
+ self.num_inference_steps = None
112
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
113
+
114
+ self.variance_type = variance_type
115
+
116
+ def get_velocity(
117
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
118
+ ) -> torch.FloatTensor:
119
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
120
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
121
+ alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
122
+ timesteps = timesteps.to(sample.device)
123
+
124
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
125
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
126
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
127
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
128
+
129
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
130
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
131
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
132
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
133
+
134
+ # import pdb
135
+ # pdb.set_trace()
136
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
137
+ return velocity
138
+
139
+ class DDIMSchedulerCustomized(DDIMScheduler):
140
+
141
+ @register_to_config
142
+ def __init__(
143
+ self,
144
+ num_train_timesteps: int = 1000,
145
+ beta_start: float = 0.0001,
146
+ beta_end: float = 0.02,
147
+ beta_schedule: str = "linear",
148
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
149
+ clip_sample: bool = True,
150
+ set_alpha_to_one: bool = True,
151
+ steps_offset: int = 0,
152
+ prediction_type: str = "epsilon",
153
+ thresholding: bool = False,
154
+ dynamic_thresholding_ratio: float = 0.995,
155
+ clip_sample_range: float = 1.0,
156
+ sample_max_value: float = 1.0,
157
+ timestep_spacing: str = "leading",
158
+ rescale_betas_zero_snr: bool = False,
159
+ power_beta_curve = 1.0,
160
+ ):
161
+ if trained_betas is not None:
162
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
163
+ elif beta_schedule == "linear":
164
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
165
+ elif beta_schedule == "scaled_linear":
166
+ # this schedule is very specific to the latent diffusion model.
167
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
168
+ elif beta_schedule == "scaled_linear_power":
169
+ self.betas = torch.linspace(beta_start**(1/power_beta_curve), beta_end**(1/power_beta_curve), num_train_timesteps, dtype=torch.float32) ** power_beta_curve
170
+ self.power_beta_curve = power_beta_curve
171
+ elif beta_schedule == "squaredcos_cap_v2":
172
+ # Glide cosine schedule
173
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
174
+ else:
175
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
176
+
177
+ # Rescale for zero SNR
178
+ if rescale_betas_zero_snr:
179
+ self.betas = rescale_zero_terminal_snr(self.betas)
180
+
181
+ # self.betas = self.betas.double()
182
+
183
+ self.alphas = 1.0 - self.betas
184
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
185
+
186
+ # At every step in ddim, we are looking into the previous alphas_cumprod
187
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
188
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
189
+ # whether we use the final alpha of the "non-previous" one.
190
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
191
+
192
+ # standard deviation of the initial noise distribution
193
+ self.init_noise_sigma = 1.0
194
+
195
+ # setable values
196
+ self.num_inference_steps = None
197
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
198
+
199
+ self.beta_schedule = beta_schedule
200
+
201
+ def _get_variance(self, timestep, prev_timestep):
202
+ alpha_prod_t = self.alphas_cumprod[timestep]
203
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
204
+ beta_prod_t = 1 - alpha_prod_t
205
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
206
+
207
+ alpha_t_prev_to_t = self.alphas[(prev_timestep+1):(timestep+1)]
208
+ alpha_t_prev_to_t = torch.prod(alpha_t_prev_to_t)
209
+
210
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_t_prev_to_t)
211
+
212
+ return variance
213
+
genpercept/genpercept_pipeline.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # What Matters When Repurposing Diffusion Models for General Dense Perception Tasks? (https://arxiv.org/abs/2403.06090)
3
+ # Github source: https://github.com/aim-uofa/GenPercept
4
+ # Copyright (c) 2024, Advanced Intelligent Machines (AIM)
5
+ # Licensed under The BSD 2-Clause License [see LICENSE for details]
6
+ # By Guangkai Xu
7
+ # Based on Marigold, diffusers codebases
8
+ # https://github.com/prs-eth/marigold
9
+ # https://github.com/huggingface/diffusers
10
+ # --------------------------------------------------------
11
+
12
+
13
+ import logging
14
+ from typing import Dict, Optional, Union
15
+
16
+ import numpy as np
17
+ import torch
18
+ from diffusers import (
19
+ AutoencoderKL,
20
+ DDIMScheduler,
21
+ DiffusionPipeline,
22
+ LCMScheduler,
23
+ UNet2DConditionModel,
24
+ )
25
+ from diffusers.utils import BaseOutput
26
+ from PIL import Image
27
+ from torch.utils.data import DataLoader, TensorDataset
28
+ from torchvision.transforms import InterpolationMode
29
+ from torchvision.transforms.functional import pil_to_tensor, resize
30
+ from tqdm.auto import tqdm
31
+ from transformers import CLIPTextModel, CLIPTokenizer
32
+
33
+ from .util.batchsize import find_batch_size
34
+ from .util.ensemble import ensemble_depth
35
+ from .util.image_util import (
36
+ chw2hwc,
37
+ colorize_depth_maps,
38
+ get_tv_resample_method,
39
+ resize_max_res,
40
+ )
41
+
42
+ import matplotlib.pyplot as plt
43
+ from genpercept.models.dpt_head import DPTNeckHeadForUnetAfterUpsampleIdentity
44
+
45
+
46
+ class GenPerceptOutput(BaseOutput):
47
+ """
48
+ Output class for GenPercept general perception pipeline.
49
+
50
+ Args:
51
+ pred_np (`np.ndarray`):
52
+ Predicted result, with values in the range of [0, 1].
53
+ pred_colored (`PIL.Image.Image`):
54
+ Colorized result, with the shape of [3, H, W] and values in [0, 1].
55
+ """
56
+
57
+ pred_np: np.ndarray
58
+ pred_colored: Union[None, Image.Image]
59
+
60
+ class GenPerceptPipeline(DiffusionPipeline):
61
+ """
62
+ Pipeline for general perception using GenPercept: https://github.com/aim-uofa/GenPercept.
63
+
64
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
65
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
66
+
67
+ Args:
68
+ unet (`UNet2DConditionModel`):
69
+ Conditional U-Net to denoise the perception latent, conditioned on image latent.
70
+ vae (`AutoencoderKL`):
71
+ Variational Auto-Encoder (VAE) Model to encode and decode images and results
72
+ to and from latent representations.
73
+ scheduler (`DDIMScheduler`):
74
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
75
+ text_encoder (`CLIPTextModel`):
76
+ Text-encoder, for empty text embedding.
77
+ tokenizer (`CLIPTokenizer`):
78
+ CLIP tokenizer.
79
+ default_denoising_steps (`int`, *optional*):
80
+ The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
81
+ quality with the given model. This value must be set in the model config. When the pipeline is called
82
+ without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
83
+ reasonable results with various model flavors compatible with the pipeline, such as those relying on very
84
+ short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
85
+ default_processing_resolution (`int`, *optional*):
86
+ The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
87
+ the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
88
+ default value is used. This is required to ensure reasonable results with various model flavors trained
89
+ with varying optimal processing resolution values.
90
+ """
91
+
92
+ latent_scale_factor = 0.18215
93
+
94
+ def __init__(
95
+ self,
96
+ unet: UNet2DConditionModel,
97
+ vae: AutoencoderKL,
98
+ scheduler: Union[DDIMScheduler, LCMScheduler],
99
+ text_encoder: CLIPTextModel,
100
+ tokenizer: CLIPTokenizer,
101
+ default_denoising_steps: Optional[int] = 10,
102
+ default_processing_resolution: Optional[int] = 768,
103
+ rgb_blending = False,
104
+ customized_head = None,
105
+ genpercept_pipeline = True,
106
+ ):
107
+ super().__init__()
108
+
109
+ self.genpercept_pipeline = genpercept_pipeline
110
+
111
+ if self.genpercept_pipeline:
112
+ default_denoising_steps = 1
113
+ rgb_blending = True
114
+
115
+ self.register_modules(
116
+ unet=unet,
117
+ customized_head=customized_head,
118
+ vae=vae,
119
+ scheduler=scheduler,
120
+ text_encoder=text_encoder,
121
+ tokenizer=tokenizer,
122
+ )
123
+ self.register_to_config(
124
+ default_denoising_steps=default_denoising_steps,
125
+ default_processing_resolution=default_processing_resolution,
126
+ rgb_blending=rgb_blending,
127
+ )
128
+
129
+ self.default_denoising_steps = default_denoising_steps
130
+ self.default_processing_resolution = default_processing_resolution
131
+ self.rgb_blending = rgb_blending
132
+
133
+ self.text_embed = None
134
+
135
+ self.customized_head = customized_head
136
+
137
+ if self.customized_head:
138
+ assert self.rgb_blending and self.scheduler.beta_start == 1 and self.scheduler.beta_end == 1
139
+ assert self.genpercept_pipeline
140
+
141
+ @torch.no_grad()
142
+ def __call__(
143
+ self,
144
+ input_image: Union[Image.Image, torch.Tensor],
145
+ denoising_steps: Optional[int] = None,
146
+ ensemble_size: int = 1,
147
+ processing_res: Optional[int] = None,
148
+ match_input_res: bool = True,
149
+ resample_method: str = "bilinear",
150
+ batch_size: int = 0,
151
+ generator: Union[torch.Generator, None] = None,
152
+ color_map: str = "Spectral",
153
+ show_progress_bar: bool = True,
154
+ ensemble_kwargs: Dict = None,
155
+ mode = None,
156
+ fix_timesteps = None,
157
+ prompt = "",
158
+ ) -> GenPerceptOutput:
159
+ """
160
+ Function invoked when calling the pipeline.
161
+
162
+ Args:
163
+ input_image (`Image`):
164
+ Input RGB (or gray-scale) image.
165
+ denoising_steps (`int`, *optional*, defaults to `None`):
166
+ Number of denoising diffusion steps during inference. The default value `None` results in automatic
167
+ selection.
168
+ ensemble_size (`int`, *optional*, defaults to `10`):
169
+ Number of predictions to be ensembled.
170
+ processing_res (`int`, *optional*, defaults to `None`):
171
+ Effective processing resolution. When set to `0`, processes at the original image resolution. This
172
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
173
+ value `None` resolves to the optimal value from the model config.
174
+ match_input_res (`bool`, *optional*, defaults to `True`):
175
+ Resize perception result to match input resolution.
176
+ Only valid if `processing_res` > 0.
177
+ resample_method: (`str`, *optional*, defaults to `bilinear`):
178
+ Resampling method used to resize images and perception results. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
179
+ batch_size (`int`, *optional*, defaults to `0`):
180
+ Inference batch size, no bigger than `num_ensemble`.
181
+ If set to 0, the script will automatically decide the proper batch size.
182
+ generator (`torch.Generator`, *optional*, defaults to `None`)
183
+ Random generator for initial noise generation.
184
+ show_progress_bar (`bool`, *optional*, defaults to `True`):
185
+ Display a progress bar of diffusion denoising.
186
+ color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized result generation):
187
+ Colormap used to colorize the result.
188
+ ensemble_kwargs (`dict`, *optional*, defaults to `None`):
189
+ Arguments for detailed ensembling settings.
190
+ Returns:
191
+ `GenPerceptOutput`: Output class for GenPercept general perception pipeline, including:
192
+ - **pred_np** (`np.ndarray`) Predicted result, with values in the range of [0, 1]
193
+ - **pred_colored** (`PIL.Image.Image`) Colorized result, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
194
+ """
195
+ assert mode is not None, "mode of GenPerceptPipeline can be chosen from ['depth', 'normal', 'seg', 'matting', 'dis']."
196
+ self.mode = mode
197
+
198
+ # Model-specific optimal default values leading to fast and reasonable results.
199
+ if denoising_steps is None:
200
+ denoising_steps = self.default_denoising_steps
201
+ if processing_res is None:
202
+ processing_res = self.default_processing_resolution
203
+
204
+ assert processing_res >= 0
205
+ assert ensemble_size >= 1
206
+
207
+ if self.genpercept_pipeline:
208
+ assert ensemble_size == 1
209
+ assert denoising_steps == 1
210
+ else:
211
+ # Check if denoising step is reasonable
212
+ self._check_inference_step(denoising_steps)
213
+
214
+ resample_method: InterpolationMode = get_tv_resample_method(resample_method)
215
+
216
+ # ----------------- Image Preprocess -----------------
217
+ # Convert to torch tensor
218
+ if isinstance(input_image, Image.Image):
219
+ input_image = input_image.convert("RGB")
220
+ # convert to torch tensor [H, W, rgb] -> [rgb, H, W]
221
+ rgb = pil_to_tensor(input_image)
222
+ rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
223
+ elif isinstance(input_image, torch.Tensor):
224
+ rgb = input_image
225
+ else:
226
+ raise TypeError(f"Unknown input type: {type(input_image) = }")
227
+ input_size = rgb.shape
228
+ assert (
229
+ 4 == rgb.dim() and 3 == input_size[-3]
230
+ ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
231
+
232
+ # Resize image
233
+ if processing_res > 0:
234
+ rgb = resize_max_res(
235
+ rgb,
236
+ max_edge_resolution=processing_res,
237
+ resample_method=resample_method,
238
+ )
239
+
240
+ # Normalize rgb values
241
+ rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
242
+ rgb_norm = rgb_norm.to(self.dtype)
243
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
244
+
245
+ # ----------------- Perception Inference -----------------
246
+ # Batch repeated input image
247
+ duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
248
+ single_rgb_dataset = TensorDataset(duplicated_rgb)
249
+ if batch_size > 0:
250
+ _bs = batch_size
251
+ else:
252
+ _bs = find_batch_size(
253
+ ensemble_size=ensemble_size,
254
+ input_res=max(rgb_norm.shape[1:]),
255
+ dtype=self.dtype,
256
+ )
257
+
258
+ single_rgb_loader = DataLoader(
259
+ single_rgb_dataset, batch_size=_bs, shuffle=False
260
+ )
261
+
262
+ # Predict results (batched)
263
+ pipe_pred_ls = []
264
+ if show_progress_bar:
265
+ iterable = tqdm(
266
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
267
+ )
268
+ else:
269
+ iterable = single_rgb_loader
270
+ for batch in iterable:
271
+ (batched_img,) = batch
272
+ pipe_pred_raw = self.single_infer(
273
+ rgb_in=batched_img,
274
+ num_inference_steps=denoising_steps,
275
+ show_pbar=show_progress_bar,
276
+ generator=generator,
277
+ fix_timesteps=fix_timesteps,
278
+ prompt=prompt,
279
+ )
280
+ pipe_pred_ls.append(pipe_pred_raw.detach())
281
+ pipe_preds = torch.concat(pipe_pred_ls, dim=0)
282
+ torch.cuda.empty_cache() # clear vram cache for ensembling
283
+
284
+ # ----------------- Test-time ensembling -----------------
285
+ if ensemble_size > 1:
286
+ pipe_pred, _ = ensemble_depth(
287
+ pipe_preds,
288
+ scale_invariant=True,
289
+ shift_invariant=True,
290
+ max_res=50,
291
+ **(ensemble_kwargs or {}),
292
+ )
293
+ else:
294
+ pipe_pred = pipe_preds
295
+
296
+ # Resize back to original resolution
297
+ if match_input_res:
298
+ pipe_pred = resize(
299
+ pipe_pred,
300
+ input_size[-2:],
301
+ interpolation=resample_method,
302
+ antialias=True,
303
+ )
304
+
305
+ # Convert to numpy
306
+ pipe_pred = pipe_pred.squeeze()
307
+ pipe_pred = pipe_pred.cpu().numpy()
308
+
309
+ # Clip output range
310
+ pipe_pred = pipe_pred.clip(0, 1)
311
+
312
+ # Colorize
313
+ if color_map is not None:
314
+ assert self.mode == 'depth'
315
+ pred_colored = colorize_depth_maps(
316
+ pipe_pred, 0, 1, cmap=color_map
317
+ ).squeeze() # [3, H, W], value in (0, 1)
318
+ pred_colored = (pred_colored * 255).astype(np.uint8)
319
+ pred_colored_hwc = chw2hwc(pred_colored)
320
+ pred_colored_img = Image.fromarray(pred_colored_hwc)
321
+ else:
322
+ pred_colored_img = None
323
+
324
+ if len(pipe_pred.shape) == 3 and pipe_pred.shape[0] == 3:
325
+ pipe_pred = np.transpose(pipe_pred, (1, 2, 0))
326
+
327
+ return GenPerceptOutput(
328
+ pred_np=pipe_pred,
329
+ pred_colored=pred_colored_img,
330
+ )
331
+
332
+ def _check_inference_step(self, n_step: int) -> None:
333
+ """
334
+ Check if denoising step is reasonable
335
+ Args:
336
+ n_step (`int`): denoising steps
337
+ """
338
+ assert n_step >= 1
339
+
340
+ if isinstance(self.scheduler, DDIMScheduler):
341
+ if n_step < 10:
342
+ logging.warning(
343
+ f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference."
344
+ )
345
+ elif isinstance(self.scheduler, LCMScheduler):
346
+ if not 1 <= n_step <= 4:
347
+ logging.warning(
348
+ f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps."
349
+ )
350
+ else:
351
+ raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
352
+
353
+ def encode_text(self, prompt):
354
+ """
355
+ Encode text embedding for empty prompt
356
+ """
357
+ text_inputs = self.tokenizer(
358
+ prompt,
359
+ padding="do_not_pad",
360
+ max_length=self.tokenizer.model_max_length,
361
+ truncation=True,
362
+ return_tensors="pt",
363
+ )
364
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
365
+ self.text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
366
+
367
+ @torch.no_grad()
368
+ def single_infer(
369
+ self,
370
+ rgb_in: torch.Tensor,
371
+ num_inference_steps: int,
372
+ generator: Union[torch.Generator, None],
373
+ show_pbar: bool,
374
+ fix_timesteps = None,
375
+ prompt = "",
376
+ ) -> torch.Tensor:
377
+ """
378
+ Perform an individual perception result without ensembling.
379
+
380
+ Args:
381
+ rgb_in (`torch.Tensor`):
382
+ Input RGB image.
383
+ num_inference_steps (`int`):
384
+ Number of diffusion denoisign steps (DDIM) during inference.
385
+ show_pbar (`bool`):
386
+ Display a progress bar of diffusion denoising.
387
+ generator (`torch.Generator`)
388
+ Random generator for initial noise generation.
389
+ Returns:
390
+ `torch.Tensor`: Predicted result.
391
+ """
392
+ device = self.device
393
+ rgb_in = rgb_in.to(device)
394
+
395
+ # Set timesteps
396
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
397
+
398
+ if fix_timesteps:
399
+ timesteps = torch.tensor([fix_timesteps]).long().repeat(self.scheduler.timesteps.shape[0]).to(device)
400
+ else:
401
+ timesteps = self.scheduler.timesteps # [T]
402
+
403
+ # Encode image
404
+ rgb_latent = self.encode_rgb(rgb_in)
405
+
406
+ if not (self.rgb_blending or self.genpercept_pipeline):
407
+ # Initial result (noise)
408
+ pred_latent = torch.randn(
409
+ rgb_latent.shape,
410
+ device=device,
411
+ dtype=self.dtype,
412
+ generator=generator,
413
+ ) # [B, 4, h, w]
414
+ else:
415
+ pred_latent = rgb_latent
416
+
417
+ # Batched empty text embedding
418
+ if self.text_embed is None:
419
+ self.encode_text(prompt)
420
+ batch_text_embed = self.text_embed.repeat(
421
+ (rgb_latent.shape[0], 1, 1)
422
+ ).to(device) # [B, 2, 1024]
423
+
424
+ # Denoising loop
425
+ if show_pbar:
426
+ iterable = tqdm(
427
+ enumerate(timesteps),
428
+ total=len(timesteps),
429
+ leave=False,
430
+ desc=" " * 4 + "Diffusion denoising",
431
+ )
432
+ else:
433
+ iterable = enumerate(timesteps)
434
+
435
+ if not self.customized_head:
436
+ for i, t in iterable:
437
+ if self.genpercept_pipeline and i > 0:
438
+ assert ValueError, "GenPercept only forward once."
439
+
440
+ if not (self.rgb_blending or self.genpercept_pipeline):
441
+ unet_input = torch.cat(
442
+ [rgb_latent, pred_latent], dim=1
443
+ ) # this order is important
444
+ else:
445
+ unet_input = pred_latent
446
+
447
+ # predict the noise residual
448
+ noise_pred = self.unet(
449
+ unet_input, t, encoder_hidden_states=batch_text_embed
450
+ ).sample # [B, 4, h, w]
451
+
452
+ # compute the previous noisy sample x_t -> x_t-1
453
+ step_output = self.scheduler.step(
454
+ noise_pred, t, pred_latent, generator=generator
455
+ )
456
+ pred_latent = step_output.prev_sample
457
+
458
+ pred_latent = step_output.pred_original_sample # NOTE: for GenPercept, it is equivalent to "pred_latent = - noise_pred"
459
+
460
+ pred = self.decode_pred(pred_latent)
461
+
462
+ # clip prediction
463
+ pred = torch.clip(pred, -1.0, 1.0)
464
+ # shift to [0, 1]
465
+ pred = (pred + 1.0) / 2.0
466
+
467
+ elif isinstance(self.customized_head, DPTNeckHeadForUnetAfterUpsampleIdentity):
468
+ unet_input = pred_latent
469
+ model_pred_output = self.unet(
470
+ unet_input, timesteps, encoder_hidden_states=batch_text_embed, return_feature=True
471
+ ) # [B, 4, h, w]
472
+ unet_features = model_pred_output.multi_level_feats[::-1]
473
+ pred = self.customized_head(hidden_states=unet_features).prediction[:, None]
474
+ # shift to [0, 1]
475
+ pred = (pred - pred.min()) / (pred.max() - pred.min())
476
+ else:
477
+ raise ValueError
478
+
479
+ return pred
480
+
481
+ def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
482
+ """
483
+ Encode RGB image into latent.
484
+
485
+ Args:
486
+ rgb_in (`torch.Tensor`):
487
+ Input RGB image to be encoded.
488
+
489
+ Returns:
490
+ `torch.Tensor`: Image latent.
491
+ """
492
+ # encode
493
+ h = self.vae.encoder(rgb_in)
494
+ moments = self.vae.quant_conv(h)
495
+ mean, logvar = torch.chunk(moments, 2, dim=1)
496
+ # scale latent
497
+ rgb_latent = mean * self.latent_scale_factor
498
+ return rgb_latent
499
+
500
+ def decode_pred(self, pred_latent: torch.Tensor) -> torch.Tensor:
501
+ """
502
+ Decode pred latent into result.
503
+
504
+ Args:
505
+ pred_latent (`torch.Tensor`):
506
+ pred latent to be decoded.
507
+
508
+ Returns:
509
+ `torch.Tensor`: Decoded result.
510
+ """
511
+ # scale latent
512
+ pred_latent = pred_latent / self.latent_scale_factor
513
+ # decode
514
+ z = self.vae.post_quant_conv(pred_latent)
515
+ stacked = self.vae.decoder(z)
516
+ if self.mode in ['depth', 'matting', 'dis']:
517
+ # mean of output channels
518
+ stacked = stacked.mean(dim=1, keepdim=True)
519
+ return stacked
genpercept/models/custom_unet.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # What Matters When Repurposing Diffusion Models for General Dense Perception Tasks? (https://arxiv.org/abs/2403.06090)
3
+ # Github source: https://github.com/aim-uofa/GenPercept
4
+ # Copyright (c) 2024, Advanced Intelligent Machines (AIM)
5
+ # Licensed under The BSD 2-Clause License [see LICENSE for details]
6
+ # By Guangkai Xu
7
+ # Based on diffusers codebases
8
+ # https://github.com/huggingface/diffusers
9
+ # --------------------------------------------------------
10
+
11
+ from diffusers import UNet2DConditionModel
12
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
13
+ from typing import Any, Dict, List, Optional, Tuple, Union
14
+ import torch
15
+ import torch.utils.checkpoint
16
+ from dataclasses import dataclass
17
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
18
+
19
+ @dataclass
20
+ class CustomUNet2DConditionOutput(BaseOutput):
21
+ """
22
+ The output of [`UNet2DConditionModel`].
23
+
24
+ Args:
25
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
26
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
27
+ """
28
+
29
+ sample: torch.FloatTensor = None
30
+ multi_level_feats: [torch.FloatTensor] = None
31
+
32
+ class CustomUNet2DConditionModel(UNet2DConditionModel):
33
+
34
+ def forward(
35
+ self,
36
+ sample: torch.FloatTensor,
37
+ timestep: Union[torch.Tensor, float, int],
38
+ encoder_hidden_states: torch.Tensor,
39
+ class_labels: Optional[torch.Tensor] = None,
40
+ timestep_cond: Optional[torch.Tensor] = None,
41
+ attention_mask: Optional[torch.Tensor] = None,
42
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
43
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
44
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
45
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
46
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
47
+ encoder_attention_mask: Optional[torch.Tensor] = None,
48
+ return_feature: bool = False,
49
+ return_dict: bool = True,
50
+ ) -> Union[UNet2DConditionOutput, Tuple]:
51
+ r"""
52
+ The [`UNet2DConditionModel`] forward method.
53
+
54
+ Args:
55
+ sample (`torch.FloatTensor`):
56
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
57
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
58
+ encoder_hidden_states (`torch.FloatTensor`):
59
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
60
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
61
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
62
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
63
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
64
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
65
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
66
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
67
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
68
+ negative values to the attention scores corresponding to "discard" tokens.
69
+ cross_attention_kwargs (`dict`, *optional*):
70
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
71
+ `self.processor` in
72
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
73
+ added_cond_kwargs: (`dict`, *optional*):
74
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
75
+ are passed along to the UNet blocks.
76
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
77
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
78
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
79
+ A tensor that if specified is added to the residual of the middle unet block.
80
+ encoder_attention_mask (`torch.Tensor`):
81
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
82
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
83
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
84
+ return_dict (`bool`, *optional*, defaults to `True`):
85
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
86
+ tuple.
87
+ cross_attention_kwargs (`dict`, *optional*):
88
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
89
+ added_cond_kwargs: (`dict`, *optional*):
90
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
91
+ are passed along to the UNet blocks.
92
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
93
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
94
+ example from ControlNet side model(s)
95
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
96
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
97
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
98
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
99
+
100
+ Returns:
101
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
102
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
103
+ a `tuple` is returned where the first element is the sample tensor.
104
+ """
105
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
106
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
107
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
108
+ # on the fly if necessary.
109
+ default_overall_up_factor = 2**self.num_upsamplers
110
+
111
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
112
+ forward_upsample_size = False
113
+ upsample_size = None
114
+
115
+ for dim in sample.shape[-2:]:
116
+ if dim % default_overall_up_factor != 0:
117
+ # Forward upsample size to force interpolation output size.
118
+ forward_upsample_size = True
119
+ break
120
+
121
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
122
+ # expects mask of shape:
123
+ # [batch, key_tokens]
124
+ # adds singleton query_tokens dimension:
125
+ # [batch, 1, key_tokens]
126
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
127
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
128
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
129
+ if attention_mask is not None:
130
+ # assume that mask is expressed as:
131
+ # (1 = keep, 0 = discard)
132
+ # convert mask into a bias that can be added to attention scores:
133
+ # (keep = +0, discard = -10000.0)
134
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
135
+ attention_mask = attention_mask.unsqueeze(1)
136
+
137
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
138
+ if encoder_attention_mask is not None:
139
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
140
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
141
+
142
+ # 0. center input if necessary
143
+ if self.config.center_input_sample:
144
+ sample = 2 * sample - 1.0
145
+
146
+ # 1. time
147
+ timesteps = timestep
148
+ if not torch.is_tensor(timesteps):
149
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
150
+ # This would be a good case for the `match` statement (Python 3.10+)
151
+ is_mps = sample.device.type == "mps"
152
+ if isinstance(timestep, float):
153
+ dtype = torch.float32 if is_mps else torch.float64
154
+ else:
155
+ dtype = torch.int32 if is_mps else torch.int64
156
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
157
+ elif len(timesteps.shape) == 0:
158
+ timesteps = timesteps[None].to(sample.device)
159
+
160
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
161
+ timesteps = timesteps.expand(sample.shape[0])
162
+
163
+ t_emb = self.time_proj(timesteps)
164
+
165
+ # `Timesteps` does not contain any weights and will always return f32 tensors
166
+ # but time_embedding might actually be running in fp16. so we need to cast here.
167
+ # there might be better ways to encapsulate this.
168
+ t_emb = t_emb.to(dtype=sample.dtype)
169
+
170
+ emb = self.time_embedding(t_emb, timestep_cond)
171
+ aug_emb = None
172
+
173
+ if self.class_embedding is not None:
174
+ if class_labels is None:
175
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
176
+
177
+ if self.config.class_embed_type == "timestep":
178
+ class_labels = self.time_proj(class_labels)
179
+
180
+ # `Timesteps` does not contain any weights and will always return f32 tensors
181
+ # there might be better ways to encapsulate this.
182
+ class_labels = class_labels.to(dtype=sample.dtype)
183
+
184
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
185
+
186
+ if self.config.class_embeddings_concat:
187
+ emb = torch.cat([emb, class_emb], dim=-1)
188
+ else:
189
+ emb = emb + class_emb
190
+
191
+ if self.config.addition_embed_type == "text":
192
+ aug_emb = self.add_embedding(encoder_hidden_states)
193
+ elif self.config.addition_embed_type == "text_image":
194
+ # Kandinsky 2.1 - style
195
+ if "image_embeds" not in added_cond_kwargs:
196
+ raise ValueError(
197
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
198
+ )
199
+
200
+ image_embs = added_cond_kwargs.get("image_embeds")
201
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
202
+ aug_emb = self.add_embedding(text_embs, image_embs)
203
+ elif self.config.addition_embed_type == "text_time":
204
+ # SDXL - style
205
+ if "text_embeds" not in added_cond_kwargs:
206
+ raise ValueError(
207
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
208
+ )
209
+ text_embeds = added_cond_kwargs.get("text_embeds")
210
+ if "time_ids" not in added_cond_kwargs:
211
+ raise ValueError(
212
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
213
+ )
214
+ time_ids = added_cond_kwargs.get("time_ids")
215
+ time_embeds = self.add_time_proj(time_ids.flatten())
216
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
217
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
218
+ add_embeds = add_embeds.to(emb.dtype)
219
+ aug_emb = self.add_embedding(add_embeds)
220
+ elif self.config.addition_embed_type == "image":
221
+ # Kandinsky 2.2 - style
222
+ if "image_embeds" not in added_cond_kwargs:
223
+ raise ValueError(
224
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
225
+ )
226
+ image_embs = added_cond_kwargs.get("image_embeds")
227
+ aug_emb = self.add_embedding(image_embs)
228
+ elif self.config.addition_embed_type == "image_hint":
229
+ # Kandinsky 2.2 - style
230
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
231
+ raise ValueError(
232
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
233
+ )
234
+ image_embs = added_cond_kwargs.get("image_embeds")
235
+ hint = added_cond_kwargs.get("hint")
236
+ aug_emb, hint = self.add_embedding(image_embs, hint)
237
+ sample = torch.cat([sample, hint], dim=1)
238
+
239
+ emb = emb + aug_emb if aug_emb is not None else emb
240
+
241
+ if self.time_embed_act is not None:
242
+ emb = self.time_embed_act(emb)
243
+
244
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
245
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
246
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
247
+ # Kadinsky 2.1 - style
248
+ if "image_embeds" not in added_cond_kwargs:
249
+ raise ValueError(
250
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
251
+ )
252
+
253
+ image_embeds = added_cond_kwargs.get("image_embeds")
254
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
255
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
256
+ # Kandinsky 2.2 - style
257
+ if "image_embeds" not in added_cond_kwargs:
258
+ raise ValueError(
259
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
260
+ )
261
+ image_embeds = added_cond_kwargs.get("image_embeds")
262
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
263
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
264
+ if "image_embeds" not in added_cond_kwargs:
265
+ raise ValueError(
266
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
267
+ )
268
+ image_embeds = added_cond_kwargs.get("image_embeds")
269
+ image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
270
+ encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
271
+
272
+ # 2. pre-process
273
+ sample = self.conv_in(sample)
274
+
275
+ # 2.5 GLIGEN position net
276
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
277
+ cross_attention_kwargs = cross_attention_kwargs.copy()
278
+ gligen_args = cross_attention_kwargs.pop("gligen")
279
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
280
+
281
+ # 3. down
282
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
283
+ if USE_PEFT_BACKEND:
284
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
285
+ scale_lora_layers(self, lora_scale)
286
+
287
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
288
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
289
+ is_adapter = down_intrablock_additional_residuals is not None
290
+ # maintain backward compatibility for legacy usage, where
291
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
292
+ # but can only use one or the other
293
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
294
+ deprecate(
295
+ "T2I should not use down_block_additional_residuals",
296
+ "1.3.0",
297
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
298
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
299
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
300
+ standard_warn=False,
301
+ )
302
+ down_intrablock_additional_residuals = down_block_additional_residuals
303
+ is_adapter = True
304
+
305
+ down_block_res_samples = (sample,)
306
+ for downsample_block in self.down_blocks:
307
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
308
+ # For t2i-adapter CrossAttnDownBlock2D
309
+ additional_residuals = {}
310
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
311
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
312
+
313
+ sample, res_samples = downsample_block(
314
+ hidden_states=sample,
315
+ temb=emb,
316
+ encoder_hidden_states=encoder_hidden_states,
317
+ attention_mask=attention_mask,
318
+ cross_attention_kwargs=cross_attention_kwargs,
319
+ encoder_attention_mask=encoder_attention_mask,
320
+ **additional_residuals,
321
+ )
322
+ else:
323
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
324
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
325
+ sample += down_intrablock_additional_residuals.pop(0)
326
+
327
+ down_block_res_samples += res_samples
328
+
329
+ if is_controlnet:
330
+ new_down_block_res_samples = ()
331
+
332
+ for down_block_res_sample, down_block_additional_residual in zip(
333
+ down_block_res_samples, down_block_additional_residuals
334
+ ):
335
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
336
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
337
+
338
+ down_block_res_samples = new_down_block_res_samples
339
+
340
+ # 4. mid
341
+ if self.mid_block is not None:
342
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
343
+ sample = self.mid_block(
344
+ sample,
345
+ emb,
346
+ encoder_hidden_states=encoder_hidden_states,
347
+ attention_mask=attention_mask,
348
+ cross_attention_kwargs=cross_attention_kwargs,
349
+ encoder_attention_mask=encoder_attention_mask,
350
+ )
351
+ else:
352
+ sample = self.mid_block(sample, emb)
353
+
354
+ # To support T2I-Adapter-XL
355
+ if (
356
+ is_adapter
357
+ and len(down_intrablock_additional_residuals) > 0
358
+ and sample.shape == down_intrablock_additional_residuals[0].shape
359
+ ):
360
+ sample += down_intrablock_additional_residuals.pop(0)
361
+
362
+ if is_controlnet:
363
+ sample = sample + mid_block_additional_residual
364
+
365
+ multi_level_feats = []
366
+ # 1, 1280, 24, 24
367
+ # multi_level_feats.append(sample) # 1/64
368
+ # 5. up
369
+ for i, upsample_block in enumerate(self.up_blocks):
370
+ is_final_block = i == len(self.up_blocks) - 1
371
+
372
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
373
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
374
+
375
+ # if we have not reached the final block and need to forward the
376
+ # upsample size, we do it here
377
+ if not is_final_block and forward_upsample_size:
378
+ upsample_size = down_block_res_samples[-1].shape[2:]
379
+
380
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
381
+ sample = upsample_block(
382
+ hidden_states=sample,
383
+ temb=emb,
384
+ res_hidden_states_tuple=res_samples,
385
+ encoder_hidden_states=encoder_hidden_states,
386
+ cross_attention_kwargs=cross_attention_kwargs,
387
+ upsample_size=upsample_size,
388
+ attention_mask=attention_mask,
389
+ encoder_attention_mask=encoder_attention_mask,
390
+ )
391
+ else:
392
+ sample = upsample_block(
393
+ hidden_states=sample,
394
+ temb=emb,
395
+ res_hidden_states_tuple=res_samples,
396
+ upsample_size=upsample_size,
397
+ scale=lora_scale,
398
+ )
399
+ # if not is_final_block:
400
+ multi_level_feats.append(sample)
401
+
402
+ if return_feature:
403
+ if USE_PEFT_BACKEND:
404
+ # remove `lora_scale` from each PEFT layer
405
+ unscale_lora_layers(self, lora_scale)
406
+ return CustomUNet2DConditionOutput(
407
+ multi_level_feats=multi_level_feats,
408
+ )
409
+
410
+ # 6. post-process
411
+ if self.conv_norm_out:
412
+ sample = self.conv_norm_out(sample)
413
+ sample = self.conv_act(sample)
414
+
415
+ sample = self.conv_out(sample)
416
+
417
+ if USE_PEFT_BACKEND:
418
+ # remove `lora_scale` from each PEFT layer
419
+ unscale_lora_layers(self, lora_scale)
420
+
421
+ if not return_dict:
422
+ return (sample,)
423
+
424
+ return CustomUNet2DConditionOutput(
425
+ sample=sample,
426
+ multi_level_feats=multi_level_feats,
427
+ )
genpercept/models/dpt_head.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # What Matters When Repurposing Diffusion Models for General Dense Perception Tasks? (https://arxiv.org/abs/2403.06090)
3
+ # Github source: https://github.com/aim-uofa/GenPercept
4
+ # Copyright (c) 2024, Advanced Intelligent Machines (AIM)
5
+ # Licensed under The BSD 2-Clause License [see LICENSE for details]
6
+ # By Guangkai Xu
7
+ # Based on diffusers codebases
8
+ # https://github.com/huggingface/diffusers
9
+ # --------------------------------------------------------
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from typing import List, Optional, Tuple, Union
14
+ from transformers import DPTPreTrainedModel
15
+
16
+ from transformers.utils import ModelOutput
17
+ from transformers.file_utils import replace_return_docstrings, add_start_docstrings_to_model_forward
18
+ from transformers.models.dpt.modeling_dpt import DPTReassembleStage
19
+
20
+ from diffusers.models.lora import LoRACompatibleConv
21
+ from diffusers.utils import USE_PEFT_BACKEND
22
+ import torch.nn.functional as F
23
+
24
+ class DepthEstimatorOutput(ModelOutput):
25
+ """
26
+ Base class for outputs of depth estimation models.
27
+
28
+ Args:
29
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
30
+ Classification (or regression if config.num_labels==1) loss.
31
+ prediction (`torch.FloatTensor` of shape `(batch_size, height, width)`):
32
+ Predicted depth for each pixel.
33
+
34
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
35
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
36
+ one for the output of each layer) of shape `(batch_size, num_channels, height, width)`.
37
+
38
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
39
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
40
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
41
+ sequence_length)`.
42
+
43
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
44
+ heads.
45
+ """
46
+
47
+ loss: Optional[torch.FloatTensor] = None
48
+ prediction: torch.FloatTensor = None
49
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
50
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
51
+
52
+ class DPTDepthEstimationHead(nn.Module):
53
+ """
54
+ Output head head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
55
+ the predictions to the input resolution after the first convolutional layer (details can be found in the paper's
56
+ supplementary material).
57
+ """
58
+
59
+ def __init__(self, config):
60
+ super().__init__()
61
+
62
+ self.config = config
63
+
64
+ self.projection = None
65
+ features = config.fusion_hidden_size
66
+ if config.add_projection:
67
+ self.projection = nn.Conv2d(features, features, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
68
+
69
+ self.head = nn.Sequential(
70
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
71
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
72
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
73
+ nn.ReLU(),
74
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
75
+ nn.ReLU(),
76
+ )
77
+
78
+ def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
79
+ # use last features
80
+ hidden_states = hidden_states[self.config.head_in_index]
81
+
82
+ if self.projection is not None:
83
+ hidden_states = self.projection(hidden_states)
84
+ hidden_states = nn.ReLU()(hidden_states)
85
+
86
+ predicted_depth = self.head(hidden_states)
87
+
88
+ predicted_depth = predicted_depth.squeeze(dim=1)
89
+
90
+ return predicted_depth
91
+
92
+ class Upsample2D(nn.Module):
93
+ """A 2D upsampling layer with an optional convolution.
94
+
95
+ Parameters:
96
+ channels (`int`):
97
+ number of channels in the inputs and outputs.
98
+ use_conv (`bool`, default `False`):
99
+ option to use a convolution.
100
+ use_conv_transpose (`bool`, default `False`):
101
+ option to use a convolution transpose.
102
+ out_channels (`int`, optional):
103
+ number of output channels. Defaults to `channels`.
104
+ name (`str`, default `conv`):
105
+ name of the upsampling 2D layer.
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ channels: int,
111
+ use_conv: bool = False,
112
+ use_conv_transpose: bool = False,
113
+ out_channels: Optional[int] = None,
114
+ name: str = "conv",
115
+ kernel_size: Optional[int] = None,
116
+ padding=1,
117
+ norm_type=None,
118
+ eps=None,
119
+ elementwise_affine=None,
120
+ bias=True,
121
+ interpolate=True,
122
+ ):
123
+ super().__init__()
124
+ self.channels = channels
125
+ self.out_channels = out_channels or channels
126
+ self.use_conv = use_conv
127
+ self.use_conv_transpose = use_conv_transpose
128
+ self.name = name
129
+ self.interpolate = interpolate
130
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
131
+
132
+ if norm_type == "ln_norm":
133
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
134
+ elif norm_type == "rms_norm":
135
+ # self.norm = RMSNorm(channels, eps, elementwise_affine)
136
+ raise NotImplementedError
137
+ elif norm_type is None:
138
+ self.norm = None
139
+ else:
140
+ raise ValueError(f"unknown norm_type: {norm_type}")
141
+
142
+ conv = None
143
+ if use_conv_transpose:
144
+ if kernel_size is None:
145
+ kernel_size = 4
146
+ conv = nn.ConvTranspose2d(
147
+ channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias
148
+ )
149
+ elif use_conv:
150
+ if kernel_size is None:
151
+ kernel_size = 3
152
+ conv = conv_cls(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
153
+
154
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
155
+ if name == "conv":
156
+ self.conv = conv
157
+ else:
158
+ self.Conv2d_0 = conv
159
+
160
+ def forward(
161
+ self,
162
+ hidden_states: torch.FloatTensor,
163
+ output_size: Optional[int] = None,
164
+ scale: float = 1.0,
165
+ ) -> torch.FloatTensor:
166
+ assert hidden_states.shape[1] == self.channels
167
+
168
+ if self.norm is not None:
169
+ hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
170
+
171
+ if self.use_conv_transpose:
172
+ return self.conv(hidden_states)
173
+
174
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
175
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
176
+ # https://github.com/pytorch/pytorch/issues/86679
177
+ dtype = hidden_states.dtype
178
+ if dtype == torch.bfloat16:
179
+ hidden_states = hidden_states.to(torch.float32)
180
+
181
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
182
+ if hidden_states.shape[0] >= 64:
183
+ hidden_states = hidden_states.contiguous()
184
+
185
+ # if `output_size` is passed we force the interpolation output
186
+ # size and do not make use of `scale_factor=2`
187
+ if self.interpolate:
188
+ if output_size is None:
189
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
190
+ else:
191
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
192
+
193
+ # If the input is bfloat16, we cast back to bfloat16
194
+ if dtype == torch.bfloat16:
195
+ hidden_states = hidden_states.to(dtype)
196
+
197
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
198
+ if self.use_conv:
199
+ if self.name == "conv":
200
+ if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
201
+ hidden_states = self.conv(hidden_states, scale)
202
+ else:
203
+ hidden_states = self.conv(hidden_states)
204
+ else:
205
+ if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
206
+ hidden_states = self.Conv2d_0(hidden_states, scale)
207
+ else:
208
+ hidden_states = self.Conv2d_0(hidden_states)
209
+
210
+ return hidden_states
211
+
212
+
213
+ class DPTPreActResidualLayer(nn.Module):
214
+ """
215
+ ResidualConvUnit, pre-activate residual unit.
216
+
217
+ Args:
218
+ config (`[DPTConfig]`):
219
+ Model configuration class defining the model architecture.
220
+ """
221
+
222
+ def __init__(self, config):
223
+ super().__init__()
224
+
225
+ self.use_batch_norm = config.use_batch_norm_in_fusion_residual
226
+ use_bias_in_fusion_residual = (
227
+ config.use_bias_in_fusion_residual
228
+ if config.use_bias_in_fusion_residual is not None
229
+ else not self.use_batch_norm
230
+ )
231
+
232
+ self.activation1 = nn.ReLU()
233
+ self.convolution1 = nn.Conv2d(
234
+ config.fusion_hidden_size,
235
+ config.fusion_hidden_size,
236
+ kernel_size=3,
237
+ stride=1,
238
+ padding=1,
239
+ bias=use_bias_in_fusion_residual,
240
+ )
241
+
242
+ self.activation2 = nn.ReLU()
243
+ self.convolution2 = nn.Conv2d(
244
+ config.fusion_hidden_size,
245
+ config.fusion_hidden_size,
246
+ kernel_size=3,
247
+ stride=1,
248
+ padding=1,
249
+ bias=use_bias_in_fusion_residual,
250
+ )
251
+
252
+ if self.use_batch_norm:
253
+ self.batch_norm1 = nn.BatchNorm2d(config.fusion_hidden_size)
254
+ self.batch_norm2 = nn.BatchNorm2d(config.fusion_hidden_size)
255
+
256
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
257
+ residual = hidden_state.clone()
258
+ hidden_state = self.activation1(hidden_state)
259
+
260
+ hidden_state = self.convolution1(hidden_state)
261
+
262
+ if self.use_batch_norm:
263
+ hidden_state = self.batch_norm1(hidden_state)
264
+
265
+ hidden_state = self.activation2(hidden_state)
266
+ hidden_state = self.convolution2(hidden_state)
267
+
268
+ if self.use_batch_norm:
269
+ hidden_state = self.batch_norm2(hidden_state)
270
+
271
+ return hidden_state + residual
272
+
273
+
274
+ class DPTFeatureFusionLayer(nn.Module):
275
+ """Feature fusion layer, merges feature maps from different stages.
276
+
277
+ Args:
278
+ config (`[DPTConfig]`):
279
+ Model configuration class defining the model architecture.
280
+ align_corners (`bool`, *optional*, defaults to `True`):
281
+ The align_corner setting for bilinear upsample.
282
+ """
283
+
284
+ def __init__(self, config, align_corners=True, with_residual_1=True):
285
+ super().__init__()
286
+
287
+ self.align_corners = align_corners
288
+
289
+ self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
290
+
291
+ if with_residual_1:
292
+ self.residual_layer1 = DPTPreActResidualLayer(config)
293
+ self.residual_layer2 = DPTPreActResidualLayer(config)
294
+
295
+ def forward(self, hidden_state, residual=None):
296
+ if residual is not None:
297
+ if hidden_state.shape != residual.shape:
298
+ residual = nn.functional.interpolate(
299
+ residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False
300
+ )
301
+ hidden_state = hidden_state + self.residual_layer1(residual)
302
+
303
+ hidden_state = self.residual_layer2(hidden_state)
304
+ hidden_state = nn.functional.interpolate(
305
+ hidden_state, scale_factor=2, mode="bilinear", align_corners=self.align_corners
306
+ )
307
+ hidden_state = self.projection(hidden_state)
308
+
309
+ return hidden_state
310
+
311
+
312
+ class DPTFeatureFusionStage(nn.Module):
313
+ def __init__(self, config):
314
+ super().__init__()
315
+ self.layers = nn.ModuleList()
316
+ for i in range(len(config.neck_hidden_sizes)):
317
+ if i == 0:
318
+ self.layers.append(DPTFeatureFusionLayer(config, with_residual_1=False))
319
+ else:
320
+ self.layers.append(DPTFeatureFusionLayer(config))
321
+
322
+ def forward(self, hidden_states):
323
+ # reversing the hidden_states, we start from the last
324
+ hidden_states = hidden_states[::-1]
325
+
326
+ fused_hidden_states = []
327
+ # first layer only uses the last hidden_state
328
+ fused_hidden_state = self.layers[0](hidden_states[0])
329
+ fused_hidden_states.append(fused_hidden_state)
330
+ # looping from the last layer to the second
331
+ for hidden_state, layer in zip(hidden_states[1:], self.layers[1:]):
332
+ fused_hidden_state = layer(fused_hidden_state, hidden_state)
333
+ fused_hidden_states.append(fused_hidden_state)
334
+
335
+ return fused_hidden_states
336
+
337
+
338
+ class DPTNeck(nn.Module):
339
+ """
340
+ DPTNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as
341
+ input and produces another list of tensors as output. For DPT, it includes 2 stages:
342
+
343
+ * DPTReassembleStage
344
+ * DPTFeatureFusionStage.
345
+
346
+ Args:
347
+ config (dict): config dict.
348
+ """
349
+
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.config = config
353
+
354
+ # postprocessing: only required in case of a non-hierarchical backbone (e.g. ViT, BEiT)
355
+ if config.backbone_config is not None and config.backbone_config.model_type in ["swinv2"]:
356
+ self.reassemble_stage = None
357
+ else:
358
+ self.reassemble_stage = DPTReassembleStage(config)
359
+
360
+ self.convs = nn.ModuleList()
361
+ for channel in config.neck_hidden_sizes:
362
+ self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))
363
+
364
+ # fusion
365
+ self.fusion_stage = DPTFeatureFusionStage(config)
366
+
367
+ def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]:
368
+ """
369
+ Args:
370
+ hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
371
+ List of hidden states from the backbone.
372
+ """
373
+ if not isinstance(hidden_states, (tuple, list)):
374
+ raise TypeError("hidden_states should be a tuple or list of tensors")
375
+
376
+ if len(hidden_states) != len(self.config.neck_hidden_sizes):
377
+ raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.")
378
+
379
+ # postprocess hidden states
380
+ if self.reassemble_stage is not None:
381
+ hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
382
+
383
+ features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
384
+
385
+ # fusion blocks
386
+ output = self.fusion_stage(features)
387
+
388
+ return output
389
+
390
+
391
+ DPT_INPUTS_DOCSTRING = r"""
392
+ Args:
393
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
394
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`]
395
+ for details.
396
+
397
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
398
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
399
+
400
+ - 1 indicates the head is **not masked**,
401
+ - 0 indicates the head is **masked**.
402
+
403
+ output_attentions (`bool`, *optional*):
404
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
405
+ tensors for more detail.
406
+ output_hidden_states (`bool`, *optional*):
407
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
408
+ more detail.
409
+ return_dict (`bool`, *optional*):
410
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
411
+ """
412
+
413
+ _CONFIG_FOR_DOC = "DPTConfig"
414
+
415
+
416
+ class DPTNeckHeadForUnetAfterUpsample(DPTPreTrainedModel):
417
+ def __init__(self, config):
418
+ super().__init__(config)
419
+
420
+ # self.backbone = None
421
+ # if config.backbone_config is not None and config.is_hybrid is False:
422
+ # self.backbone = load_backbone(config)
423
+ # else:
424
+ # self.dpt = DPTModel(config, add_pooling_layer=False)
425
+
426
+ self.feature_upsample_0 = Upsample2D(channels=config.neck_hidden_sizes[0], use_conv=True)
427
+ # self.feature_upsample_1 = Upsample2D(channels=config.neck_hidden_sizes[1], use_conv=True)
428
+ # self.feature_upsample_2 = Upsample2D(channels=config.neck_hidden_sizes[2], use_conv=True)
429
+ # self.feature_upsample_3 = Upsample2D(channels=config.neck_hidden_sizes[3], use_conv=True)
430
+
431
+ # Neck
432
+ self.neck = DPTNeck(config)
433
+ self.neck.reassemble_stage = None
434
+
435
+ # Depth estimation head
436
+ self.head = DPTDepthEstimationHead(config)
437
+
438
+ # Initialize weights and apply final processing
439
+ self.post_init()
440
+
441
+ @add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING)
442
+ @replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC)
443
+ def forward(
444
+ self,
445
+ hidden_states,
446
+ head_mask: Optional[torch.FloatTensor] = None,
447
+ labels: Optional[torch.LongTensor] = None,
448
+ output_attentions: Optional[bool] = None,
449
+ output_hidden_states: Optional[bool] = None,
450
+ return_depth_only: bool = False,
451
+ return_dict: Optional[bool] = None,
452
+ ) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]:
453
+ r"""
454
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
455
+ Ground truth depth estimation maps for computing the loss.
456
+
457
+ Returns:
458
+
459
+ Examples:
460
+ ```python
461
+ >>> from transformers import AutoImageProcessor, DPTForDepthEstimation
462
+ >>> import torch
463
+ >>> import numpy as np
464
+ >>> from PIL import Image
465
+ >>> import requests
466
+
467
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
468
+ >>> image = Image.open(requests.get(url, stream=True).raw)
469
+
470
+ >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large")
471
+ >>> model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
472
+
473
+ >>> # prepare image for the model
474
+ >>> inputs = image_processor(images=image, return_tensors="pt")
475
+
476
+ >>> with torch.no_grad():
477
+ ... outputs = model(**inputs)
478
+ ... predicted_depth = outputs.predicted_depth
479
+
480
+ >>> # interpolate to original size
481
+ >>> prediction = torch.nn.functional.interpolate(
482
+ ... predicted_depth.unsqueeze(1),
483
+ ... size=image.size[::-1],
484
+ ... mode="bicubic",
485
+ ... align_corners=False,
486
+ ... )
487
+
488
+ >>> # visualize the prediction
489
+ >>> output = prediction.squeeze().cpu().numpy()
490
+ >>> formatted = (output * 255 / np.max(output)).astype("uint8")
491
+ >>> depth = Image.fromarray(formatted)
492
+ ```"""
493
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
494
+ output_hidden_states = (
495
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
496
+ )
497
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
498
+
499
+ # if self.backbone is not None:
500
+ # outputs = self.backbone.forward_with_filtered_kwargs(
501
+ # pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
502
+ # )
503
+ # hidden_states = outputs.feature_maps
504
+ # else:
505
+ # outputs = self.dpt(
506
+ # pixel_values,
507
+ # head_mask=head_mask,
508
+ # output_attentions=output_attentions,
509
+ # output_hidden_states=True, # we need the intermediate hidden states
510
+ # return_dict=return_dict,
511
+ # )
512
+ # hidden_states = outputs.hidden_states if return_dict else outputs[1]
513
+ # # only keep certain features based on config.backbone_out_indices
514
+ # # note that the hidden_states also include the initial embeddings
515
+ # if not self.config.is_hybrid:
516
+ # hidden_states = [
517
+ # feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices
518
+ # ]
519
+ # else:
520
+ # backbone_hidden_states = outputs.intermediate_activations if return_dict else list(outputs[-1])
521
+ # backbone_hidden_states.extend(
522
+ # feature
523
+ # for idx, feature in enumerate(hidden_states[1:])
524
+ # if idx in self.config.backbone_out_indices[2:]
525
+ # )
526
+
527
+ # hidden_states = backbone_hidden_states
528
+
529
+
530
+ assert len(hidden_states) == 4
531
+
532
+ # upsample hidden_states for unet
533
+ # hidden_states = [getattr(self, "feature_upsample_%s" %i)(hidden_states[i]) for i in range(len(hidden_states))]
534
+ hidden_states[0] = self.feature_upsample_0(hidden_states[0])
535
+
536
+ patch_height, patch_width = None, None
537
+ if self.config.backbone_config is not None and self.config.is_hybrid is False:
538
+ _, _, height, width = hidden_states[3].shape
539
+ height *= 8; width *= 8
540
+ patch_size = self.config.backbone_config.patch_size
541
+ patch_height = height // patch_size
542
+ patch_width = width // patch_size
543
+
544
+ hidden_states = self.neck(hidden_states, patch_height, patch_width)
545
+
546
+ predicted_depth = self.head(hidden_states)
547
+
548
+ loss = None
549
+ if labels is not None:
550
+ raise NotImplementedError("Training is not implemented yet")
551
+
552
+ if return_depth_only:
553
+ return predicted_depth
554
+
555
+ return DepthEstimatorOutput(
556
+ loss=loss,
557
+ prediction=predicted_depth,
558
+ hidden_states=None,
559
+ attentions=None,
560
+ )
561
+
562
+
563
+
564
+ class DPTDepthEstimationHeadIdentity(DPTDepthEstimationHead):
565
+ """
566
+ Output head head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
567
+ the predictions to the input resolution after the first convolutional layer (details can be found in the paper's
568
+ supplementary material).
569
+ """
570
+
571
+ def __init__(self, config):
572
+ super().__init__(config)
573
+
574
+ features = config.fusion_hidden_size
575
+ self.head = nn.Sequential(
576
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
577
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
578
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
579
+ nn.ReLU(),
580
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
581
+ nn.Identity(),
582
+ )
583
+
584
+
585
+ class DPTNeckHeadForUnetAfterUpsampleIdentity(DPTNeckHeadForUnetAfterUpsample):
586
+ def __init__(self, config):
587
+ super().__init__(config)
588
+
589
+ # Depth estimation head
590
+ self.head = DPTDepthEstimationHeadIdentity(config)
591
+
592
+ # Initialize weights and apply final processing
593
+ self.post_init()
{util → genpercept/util}/batchsize.py RENAMED
@@ -1,3 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import math
3
 
@@ -33,11 +53,13 @@ def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> i
33
  Automatically search for suitable operating batch size.
34
 
35
  Args:
36
- ensemble_size (int): Number of predictions to be ensembled
37
- input_res (int): Operating resolution of the input image.
 
 
38
 
39
  Returns:
40
- int: Operating batch size
41
  """
42
  if not torch.cuda.is_available():
43
  return 1
@@ -56,4 +78,4 @@ def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> i
56
  bs = math.ceil(ensemble_size / 2)
57
  return bs
58
 
59
- return 1
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
  import torch
22
  import math
23
 
 
53
  Automatically search for suitable operating batch size.
54
 
55
  Args:
56
+ ensemble_size (`int`):
57
+ Number of predictions to be ensembled.
58
+ input_res (`int`):
59
+ Operating resolution of the input image.
60
 
61
  Returns:
62
+ `int`: Operating batch size.
63
  """
64
  if not torch.cuda.is_available():
65
  return 1
 
78
  bs = math.ceil(ensemble_size / 2)
79
  return bs
80
 
81
+ return 1
genpercept/util/ensemble.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
+ from functools import partial
22
+ from typing import Optional, Tuple
23
+
24
+ import numpy as np
25
+ import torch
26
+
27
+ from .image_util import get_tv_resample_method, resize_max_res
28
+
29
+
30
+ def inter_distances(tensors: torch.Tensor):
31
+ """
32
+ To calculate the distance between each two depth maps.
33
+ """
34
+ distances = []
35
+ for i, j in torch.combinations(torch.arange(tensors.shape[0])):
36
+ arr1 = tensors[i : i + 1]
37
+ arr2 = tensors[j : j + 1]
38
+ distances.append(arr1 - arr2)
39
+ dist = torch.concatenate(distances, dim=0)
40
+ return dist
41
+
42
+
43
+ def ensemble_depth(
44
+ depth: torch.Tensor,
45
+ scale_invariant: bool = True,
46
+ shift_invariant: bool = True,
47
+ output_uncertainty: bool = False,
48
+ reduction: str = "median",
49
+ regularizer_strength: float = 0.02,
50
+ max_iter: int = 2,
51
+ tol: float = 1e-3,
52
+ max_res: int = 1024,
53
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
54
+ """
55
+ Ensembles depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the
56
+ number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for
57
+ depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The
58
+ alignment happens when the predictions have one or more degrees of freedom, that is when they are either
59
+ affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only
60
+ `scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`)
61
+ alignment is skipped and only ensembling is performed.
62
+
63
+ Args:
64
+ depth (`torch.Tensor`):
65
+ Input ensemble depth maps.
66
+ scale_invariant (`bool`, *optional*, defaults to `True`):
67
+ Whether to treat predictions as scale-invariant.
68
+ shift_invariant (`bool`, *optional*, defaults to `True`):
69
+ Whether to treat predictions as shift-invariant.
70
+ output_uncertainty (`bool`, *optional*, defaults to `False`):
71
+ Whether to output uncertainty map.
72
+ reduction (`str`, *optional*, defaults to `"median"`):
73
+ Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and
74
+ `"median"`.
75
+ regularizer_strength (`float`, *optional*, defaults to `0.02`):
76
+ Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1.
77
+ max_iter (`int`, *optional*, defaults to `2`):
78
+ Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options`
79
+ argument.
80
+ tol (`float`, *optional*, defaults to `1e-3`):
81
+ Alignment solver tolerance. The solver stops when the tolerance is reached.
82
+ max_res (`int`, *optional*, defaults to `1024`):
83
+ Resolution at which the alignment is performed; `None` matches the `processing_resolution`.
84
+ Returns:
85
+ A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape:
86
+ `(1, 1, H, W)`.
87
+ """
88
+ if depth.dim() != 4 or depth.shape[1] != 1:
89
+ raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.")
90
+ if reduction not in ("mean", "median"):
91
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
92
+ if not scale_invariant and shift_invariant:
93
+ raise ValueError("Pure shift-invariant ensembling is not supported.")
94
+
95
+ def init_param(depth: torch.Tensor):
96
+ init_min = depth.reshape(ensemble_size, -1).min(dim=1).values
97
+ init_max = depth.reshape(ensemble_size, -1).max(dim=1).values
98
+
99
+ if scale_invariant and shift_invariant:
100
+ init_s = 1.0 / (init_max - init_min).clamp(min=1e-6)
101
+ init_t = -init_s * init_min
102
+ param = torch.cat((init_s, init_t)).cpu().numpy()
103
+ elif scale_invariant:
104
+ init_s = 1.0 / init_max.clamp(min=1e-6)
105
+ param = init_s.cpu().numpy()
106
+ else:
107
+ raise ValueError("Unrecognized alignment.")
108
+
109
+ return param
110
+
111
+ def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor:
112
+ if scale_invariant and shift_invariant:
113
+ s, t = np.split(param, 2)
114
+ s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1)
115
+ t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1)
116
+ out = depth * s + t
117
+ elif scale_invariant:
118
+ s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1)
119
+ out = depth * s
120
+ else:
121
+ raise ValueError("Unrecognized alignment.")
122
+ return out
123
+
124
+ def ensemble(
125
+ depth_aligned: torch.Tensor, return_uncertainty: bool = False
126
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
127
+ uncertainty = None
128
+ if reduction == "mean":
129
+ prediction = torch.mean(depth_aligned, dim=0, keepdim=True)
130
+ if return_uncertainty:
131
+ uncertainty = torch.std(depth_aligned, dim=0, keepdim=True)
132
+ elif reduction == "median":
133
+ prediction = torch.median(depth_aligned, dim=0, keepdim=True).values
134
+ if return_uncertainty:
135
+ uncertainty = torch.median(
136
+ torch.abs(depth_aligned - prediction), dim=0, keepdim=True
137
+ ).values
138
+ else:
139
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
140
+ return prediction, uncertainty
141
+
142
+ def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float:
143
+ cost = 0.0
144
+ depth_aligned = align(depth, param)
145
+
146
+ for i, j in torch.combinations(torch.arange(ensemble_size)):
147
+ diff = depth_aligned[i] - depth_aligned[j]
148
+ cost += (diff**2).mean().sqrt().item()
149
+
150
+ if regularizer_strength > 0:
151
+ prediction, _ = ensemble(depth_aligned, return_uncertainty=False)
152
+ err_near = (0.0 - prediction.min()).abs().item()
153
+ err_far = (1.0 - prediction.max()).abs().item()
154
+ cost += (err_near + err_far) * regularizer_strength
155
+
156
+ return cost
157
+
158
+ def compute_param(depth: torch.Tensor):
159
+ import scipy
160
+
161
+ depth_to_align = depth.to(torch.float32)
162
+ if max_res is not None and max(depth_to_align.shape[2:]) > max_res:
163
+ try:
164
+ depth_to_align = resize_max_res(
165
+ depth_to_align, max_res, get_tv_resample_method("nearest-exact")
166
+ )
167
+ except:
168
+ depth_to_align = resize_max_res(
169
+ depth_to_align, max_res, get_tv_resample_method("bilinear")
170
+ )
171
+
172
+ param = init_param(depth_to_align)
173
+
174
+ res = scipy.optimize.minimize(
175
+ partial(cost_fn, depth=depth_to_align),
176
+ param,
177
+ method="BFGS",
178
+ tol=tol,
179
+ options={"maxiter": max_iter, "disp": False},
180
+ )
181
+
182
+ return res.x
183
+
184
+ requires_aligning = scale_invariant or shift_invariant
185
+ ensemble_size = depth.shape[0]
186
+
187
+ if requires_aligning:
188
+ param = compute_param(depth)
189
+ depth = align(depth, param)
190
+
191
+ depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty)
192
+
193
+ depth_max = depth.max()
194
+ if scale_invariant and shift_invariant:
195
+ depth_min = depth.min()
196
+ elif scale_invariant:
197
+ depth_min = 0
198
+ else:
199
+ raise ValueError("Unrecognized alignment.")
200
+ depth_range = (depth_max - depth_min).clamp(min=1e-6)
201
+ depth = (depth - depth_min) / depth_range
202
+ if output_uncertainty:
203
+ uncertainty /= depth_range
204
+
205
+ return depth, uncertainty # [1,1,H,W], [1,1,H,W]
{util → genpercept/util}/image_util.py RENAMED
@@ -1,15 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import matplotlib
2
  import numpy as np
3
  import torch
4
- from PIL import Image
5
- from torchvision import transforms
6
 
7
- def norm_to_rgb(norm):
8
- # norm: (3, H, W), range from [-1, 1]
9
- norm_rgb = ((norm + 1) * 0.5) * 255
10
- norm_rgb = np.clip(norm_rgb, a_min=0, a_max=255)
11
- norm_rgb = norm_rgb.astype(np.uint8)
12
- return norm_rgb
13
 
14
  def colorize_depth_maps(
15
  depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
@@ -20,9 +35,9 @@ def colorize_depth_maps(
20
  assert len(depth_map.shape) >= 2, "Invalid dimension"
21
 
22
  if isinstance(depth_map, torch.Tensor):
23
- depth = depth_map.detach().clone().squeeze().numpy()
24
  elif isinstance(depth_map, np.ndarray):
25
- depth = np.squeeze(depth_map.copy())
26
  # reshape to [ (B,) H, W ]
27
  if depth.ndim < 3:
28
  depth = depth[np.newaxis, :, :]
@@ -36,7 +51,7 @@ def colorize_depth_maps(
36
  if valid_mask is not None:
37
  if isinstance(depth_map, torch.Tensor):
38
  valid_mask = valid_mask.detach().numpy()
39
- valid_mask = np.squeeze(valid_mask) # [H, W] or [B, H, W]
40
  if valid_mask.ndim < 3:
41
  valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
42
  else:
@@ -61,18 +76,28 @@ def chw2hwc(chw):
61
  return hwc
62
 
63
 
64
- def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
 
 
 
 
65
  """
66
- Resize image to limit maximum edge length while keeping aspect ratio
67
 
68
  Args:
69
- img (Image.Image): Image to be resized
70
- max_edge_resolution (int): Maximum edge length (px).
 
 
 
 
71
 
72
  Returns:
73
- Image.Image: Resized image.
74
  """
75
- original_width, original_height = img.size
 
 
76
  downscale_factor = min(
77
  max_edge_resolution / original_width, max_edge_resolution / original_height
78
  )
@@ -80,93 +105,26 @@ def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
80
  new_width = int(original_width * downscale_factor)
81
  new_height = int(original_height * downscale_factor)
82
 
83
- resized_img = img.resize((new_width, new_height))
84
- return resized_img
85
-
86
- def resize_max_res_integer_16(img: Image.Image, max_edge_resolution: int) -> Image.Image:
87
- """
88
- Resize image to limit maximum edge length while keeping aspect ratio
89
-
90
- Args:
91
- img (Image.Image): Image to be resized
92
- max_edge_resolution (int): Maximum edge length (px).
93
-
94
- Returns:
95
- Image.Image: Resized image.
96
- """
97
- original_width, original_height = img.size
98
- downscale_factor = min(
99
- max_edge_resolution / original_width, max_edge_resolution / original_height
100
- )
101
-
102
- new_width = int(original_width * downscale_factor) // 16 * 16 # make sure it is integer multiples of 16, used for pixart
103
- new_height = int(original_height * downscale_factor) // 16 * 16 # make sure it is integer multiples of 16, used for pixart
104
-
105
- resized_img = img.resize((new_width, new_height))
106
- return resized_img
107
-
108
- def resize_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
109
- """
110
- Resize image to limit maximum edge length while keeping aspect ratio
111
-
112
- Args:
113
- img (Image.Image): Image to be resized
114
- max_edge_resolution (int): Maximum edge length (px).
115
-
116
- Returns:
117
- Image.Image: Resized image.
118
- """
119
-
120
- resized_img = img.resize((max_edge_resolution, max_edge_resolution))
121
  return resized_img
122
 
123
- class ResizeLongestEdge:
124
- def __init__(self, max_size, interpolation=transforms.InterpolationMode.BILINEAR):
125
- self.max_size = max_size
126
- self.interpolation = interpolation
127
-
128
- def __call__(self, img):
129
-
130
- scale = self.max_size / max(img.width, img.height)
131
- new_size = (int(img.height * scale), int(img.width * scale))
132
-
133
- return transforms.functional.resize(img, new_size, self.interpolation)
134
-
135
- class ResizeShortestEdge:
136
- def __init__(self, min_size, interpolation=transforms.InterpolationMode.BILINEAR):
137
- self.min_size = min_size
138
- self.interpolation = interpolation
139
-
140
- def __call__(self, img):
141
-
142
- scale = self.min_size / min(img.width, img.height)
143
- new_size = (int(img.height * scale), int(img.width * scale))
144
-
145
- return transforms.functional.resize(img, new_size, self.interpolation)
146
-
147
- class ResizeHard:
148
- def __init__(self, size, interpolation=transforms.InterpolationMode.BILINEAR):
149
- self.size = size
150
- self.interpolation = interpolation
151
-
152
- def __call__(self, img):
153
-
154
- new_size = (int(self.size), int(self.size))
155
-
156
- return transforms.functional.resize(img, new_size, self.interpolation)
157
-
158
-
159
- class ResizeLongestEdgeInteger:
160
- def __init__(self, max_size, interpolation=transforms.InterpolationMode.BILINEAR, integer=16):
161
- self.max_size = max_size
162
- self.interpolation = interpolation
163
- self.integer = integer
164
-
165
- def __call__(self, img):
166
-
167
- scale = self.max_size / max(img.width, img.height)
168
- new_size_h = int(img.height * scale) // self.integer * self.integer
169
- new_size_w = int(img.width * scale) // self.integer * self.integer
170
- new_size = (new_size_h, new_size_w)
171
 
172
- return transforms.functional.resize(img, new_size, self.interpolation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ # Last modified: 2024-05-24
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # --------------------------------------------------------------------------
16
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
17
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
18
+ # More information about the method can be found at https://marigoldmonodepth.github.io
19
+ # --------------------------------------------------------------------------
20
+
21
+
22
  import matplotlib
23
  import numpy as np
24
  import torch
25
+ from torchvision.transforms import InterpolationMode
26
+ from torchvision.transforms.functional import resize
27
 
 
 
 
 
 
 
28
 
29
  def colorize_depth_maps(
30
  depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
 
35
  assert len(depth_map.shape) >= 2, "Invalid dimension"
36
 
37
  if isinstance(depth_map, torch.Tensor):
38
+ depth = depth_map.detach().squeeze().numpy()
39
  elif isinstance(depth_map, np.ndarray):
40
+ depth = depth_map.copy().squeeze()
41
  # reshape to [ (B,) H, W ]
42
  if depth.ndim < 3:
43
  depth = depth[np.newaxis, :, :]
 
51
  if valid_mask is not None:
52
  if isinstance(depth_map, torch.Tensor):
53
  valid_mask = valid_mask.detach().numpy()
54
+ valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
55
  if valid_mask.ndim < 3:
56
  valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
57
  else:
 
76
  return hwc
77
 
78
 
79
+ def resize_max_res(
80
+ img: torch.Tensor,
81
+ max_edge_resolution: int,
82
+ resample_method: InterpolationMode = InterpolationMode.BILINEAR,
83
+ ) -> torch.Tensor:
84
  """
85
+ Resize image to limit maximum edge length while keeping aspect ratio.
86
 
87
  Args:
88
+ img (`torch.Tensor`):
89
+ Image tensor to be resized. Expected shape: [B, C, H, W]
90
+ max_edge_resolution (`int`):
91
+ Maximum edge length (pixel).
92
+ resample_method (`PIL.Image.Resampling`):
93
+ Resampling method used to resize images.
94
 
95
  Returns:
96
+ `torch.Tensor`: Resized image.
97
  """
98
+ assert 4 == img.dim(), f"Invalid input shape {img.shape}"
99
+
100
+ original_height, original_width = img.shape[-2:]
101
  downscale_factor = min(
102
  max_edge_resolution / original_width, max_edge_resolution / original_height
103
  )
 
105
  new_width = int(original_width * downscale_factor)
106
  new_height = int(original_height * downscale_factor)
107
 
108
+ resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  return resized_img
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ def get_tv_resample_method(method_str: str) -> InterpolationMode:
113
+ try:
114
+ resample_method_dict = {
115
+ "bilinear": InterpolationMode.BILINEAR,
116
+ "bicubic": InterpolationMode.BICUBIC,
117
+ "nearest": InterpolationMode.NEAREST_EXACT,
118
+ "nearest-exact": InterpolationMode.NEAREST_EXACT,
119
+ }
120
+ except:
121
+ resample_method_dict = {
122
+ "bilinear": InterpolationMode.BILINEAR,
123
+ "bicubic": InterpolationMode.BICUBIC,
124
+ "nearest": InterpolationMode.NEAREST,
125
+ }
126
+ resample_method = resample_method_dict.get(method_str, None)
127
+ if resample_method is None:
128
+ raise ValueError(f"Unknown resampling method: {resample_method}")
129
+ else:
130
+ return resample_method
hf_configs/dpt-sd2.1-unet-after-upsample-general/config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "add_projection": true,
4
+ "architectures": [
5
+ "DPTForDepthEstimation"
6
+ ],
7
+ "attention_probs_dropout_prob": null,
8
+ "auxiliary_loss_weight": 0.4,
9
+ "backbone_featmap_shape": null,
10
+ "backbone_out_indices": null,
11
+ "fusion_hidden_size": 256,
12
+ "head_in_index": -1,
13
+ "hidden_act": "gelu",
14
+ "hidden_dropout_prob": null,
15
+ "hidden_size": 768,
16
+ "image_size": null,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": null,
19
+ "is_hybrid": false,
20
+ "layer_norm_eps": null,
21
+ "model_type": "dpt",
22
+ "neck_hidden_sizes": [
23
+ 320,
24
+ 640,
25
+ 1280,
26
+ 1280
27
+ ],
28
+ "neck_ignore_stages": [],
29
+ "num_attention_heads": null,
30
+ "num_channels": null,
31
+ "num_hidden_layers": null,
32
+ "patch_size": null,
33
+ "qkv_bias": null,
34
+ "readout_type": "project",
35
+ "reassemble_factors": [
36
+ 4,
37
+ 2,
38
+ 1,
39
+ 0.5
40
+ ],
41
+ "semantic_classifier_dropout": 0.1,
42
+ "semantic_loss_ignore_index": 255,
43
+ "torch_dtype": "float32",
44
+ "transformers_version": null,
45
+ "use_auxiliary_head": true,
46
+ "use_batch_norm_in_fusion_residual": false,
47
+ "use_bias_in_fusion_residual": false
48
+ }
hf_configs/dpt-sd2.1-unet-after-upsample-general/preprocessor_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "do_pad": true,
4
+ "do_rescale": false,
5
+ "do_resize": true,
6
+ "ensure_multiple_of": 1,
7
+ "image_mean": [
8
+ 123.675,
9
+ 116.28,
10
+ 103.53
11
+ ],
12
+ "image_processor_type": "DPTImageProcessor",
13
+ "image_std": [
14
+ 58.395,
15
+ 57.12,
16
+ 57.375
17
+ ],
18
+ "keep_aspect_ratio": false,
19
+ "resample": 2,
20
+ "rescale_factor": 0.00392156862745098,
21
+ "size": {
22
+ "height": 392,
23
+ "width": 392
24
+ },
25
+ "size_divisor": 14
26
+ }
27
+
hf_configs/scheduler_beta_1.0_1.0/scheduler_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.29.2",
4
+ "beta_end": 1.0,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 1.0,
7
+ "clip_sample": false,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "v_prediction",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "set_alpha_to_one": false,
15
+ "skip_prk_steps": true,
16
+ "steps_offset": 1,
17
+ "thresholding": false,
18
+ "timestep_spacing": "leading",
19
+ "trained_betas": null
20
+ }
pipeline_genpercept.py DELETED
@@ -1,355 +0,0 @@
1
- # --------------------------------------------------------
2
- # Diffusion Models Trained with Large Data Are Transferable Visual Models (https://arxiv.org/abs/2403.06090)
3
- # Github source: https://github.com/aim-uofa/GenPercept
4
- # Copyright (c) 2024 Zhejiang University
5
- # Licensed under The CC0 1.0 License [see LICENSE for details]
6
- # By Guangkai Xu
7
- # Based on Marigold, diffusers codebases
8
- # https://github.com/prs-eth/marigold
9
- # https://github.com/huggingface/diffusers
10
- # --------------------------------------------------------
11
-
12
- import torch
13
- import numpy as np
14
- import torch.nn.functional as F
15
- import matplotlib.pyplot as plt
16
-
17
- from tqdm.auto import tqdm
18
- from PIL import Image
19
- from typing import List, Dict, Union
20
- from torch.utils.data import DataLoader, TensorDataset
21
-
22
- from diffusers import (
23
- DiffusionPipeline,
24
- UNet2DConditionModel,
25
- AutoencoderKL,
26
- )
27
- from diffusers.utils import BaseOutput
28
-
29
- from util.image_util import chw2hwc, colorize_depth_maps, resize_max_res, norm_to_rgb, resize_res
30
- from util.batchsize import find_batch_size
31
-
32
- class GenPerceptOutput(BaseOutput):
33
-
34
- pred_np: np.ndarray
35
- pred_colored: Image.Image
36
-
37
- class GenPerceptPipeline(DiffusionPipeline):
38
-
39
- vae_scale_factor = 0.18215
40
- task_infos = {
41
- 'depth': dict(task_channel_num=1, interpolate='bilinear', ),
42
- 'seg': dict(task_channel_num=3, interpolate='nearest', ),
43
- 'sr': dict(task_channel_num=3, interpolate='nearest', ),
44
- 'normal': dict(task_channel_num=3, interpolate='bilinear', ),
45
- }
46
-
47
- def __init__(
48
- self,
49
- unet: UNet2DConditionModel,
50
- vae: AutoencoderKL,
51
- customized_head=None,
52
- empty_text_embed=None,
53
- ):
54
- super().__init__()
55
-
56
- self.empty_text_embed = empty_text_embed
57
-
58
- # register
59
- register_dict = dict(
60
- unet=unet,
61
- vae=vae,
62
- customized_head=customized_head,
63
- )
64
- self.register_modules(**register_dict)
65
-
66
- @torch.no_grad()
67
- def __call__(
68
- self,
69
- input_image: Union[Image.Image, torch.Tensor],
70
- mode: str = 'depth',
71
- resize_hard = False,
72
- processing_res: int = 768,
73
- match_input_res: bool = False,
74
- batch_size: int = 0,
75
- color_map: str = "Spectral",
76
- show_progress_bar: bool = True,
77
- ) -> GenPerceptOutput:
78
- """
79
- Function invoked when calling the pipeline.
80
-
81
- Args:
82
- input_image (Image):
83
- Input RGB (or gray-scale) image.
84
- processing_res (int, optional):
85
- Maximum resolution of processing.
86
- If set to 0: will not resize at all.
87
- Defaults to 768.
88
- match_input_res (bool, optional):
89
- Resize depth prediction to match input resolution.
90
- Only valid if `limit_input_res` is not None.
91
- Defaults to True.
92
- batch_size (int, optional):
93
- Inference batch size.
94
- If set to 0, the script will automatically decide the proper batch size.
95
- Defaults to 0.
96
- show_progress_bar (bool, optional):
97
- Display a progress bar of diffusion denoising.
98
- Defaults to True.
99
- color_map (str, optional):
100
- Colormap used to colorize the depth map.
101
- Defaults to "Spectral".
102
- Returns:
103
- `GenPerceptOutput`
104
- """
105
-
106
- device = self.device
107
-
108
- task_channel_num = self.task_infos[mode]['task_channel_num']
109
-
110
- if not match_input_res:
111
- assert (
112
- processing_res is not None
113
- ), "Value error: `resize_output_back` is only valid with "
114
- assert processing_res >= 0
115
-
116
- # ----------------- Image Preprocess -----------------
117
-
118
- if type(input_image) == torch.Tensor: # [B, 3, H, W]
119
- rgb_norm = input_image.to(device)
120
- input_size = input_image.shape[2:]
121
- bs_imgs = rgb_norm.shape[0]
122
- assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
123
- rgb_norm = rgb_norm.to(self.dtype)
124
- else:
125
- # if len(rgb_paths) > 0 and 'kitti' in rgb_paths[0]:
126
- # # kb crop
127
- # height = input_image.size[1]
128
- # width = input_image.size[0]
129
- # top_margin = int(height - 352)
130
- # left_margin = int((width - 1216) / 2)
131
- # input_image = input_image.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
132
-
133
- # TODO: check the kitti evaluation resolution here.
134
- input_size = (input_image.size[1], input_image.size[0])
135
- # Resize image
136
- if processing_res > 0:
137
- if resize_hard:
138
- input_image = resize_res(
139
- input_image, max_edge_resolution=processing_res
140
- )
141
- else:
142
- input_image = resize_max_res(
143
- input_image, max_edge_resolution=processing_res
144
- )
145
- input_image = input_image.convert("RGB")
146
- image = np.asarray(input_image)
147
-
148
- # Normalize rgb values
149
- rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
150
- rgb_norm = rgb / 255.0 * 2.0 - 1.0
151
- rgb_norm = torch.from_numpy(rgb_norm).to(self.unet.dtype)
152
- rgb_norm = rgb_norm[None].to(device)
153
- assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
154
- bs_imgs = 1
155
-
156
- # ----------------- Predicting depth -----------------
157
-
158
- single_rgb_dataset = TensorDataset(rgb_norm)
159
- if batch_size > 0:
160
- _bs = batch_size
161
- else:
162
- _bs = find_batch_size(
163
- ensemble_size=1,
164
- input_res=max(rgb_norm.shape[1:]),
165
- dtype=self.dtype,
166
- )
167
-
168
- single_rgb_loader = DataLoader(
169
- single_rgb_dataset, batch_size=_bs, shuffle=False
170
- )
171
-
172
- # Predict depth maps (batched)
173
- pred_list = []
174
- if show_progress_bar:
175
- iterable = tqdm(
176
- single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
177
- )
178
- else:
179
- iterable = single_rgb_loader
180
-
181
- for batch in iterable:
182
- (batched_img, ) = batch
183
- pred = self.single_infer(
184
- rgb_in=batched_img,
185
- mode=mode,
186
- )
187
- pred_list.append(pred.detach().clone())
188
- preds = torch.concat(pred_list, axis=0).squeeze() # [bs_imgs, task_channel_num, H, W]
189
- preds = preds.view(bs_imgs, task_channel_num, preds.shape[-2], preds.shape[-1])
190
-
191
- if match_input_res:
192
- preds = F.interpolate(preds, input_size, mode=self.task_infos[mode]['interpolate'])
193
-
194
- # ----------------- Post processing -----------------
195
- if mode == 'depth':
196
- if len(preds.shape) == 4:
197
- preds = preds[:, 0] # [bs_imgs, H, W]
198
- # Scale prediction to [0, 1]
199
- min_d = preds.view(bs_imgs, -1).min(dim=1)[0]
200
- max_d = preds.view(bs_imgs, -1).max(dim=1)[0]
201
- preds = (preds - min_d[:, None, None]) / (max_d[:, None, None] - min_d[:, None, None])
202
- preds = preds.cpu().numpy().astype(np.float32)
203
- # Colorize
204
- pred_colored_img_list = []
205
- for i in range(bs_imgs):
206
- pred_colored_chw = colorize_depth_maps(
207
- preds[i], 0, 1, cmap=color_map
208
- ).squeeze() # [3, H, W], value in (0, 1)
209
- pred_colored_chw = (pred_colored_chw * 255).astype(np.uint8)
210
- pred_colored_hwc = chw2hwc(pred_colored_chw)
211
- pred_colored_img = Image.fromarray(pred_colored_hwc)
212
- pred_colored_img_list.append(pred_colored_img)
213
-
214
- return GenPerceptOutput(
215
- pred_np=np.squeeze(preds),
216
- pred_colored=pred_colored_img_list[0] if len(pred_colored_img_list) == 1 else pred_colored_img_list,
217
- )
218
-
219
- elif mode == 'seg' or mode == 'sr':
220
- if not self.customized_head:
221
- # shift to [0, 1]
222
- preds = (preds + 1.0) / 2.0
223
- # shift to [0, 255]
224
- preds = preds * 255
225
- # Clip output range
226
- preds = preds.clip(0, 255).cpu().numpy().astype(np.uint8)
227
- else:
228
- raise NotImplementedError
229
-
230
- pred_colored_img_list = []
231
- for i in range(preds.shape[0]):
232
- pred_colored_hwc = chw2hwc(preds[i])
233
- pred_colored_img = Image.fromarray(pred_colored_hwc)
234
- pred_colored_img_list.append(pred_colored_img)
235
-
236
- return GenPerceptOutput(
237
- pred_np=np.squeeze(preds),
238
- pred_colored=pred_colored_img_list[0] if len(pred_colored_img_list) == 1 else pred_colored_img_list,
239
- )
240
-
241
- elif mode == 'normal':
242
- if not self.customized_head:
243
- preds = preds.clip(-1, 1).cpu().numpy() # [-1, 1]
244
- else:
245
- raise NotImplementedError
246
-
247
- pred_colored_img_list = []
248
- for i in range(preds.shape[0]):
249
- pred_colored_chw = norm_to_rgb(preds[i])
250
- pred_colored_hwc = chw2hwc(pred_colored_chw)
251
- normal_colored_img_i = Image.fromarray(pred_colored_hwc)
252
- pred_colored_img_list.append(normal_colored_img_i)
253
-
254
- return GenPerceptOutput(
255
- pred_np=np.squeeze(preds),
256
- pred_colored=pred_colored_img_list[0] if len(pred_colored_img_list) == 1 else pred_colored_img_list,
257
- )
258
-
259
- else:
260
- raise NotImplementedError
261
-
262
- @torch.no_grad()
263
- def single_infer(
264
- self,
265
- rgb_in: torch.Tensor,
266
- mode: str = 'depth',
267
- ) -> torch.Tensor:
268
- """
269
- Perform an individual depth prediction without ensembling.
270
-
271
- Args:
272
- rgb_in (torch.Tensor):
273
- Input RGB image.
274
- num_inference_steps (int):
275
- Number of diffusion denoising steps (DDIM) during inference.
276
- show_pbar (bool):
277
- Display a progress bar of diffusion denoising.
278
-
279
- Returns:
280
- torch.Tensor: Predicted depth map.
281
- """
282
- device = rgb_in.device
283
- bs_imgs = rgb_in.shape[0]
284
- timesteps = torch.tensor([1]).long().repeat(bs_imgs).to(device)
285
-
286
- # Encode image
287
- rgb_latent = self.encode_rgb(rgb_in)
288
-
289
- batch_embed = self.empty_text_embed
290
- batch_embed = batch_embed.repeat((rgb_latent.shape[0], 1, 1)).to(device) # [bs_imgs, 77, 1024]
291
-
292
- # Forward!
293
- if self.customized_head:
294
- unet_features = self.unet(rgb_latent, timesteps, encoder_hidden_states=batch_embed, return_feature_only=True)[0][::-1]
295
- pred = self.customized_head(unet_features)
296
- else:
297
- unet_output = self.unet(
298
- rgb_latent, timesteps, encoder_hidden_states=batch_embed
299
- ) # [bs_imgs, 4, h, w]
300
- unet_pred = unet_output.sample
301
- pred_latent = - unet_pred
302
- pred_latent.to(device)
303
- pred = self.decode_pred(pred_latent)
304
- if mode == 'depth':
305
- # mean of output channels
306
- pred = pred.mean(dim=1, keepdim=True)
307
- # clip prediction
308
- pred = torch.clip(pred, -1.0, 1.0)
309
- return pred
310
-
311
-
312
- def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
313
- """
314
- Encode RGB image into latent.
315
-
316
- Args:
317
- rgb_in (torch.Tensor):
318
- Input RGB image to be encoded.
319
-
320
- Returns:
321
- torch.Tensor: Image latent
322
- """
323
- try:
324
- # encode
325
- h_temp = self.vae.encoder(rgb_in)
326
- moments = self.vae.quant_conv(h_temp)
327
- except:
328
- # encode
329
- h_temp = self.vae.encoder(rgb_in.float())
330
- moments = self.vae.quant_conv(h_temp.float())
331
-
332
- mean, logvar = torch.chunk(moments, 2, dim=1)
333
- # scale latent
334
- rgb_latent = mean * self.vae_scale_factor
335
- return rgb_latent
336
-
337
- def decode_pred(self, pred_latent: torch.Tensor) -> torch.Tensor:
338
- """
339
- Decode pred latent into pred label.
340
-
341
- Args:
342
- pred_latent (torch.Tensor):
343
- prediction latent to be decoded.
344
-
345
- Returns:
346
- torch.Tensor: Decoded prediction label.
347
- """
348
- # scale latent
349
- pred_latent = pred_latent / self.vae_scale_factor
350
- # decode
351
- z = self.vae.post_quant_conv(pred_latent)
352
- pred = self.vae.decoder(z)
353
-
354
- return pred
355
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -21,3 +21,7 @@ spaces
21
  gradio>=4.32.2
22
  gradio_client>=0.17.0
23
  gradio_imageslider>=0.0.20
 
 
 
 
 
21
  gradio>=4.32.2
22
  gradio_client>=0.17.0
23
  gradio_imageslider>=0.0.20
24
+ omegaconf
25
+ tabulate
26
+ wandb
27
+ pandas
seg_images/seg_1.jpg ADDED
seg_images/seg_2.jpg ADDED
seg_images/seg_3.jpg ADDED
seg_images/seg_4.jpg ADDED
seg_images/seg_5.jpg ADDED
util/__init__.py DELETED
File without changes
util/seed_all.py DELETED
@@ -1,13 +0,0 @@
1
- import numpy as np
2
- import random
3
- import torch
4
-
5
-
6
- def seed_all(seed: int = 0):
7
- """
8
- Set random seeds of all components.
9
- """
10
- random.seed(seed)
11
- np.random.seed(seed)
12
- torch.manual_seed(seed)
13
- torch.cuda.manual_seed_all(seed)