thinh-researcher commited on
Commit
6e9c433
·
1 Parent(s): 2cadd70
Files changed (49) hide show
  1. .gitignore +28 -0
  2. README.md +72 -1
  3. definition.py +21 -0
  4. docs/figures/proposed_method_v5.drawio.png +0 -0
  5. docs/references/Dataset.bib +124 -0
  6. docs/references/References.bib +190 -0
  7. docs/references/SOTAs.bib +355 -0
  8. requirements-lock.txt +103 -0
  9. requirements.txt +18 -0
  10. s_multimae/__init__.py +0 -0
  11. s_multimae/configs/__init__.py +0 -0
  12. s_multimae/configs/base_config.py +164 -0
  13. s_multimae/configs/data_augmentation_config.py +19 -0
  14. s_multimae/configs/experiment_config.py +31 -0
  15. s_multimae/configs/experiment_configs/__init__.py +0 -0
  16. s_multimae/configs/experiment_configs/expv1_dynamic.py +277 -0
  17. s_multimae/da/__init__.py +0 -0
  18. s_multimae/da/base_da.py +33 -0
  19. s_multimae/da/dav6.py +147 -0
  20. s_multimae/data_augmentation.py +19 -0
  21. s_multimae/model/__init__.py +0 -0
  22. s_multimae/model/components.py +117 -0
  23. s_multimae/model/multimae.py +938 -0
  24. s_multimae/model_pl.py +105 -0
  25. s_multimae/rgbd_model.py +60 -0
  26. s_multimae/utils.py +236 -0
  27. s_multimae/visualize_2d_posemb.py +58 -0
  28. s_multimae/visualizer.py +711 -0
  29. streamlit_apps/__init__.py +0 -0
  30. streamlit_apps/app.py +91 -0
  31. streamlit_apps/app_utils/__init__.py +0 -0
  32. streamlit_apps/app_utils/app_env.py +16 -0
  33. streamlit_apps/app_utils/app_utils.py +83 -0
  34. streamlit_apps/app_utils/base_model.py +54 -0
  35. streamlit_apps/app_utils/color_selection_ui.py +10 -0
  36. streamlit_apps/app_utils/depth_model.py +77 -0
  37. streamlit_apps/app_utils/depth_selection_ui.py +27 -0
  38. streamlit_apps/app_utils/device.py +5 -0
  39. streamlit_apps/app_utils/dpt/__init__.py +0 -0
  40. streamlit_apps/app_utils/dpt/base_model.py +16 -0
  41. streamlit_apps/app_utils/dpt/blocks.py +383 -0
  42. streamlit_apps/app_utils/dpt/midas_net.py +78 -0
  43. streamlit_apps/app_utils/dpt/models.py +124 -0
  44. streamlit_apps/app_utils/dpt/transforms.py +231 -0
  45. streamlit_apps/app_utils/dpt/vit.py +576 -0
  46. streamlit_apps/app_utils/image_inference.py +88 -0
  47. streamlit_apps/app_utils/model.py +84 -0
  48. streamlit_apps/app_utils/smultimae_model.py +43 -0
  49. streamlit_apps/app_utils/sod_selection_ui.py +111 -0
.gitignore ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ env
2
+ __pycache__/
3
+ *.pth
4
+ *.pt
5
+
6
+ datasets/**/benchmark/*
7
+ datasets/**/test/*
8
+ datasets/**/train/*
9
+ datasets/**/dev/*
10
+ datasets/**/*.zip
11
+
12
+ sources/deployment/*
13
+ sources/experiment/*
14
+ sources/pickle/*
15
+ sources/csv/*/*
16
+ sources/json/*
17
+ sotas/*
18
+ continue_training/*
19
+
20
+ weights/*
21
+
22
+ !*.gitkeep
23
+ logs
24
+ wandb
25
+ tmp/
26
+
27
+ wandb_cache
28
+ script.md
README.md CHANGED
@@ -9,4 +9,75 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pinned: false
10
  ---
11
 
12
+ # S-MultiMAE
13
+
14
+ This repository provides the official implementation of `S-MultiMAE A Multi-Ground Truth approach for RGB-D Saliency Detection`
15
+
16
+ _Nguyen Truong Thinh Huynh, Van Linh Pham, Xuan Toan Mai and Tuan Anh Tran_
17
+
18
+ ![alt text](docs/figures/proposed_method_v5.drawio.png)
19
+
20
+ ## Model weights
21
+
22
+ | Backbone | #params | Training paradigm | Weights | Input size |
23
+ | -------- | ----------- | ----------------- | ---------------------------------------------------------------------------------------------- | ---------- |
24
+ | ViT-L | 328,318,529 | Multi-GT | [Download](https://drive.google.com/file/d/1YhAuu3DI2adPLQgbgoSt74ilZbpuKihh/view?usp=sharing) | 224x224 |
25
+ | ViT-B | 107,654,977 | Multi-GT | [Download](https://drive.google.com/file/d/13Omafif3pvPKgg3Isp_srkHf8CSPx33d/view?usp=sharing) | 224x224 |
26
+
27
+ ## How to run
28
+
29
+ ### Create a virtual environment
30
+
31
+ We recommend using python 3.10 or higher.
32
+
33
+ ```bash
34
+ python3.10 -m venv env
35
+ source env/bin/activate
36
+ pip install -r requirements.txt
37
+ ```
38
+
39
+ ### Download trained weights
40
+
41
+ - Download model weights and put it in the folder `weights`. You may also need to download the weights of [DPT model]() (a rgb2depth model). The `weights` folder will look like this:
42
+
43
+ ```bash
44
+ ├── weights
45
+ │ ├── omnidata_rgb2depth_dpt_hybrid.pth
46
+ │ ├── s-multimae-cfgv4_0_2006-top1.pth
47
+ │ ├── s-multimae-cfgv4_0_2007-top1.pth
48
+ ```
49
+
50
+ ### Run
51
+
52
+ - Run streamlit app
53
+
54
+ ```
55
+ streamlit run streamlit_apps/app.py --server.port 9113 --browser.gatherUsageStats False --server.fileWatcherType none
56
+ ```
57
+
58
+ ## Datasets
59
+
60
+ ### COME15K dataset
61
+
62
+ | | 1 GT | 2 GTs | 3 GTs | 4 GTs | 5 GTs |
63
+ | --------------------- | ------ | ----- | ------ | ----- | ----- |
64
+ | COME8K (8025 samples) | 77.61% | 1.71% | 18.28% | 2.24% | 0.16% |
65
+ | COME-E (4600 samples) | 70.5% | 1.87% | 21.15% | 5.70% | 0.78% |
66
+ | COME8K (3000 samples) | 62.3% | 2.00% | 25.63% | 8.37% | 1.70% |
67
+
68
+ ```
69
+ @inproceedings{cascaded_rgbd_sod,
70
+ title={RGB-D Saliency Detection via Cascaded Mutual Information Minimization},
71
+ author={Zhang, Jing and Fan, Deng-Ping and Dai, Yuchao and Yu, Xin and Zhong, Yiran and Barnes, Nick and Shao, Ling},
72
+ booktitle={International Conference on Computer Vision (ICCV)},
73
+ year={2021}
74
+ }
75
+ ```
76
+
77
+ ## References
78
+
79
+ All references are cited in these files:
80
+
81
+ - [Datasets](./docs/references/Dataset.bib)
82
+ - [SOTAs](./docs/references/SOTAs.bib)
83
+ - [Others](./docs/references/References.bib)
definition.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Do not import other modules!
3
+ """
4
+
5
+
6
+ class PRETRAINED_BACKBONE:
7
+ MULTIMAE = "multimae"
8
+
9
+ S_MULTIMAE = "s-multimae"
10
+ LARGE_S_MULTIMAE = "large-s-multimae"
11
+
12
+ MAE = "mae"
13
+ LARGE_MAE = "large-mae"
14
+ HUGE_MAE = "huge-mae"
15
+
16
+ FINETUNE_LARGE_S_MULTIMAE = "finetune-large-s-multimae"
17
+ FINETUNE_S_MULTIMAE = "finetune-s-multimae"
18
+
19
+ VIT = "vit" # train from supervised model
20
+
21
+ NONE = None # train from scratch
docs/figures/proposed_method_v5.drawio.png ADDED
docs/references/Dataset.bib ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ % Encoding: UTF-8
2
+
3
+ % DES
4
+ @inproceedings{cheng2014depth,
5
+ title={Depth enhanced saliency detection method},
6
+ author={Cheng, Yupeng and Fu, Huazhu and Wei, Xingxing and Xiao, Jiangjian and Cao, Xiaochun},
7
+ booktitle={Proceedings of international conference on internet multimedia computing and service},
8
+ pages={23--27},
9
+ year={2014}
10
+ }
11
+
12
+ % DUT-RGBD
13
+ @inproceedings{piao2019depth,
14
+ title={Depth-induced multi-scale recurrent attention network for saliency detection},
15
+ author={Piao, Yongri and Ji, Wei and Li, Jingjing and Zhang, Miao and Lu, Huchuan},
16
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
17
+ pages={7254--7263},
18
+ year={2019}
19
+ }
20
+
21
+ % LFSD
22
+ @inproceedings{li2014saliency,
23
+ title={Saliency detection on light field},
24
+ author={Li, Nianyi and Ye, Jinwei and Ji, Yu and Ling, Haibin and Yu, Jingyi},
25
+ booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
26
+ pages={2806--2813},
27
+ year={2014}
28
+ }
29
+
30
+ % NJU2K
31
+ @inproceedings{ju2014depth,
32
+ title={Depth saliency based on anisotropic center-surround difference},
33
+ author={Ju, Ran and Ge, Ling and Geng, Wenjing and Ren, Tongwei and Wu, Gangshan},
34
+ booktitle={2014 IEEE international conference on image processing (ICIP)},
35
+ pages={1115--1119},
36
+ year={2014},
37
+ organization={IEEE}
38
+ }
39
+
40
+ % SSD
41
+ @inproceedings{zhu2017three,
42
+ title={A three-pathway psychobiological framework of salient object detection using stereoscopic technology},
43
+ author={Zhu, Chunbiao and Li, Ge},
44
+ booktitle={Proceedings of the IEEE international conference on computer vision workshops},
45
+ pages={3008--3014},
46
+ year={2017}
47
+ }
48
+
49
+ % Holo50K
50
+ @article{hua2020holopix50k,
51
+ title={Holopix50k: A large-scale in-the-wild stereo image dataset},
52
+ author={Hua, Yiwen and Kohli, Puneet and Uplavikar, Pritish and Ravi, Anand and Gunaseelan, Saravana and Orozco, Jason and Li, Edward},
53
+ journal={arXiv preprint arXiv:2003.11172},
54
+ year={2020}
55
+ }
56
+
57
+ % NLPR
58
+ @inproceedings{peng2014rgbd,
59
+ title={RGBD salient object detection: A benchmark and algorithms},
60
+ author={Peng, Houwen and Li, Bing and Xiong, Weihua and Hu, Weiming and Ji, Rongrong},
61
+ booktitle={European conference on computer vision},
62
+ pages={92--109},
63
+ year={2014},
64
+ organization={Springer}
65
+ }
66
+
67
+
68
+ % SIP
69
+ @article{fan2020rethinking,
70
+ title={Rethinking RGB-D salient object detection: Models, data sets, and large-scale benchmarks},
71
+ author={/project/634a6386039ac5d46d8c6ab0Fan, Deng-Ping and Lin, Zheng and Zhang, Zhao and Zhu, Menglong and Cheng, Ming-Ming},
72
+ journal={IEEE Transactions on neural networks and learning systems},
73
+ volume={32},
74
+ number={5},
75
+ pages={2075--2089},
76
+ year={2020},
77
+ publisher={IEEE}
78
+ }
79
+
80
+ % STERE
81
+ @inproceedings{niu2012leveraging,
82
+ title={Leveraging stereopsis for saliency analysis},
83
+ author={Niu, Yuzhen and Geng, Yujie and Li, Xueqing and Liu, Feng},
84
+ booktitle={2012 IEEE Conference on Computer Vision and Pattern Recognition},
85
+ pages={454--461},
86
+ year={2012},
87
+ organization={IEEE}
88
+ }
89
+
90
+ % RGB-Thermal
91
+ @inproceedings{ha2017mfnet,
92
+ title={MFNet: Towards real-time semantic segmentation for autonomous vehicles with multi-spectral scenes},
93
+ author={Ha, Qishen and Watanabe, Kohei and Karasawa, Takumi and Ushiku, Yoshitaka and Harada, Tatsuya},
94
+ booktitle={2017 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
95
+ pages={5108--5115},
96
+ year={2017},
97
+ organization={IEEE}
98
+ }
99
+
100
+ %RGB-Polarization
101
+ @article{xiang2021polarization,
102
+ title={Polarization-driven semantic segmentation via efficient attention-bridged fusion},
103
+ author={Xiang, Kaite and Yang, Kailun and Wang, Kaiwei},
104
+ journal={Optics Express},
105
+ volume={29},
106
+ number={4},
107
+ pages={4802--4820},
108
+ year={2021},
109
+ publisher={Optica Publishing Group}
110
+ }
111
+
112
+ % ImageNet
113
+ @article{russakovsky2015imagenet,
114
+ title={Imagenet large scale visual recognition challenge},
115
+ author={Russakovsky, Olga and Deng, Jia and Su, Hao and Krause, Jonathan and Satheesh, Sanjeev and Ma, Sean and Huang, Zhiheng and Karpathy, Andrej and Khosla, Aditya and Bernstein, Michael and others},
116
+ journal={International journal of computer vision},
117
+ volume={115},
118
+ number={3},
119
+ pages={211--252},
120
+ year={2015},
121
+ publisher={Springer}
122
+ }
123
+
124
+ @Comment{jabref-meta: databaseType:bibtex;}
docs/references/References.bib ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ % Encoding: UTF-8
2
+
3
+ % An Empirical Study of Training Self-Supervised Vision Transformers
4
+ @inproceedings{chen2021empirical,
5
+ title={An empirical study of training self-supervised vision transformers},
6
+ author={Chen, Xinlei and Xie, Saining and He, Kaiming},
7
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
8
+ pages={9640--9649},
9
+ year={2021}
10
+ }
11
+
12
+ % 2D positional embedding
13
+ @article{raisi20202d,
14
+ title={2D positional embedding-based transformer for scene text recognition},
15
+ author={Raisi, Zobeir and Naiel, Mohamed A and Fieguth, Paul and Wardell, Steven and Zelek, John},
16
+ journal={Journal of Computational Vision and Imaging Systems},
17
+ volume={6},
18
+ number={1},
19
+ pages={1--4},
20
+ year={2020}
21
+ }
22
+
23
+ % Layer Normalization
24
+ @article{ba2016layer,
25
+ title={Layer normalization},
26
+ author={Ba, Jimmy Lei and Kiros, Jamie Ryan and Hinton, Geoffrey E},
27
+ journal={arXiv preprint arXiv:1607.06450},
28
+ year={2016}
29
+ }
30
+
31
+ % Batch Normalization
32
+ @inproceedings{ioffe2015batch,
33
+ title={Batch normalization: Accelerating deep network training by reducing internal covariate shift},
34
+ author={Ioffe, Sergey and Szegedy, Christian},
35
+ booktitle={International conference on machine learning},
36
+ pages={448--456},
37
+ year={2015},
38
+ organization={PMLR}
39
+ }
40
+
41
+ % ReLU
42
+ @article{fukushima1975cognitron,
43
+ title={Cognitron: A self-organizing multilayered neural network},
44
+ author={Fukushima, Kunihiko},
45
+ journal={Biological cybernetics},
46
+ volume={20},
47
+ number={3},
48
+ pages={121--136},
49
+ year={1975},
50
+ publisher={Springer}
51
+ }
52
+
53
+ % Weight Normalization
54
+ @article{salimans2016weight,
55
+ title={Weight normalization: A simple reparameterization to accelerate training of deep neural networks},
56
+ author={Salimans, Tim and Kingma, Durk P},
57
+ journal={Advances in neural information processing systems},
58
+ volume={29},
59
+ year={2016}
60
+ }
61
+
62
+ % Stochastic depth
63
+ @inproceedings{huang2016deep,
64
+ title={Deep networks with stochastic depth},
65
+ author={Huang, Gao and Sun, Yu and Liu, Zhuang and Sedra, Daniel and Weinberger, Kilian Q},
66
+ booktitle={European conference on computer vision},
67
+ pages={646--661},
68
+ year={2016},
69
+ organization={Springer}
70
+ }
71
+
72
+ % Stereo Matching Algorithm
73
+ @article{zhong2020displacement,
74
+ title={Displacement-invariant cost computation for efficient stereo matching},
75
+ author={Zhong, Yiran and Loop, Charles and Byeon, Wonmin and Birchfield, Stan and Dai, Yuchao and Zhang, Kaihao and Kamenev, Alexey and Breuel, Thomas and Li, Hongdong and Kautz, Jan},
76
+ journal={arXiv preprint arXiv:2012.00899},
77
+ year={2020}
78
+ }
79
+
80
+ % wandb
81
+ @misc{wandb,
82
+ title = {Experiment Tracking with Weights and Biases},
83
+ year = {2020},
84
+ note = {Software available from wandb.com},
85
+ url={https://www.wandb.com/},
86
+ author = {Biewald, Lukas},
87
+ }
88
+
89
+ %
90
+ @article{borji2015salient,
91
+ title={Salient object detection: A benchmark},
92
+ author={Borji, Ali and Cheng, Ming-Ming and Jiang, Huaizu and Li, Jia},
93
+ journal={IEEE transactions on image processing},
94
+ volume={24},
95
+ number={12},
96
+ pages={5706--5722},
97
+ year={2015},
98
+ publisher={IEEE}
99
+ }
100
+
101
+ % SOD metrics
102
+ @misc{sodmetrics,
103
+ title = {PySODMetrics: A simple and efficient implementation of SOD metrics},
104
+ howpublished = {\url{https://github.com/lartpang/PySODMetrics}},
105
+ note = {Accessed: 2022-10-31}
106
+ }
107
+
108
+ % MAE
109
+ @inproceedings{perazzi2012saliency,
110
+ title={Saliency filters: Contrast based filtering for salient region detection},
111
+ author={Perazzi, Federico and Kr{\"a}henb{\"u}hl, Philipp and Pritch, Yael and Hornung, Alexander},
112
+ booktitle={2012 IEEE conference on computer vision and pattern recognition},
113
+ pages={733--740},
114
+ year={2012},
115
+ organization={IEEE}
116
+ }
117
+
118
+ % F-measure
119
+ @inproceedings{achanta2009frequency,
120
+ title={Frequency-tuned salient region detection},
121
+ author={Achanta, Radhakrishna and Hemami, Sheila and Estrada, Francisco and Susstrunk, Sabine},
122
+ booktitle={2009 IEEE conference on computer vision and pattern recognition},
123
+ pages={1597--1604},
124
+ year={2009},
125
+ organization={IEEE}
126
+ }
127
+
128
+ % E-measure
129
+ @article{fan2018enhanced,
130
+ title={Enhanced-alignment measure for binary foreground map evaluation},
131
+ author={Fan, Deng-Ping and Gong, Cheng and Cao, Yang and Ren, Bo and Cheng, Ming-Ming and Borji, Ali},
132
+ journal={arXiv preprint arXiv:1805.10421},
133
+ year={2018}
134
+ }
135
+
136
+ % S-measure
137
+ @inproceedings{fan2017structure,
138
+ title={Structure-measure: A new way to evaluate foreground maps},
139
+ author={Fan, Deng-Ping and Cheng, Ming-Ming and Liu, Yun and Li, Tao and Borji, Ali},
140
+ booktitle={Proceedings of the IEEE international conference on computer vision},
141
+ pages={4548--4557},
142
+ year={2017}
143
+ }
144
+
145
+ % GELU
146
+ @article{hendrycks2016gaussian,
147
+ title={Gaussian error linear units (gelus)},
148
+ author={Hendrycks, Dan and Gimpel, Kevin},
149
+ journal={arXiv preprint arXiv:1606.08415},
150
+ year={2016}
151
+ }
152
+
153
+ % Instance normalization
154
+ @article{ulyanov2016instance,
155
+ title={Instance normalization: The missing ingredient for fast stylization},
156
+ author={Ulyanov, Dmitry and Vedaldi, Andrea and Lempitsky, Victor},
157
+ journal={arXiv preprint arXiv:1607.08022},
158
+ year={2016}
159
+ }
160
+
161
+ % Group normalization
162
+ @inproceedings{wu2018group,
163
+ title={Group normalization},
164
+ author={Wu, Yuxin and He, Kaiming},
165
+ booktitle={Proceedings of the European conference on computer vision (ECCV)},
166
+ pages={3--19},
167
+ year={2018}
168
+ }
169
+
170
+ % timm
171
+ @misc{rw2019timm,
172
+ author = {Ross Wightman},
173
+ title = {PyTorch Image Models},
174
+ year = {2019},
175
+ publisher = {GitHub},
176
+ journal = {GitHub repository},
177
+ doi = {10.5281/zenodo.4414861},
178
+ howpublished = {\url{https://github.com/rwightman/pytorch-image-models}}
179
+ }
180
+
181
+ % taskonomy
182
+ @inproceedings{zamir2018taskonomy,
183
+ title={Taskonomy: Disentangling task transfer learning},
184
+ author={Zamir, Amir R and Sax, Alexander and Shen, William and Guibas, Leonidas J and Malik, Jitendra and Savarese, Silvio},
185
+ booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
186
+ pages={3712--3722},
187
+ year={2018}
188
+ }
189
+
190
+ @Comment{jabref-meta: databaseType:bibtex;}
docs/references/SOTAs.bib ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ % Encoding: UTF-8
2
+
3
+ % COME15K CMINet
4
+ @Article{cascaded_rgbd_sod,
5
+ title={RGB-D Saliency Detection via Cascaded Mutual Information Minimization},
6
+ author={Zhang, Jing and Fan, Deng-Ping and Dai, Yuchao and Yu, Xin and Zhong, Yiran and Barnes, Nick and Shao, Ling},
7
+ booktitle={International Conference on Computer Vision (ICCV)},
8
+ year={2021}
9
+ }
10
+
11
+ % A2dele
12
+ @Article{piao2020a2dele,
13
+ title={A2dele: Adaptive and attentive depth distiller for efficient RGB-D salient object detection},
14
+ author={Piao, Yongri and Rong, Zhengkun and Zhang, Miao and Ren, Weisong and Lu, Huchuan},
15
+ booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
16
+ pages={9060--9069},
17
+ year={2020}
18
+ }
19
+
20
+ % BBS-Net
21
+ @Article{fan2020bbs,
22
+ title={BBS-Net: RGB-D salient object detection with a bifurcated backbone strategy network},
23
+ author={Fan, Deng-Ping and Zhai, Yingjie and Borji, Ali and Yang, Jufeng and Shao, Ling},
24
+ booktitle={European conference on computer vision},
25
+ pages={275--292},
26
+ year={2020},
27
+ organization={Springer}
28
+ }
29
+
30
+ % MobileSal
31
+ @article{wu2021mobilesal,
32
+ title={MobileSal: Extremely efficient RGB-D salient object detection},
33
+ author={Wu, Yu-Huan and Liu, Yun and Xu, Jun and Bian, Jia-Wang and Gu, Yu-Chao and Cheng, Ming-Ming},
34
+ journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
35
+ year={2021},
36
+ publisher={IEEE}
37
+ }
38
+
39
+ % ATSA
40
+ @Article{zhang2020asymmetric,
41
+ title={Asymmetric two-stream architecture for accurate RGB-D saliency detection},
42
+ author={Zhang, Miao and Fei, Sun Xiao and Liu, Jie and Xu, Shuang and Piao, Yongri and Lu, Huchuan},
43
+ booktitle={European Conference on Computer Vision},
44
+ pages={374--390},
45
+ year={2020},
46
+ organization={Springer}
47
+ }
48
+
49
+ % CDNet
50
+ @article{jin2021cdnet,
51
+ title={CDNet: Complementary depth network for RGB-D salient object detection},
52
+ author={Jin, Wen-Da and Xu, Jun and Han, Qi and Zhang, Yi and Cheng, Ming-Ming},
53
+ journal={IEEE Transactions on Image Processing},
54
+ volume={30},
55
+ pages={3376--3390},
56
+ year={2021},
57
+ publisher={IEEE}
58
+ }
59
+
60
+ % CoNet
61
+ @Article{ji2020accurate,
62
+ title={Accurate RGB-D salient object detection via collaborative learning},
63
+ author={Ji, Wei and Li, Jingjing and Zhang, Miao and Piao, Yongri and Lu, Huchuan},
64
+ booktitle={European Conference on Computer Vision},
65
+ pages={52--69},
66
+ year={2020},
67
+ organization={Springer}
68
+ }
69
+
70
+ % SPNet
71
+ @inproceedings{zhou2021specificity,
72
+ title={Specificity-preserving rgb-d saliency detection},
73
+ author={Zhou, Tao and Fu, Huazhu and Chen, Geng and Zhou, Yi and Fan, Deng-Ping and Shao, Ling},
74
+ booktitle={Proceedings of the IEEE/CVF international conference on computer vision},
75
+ pages={4681--4691},
76
+ year={2021}
77
+ }
78
+
79
+ % C2DFNet
80
+ @article{zhang2022c,
81
+ title={C2DFNet: Criss-Cross Dynamic Filter Network for RGB-D Salient Object Detection},
82
+ author={Zhang, Miao and Yao, Shunyu and Hu, Beiqi and Piao, Yongri and Ji, Wei},
83
+ journal={IEEE Transactions on Multimedia},
84
+ year={2022},
85
+ publisher={IEEE}
86
+ }
87
+
88
+ % SPSN
89
+ @inproceedings{lee2022spsn,
90
+ title={SPSN: Superpixel Prototype Sampling Network for RGB-D Salient Object Detection},
91
+ author={Lee, Minhyeok and Park, Chaewon and Cho, Suhwan and Lee, Sangyoun},
92
+ booktitle={European Conference on Computer Vision},
93
+ pages={630--647},
94
+ year={2022},
95
+ organization={Springer}
96
+ }
97
+
98
+ % ConvNeXt
99
+ @inproceedings{liu2022convnet,
100
+ title={A convnet for the 2020s},
101
+ author={Liu, Zhuang and Mao, Hanzi and Wu, Chao-Yuan and Feichtenhofer, Christoph and Darrell, Trevor and Xie, Saining},
102
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
103
+ pages={11976--11986},
104
+ year={2022}
105
+ }
106
+
107
+ % GPT-2
108
+ @article{radford2019language,
109
+ title={Language models are unsupervised multitask learners},
110
+ author={Radford, Alec and Wu, Jeffrey and Child, Rewon and Luan, David and Amodei, Dario and Sutskever, Ilya and others},
111
+ journal={OpenAI blog},
112
+ volume={1},
113
+ number={8},
114
+ pages={9},
115
+ year={2019}
116
+ }
117
+
118
+ % BERT
119
+ @article{devlin2018bert,
120
+ title={Bert: Pre-training of deep bidirectional transformers for language understanding},
121
+ author={Devlin, Jacob and Chang, Ming-Wei and Lee, Kenton and Toutanova, Kristina},
122
+ journal={arXiv preprint arXiv:1810.04805},
123
+ year={2018}
124
+ }
125
+
126
+ % UNet
127
+ @inproceedings{ronneberger2015u,
128
+ title={U-net: Convolutional networks for biomedical image segmentation},
129
+ author={Ronneberger, Olaf and Fischer, Philipp and Brox, Thomas},
130
+ booktitle={International Conference on Medical image computing and computer-assisted intervention},
131
+ pages={234--241},
132
+ year={2015},
133
+ organization={Springer}
134
+ }
135
+
136
+ % MobileNetV2
137
+ @inproceedings{sandler2018mobilenetv2,
138
+ title={Mobilenetv2: Inverted residuals and linear bottlenecks},
139
+ author={Sandler, Mark and Howard, Andrew and Zhu, Menglong and Zhmoginov, Andrey and Chen, Liang-Chieh},
140
+ booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
141
+ pages={4510--4520},
142
+ year={2018}
143
+ }
144
+
145
+ % Xception
146
+ @inproceedings{chollet2017xception,
147
+ title={Xception: Deep learning with depthwise separable convolutions},
148
+ author={Chollet, Fran{\c{c}}ois},
149
+ booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
150
+ pages={1251--1258},
151
+ year={2017}
152
+ }
153
+
154
+ % MobileNets
155
+ @article{howard2017mobilenets,
156
+ title={Mobilenets: Efficient convolutional neural networks for mobile vision applications},
157
+ author={Howard, Andrew G and Zhu, Menglong and Chen, Bo and Kalenichenko, Dmitry and Wang, Weijun and Weyand, Tobias and Andreetto, Marco and Adam, Hartwig},
158
+ journal={arXiv preprint arXiv:1704.04861},
159
+ year={2017}
160
+ }
161
+
162
+ % ResNet
163
+ @inproceedings{he2016deep,
164
+ title={Deep residual learning for image recognition},
165
+ author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian},
166
+ booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
167
+ pages={770--778},
168
+ year={2016}
169
+ }
170
+
171
+ % ResNeXt
172
+ @inproceedings{xie2017aggregated,
173
+ title={Aggregated residual transformations for deep neural networks},
174
+ author={Xie, Saining and Girshick, Ross and Doll{\'a}r, Piotr and Tu, Zhuowen and He, Kaiming},
175
+ booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
176
+ pages={1492--1500},
177
+ year={2017}
178
+ }
179
+
180
+ % MultiMAE
181
+ @article{bachmann2022multimae,
182
+ title={MultiMAE: Multi-modal Multi-task Masked Autoencoders},
183
+ author={Bachmann, Roman and Mizrahi, David and Atanov, Andrei and Zamir, Amir},
184
+ journal={arXiv preprint arXiv:2204.01678},
185
+ year={2022}
186
+ }
187
+
188
+ % MAE
189
+ @inproceedings{he2022masked,
190
+ title={Masked autoencoders are scalable vision learners},
191
+ author={He, Kaiming and Chen, Xinlei and Xie, Saining and Li, Yanghao and Doll{\'a}r, Piotr and Girshick, Ross},
192
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
193
+ pages={16000--16009},
194
+ year={2022}
195
+ }
196
+
197
+ % VisionTransformer, ViT
198
+ @article{dosovitskiy2020image,
199
+ title={An image is worth 16x16 words: Transformers for image recognition at scale},
200
+ author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and others},
201
+ journal={arXiv preprint arXiv:2010.11929},
202
+ year={2020}
203
+ }
204
+
205
+ % DANet
206
+ @Article{zhao2020single,
207
+ title={A single stream network for robust and real-time RGB-D salient object detection},
208
+ author={Zhao, Xiaoqi and Zhang, Lihe and Pang, Youwei and Lu, Huchuan and Zhang, Lei},
209
+ booktitle={European Conference on Computer Vision},
210
+ pages={646--662},
211
+ year={2020},
212
+ organization={Springer}
213
+ }
214
+
215
+ % DCF
216
+ @Article{Ji_2021_DCF,
217
+ author = {Ji, Wei and Li, Jingjing and Yu, Shuang and Zhang, Miao and Piao, Yongri and Yao, Shunyu and Bi, Qi and Ma, Kai and Zheng, Yefeng and Lu, Huchuan and Cheng, Li},
218
+ title = {Calibrated RGB-D Salient Object Detection},
219
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
220
+ year = {2021},
221
+ pages = {9471-9481}
222
+ }
223
+
224
+ % MVSalNet
225
+ @inproceedings{zhou2022mvsalnet,
226
+ title={MVSalNet: Multi-view Augmentation for RGB-D Salient Object Detection},
227
+ author={Zhou, Jiayuan and Wang, Lijun and Lu, Huchuan and Huang, Kaining and Shi, Xinchu and Liu, Bocong},
228
+ booktitle={European Conference on Computer Vision},
229
+ pages={270--287},
230
+ year={2022},
231
+ organization={Springer}
232
+ }
233
+
234
+ % DSA2F
235
+ @Article{Sun2021DeepRS,
236
+ title={Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion},
237
+ author={P. Sun and Wenhu Zhang and Huanyu Wang and Songyuan Li and Xi Li},
238
+ journal={IEEE Conf. Comput. Vis. Pattern Recog.},
239
+ year={2021}
240
+ }
241
+
242
+ % FRDT
243
+ @Article{zhang2020feature,
244
+ title={Feature reintegration over differential treatment: A top-down and adaptive fusion network for RGB-D salient object detection},
245
+ author={Zhang, Miao and Zhang, Yu and Piao, Yongri and Hu, Beiqi and Lu, Huchuan},
246
+ booktitle={Proceedings of the 28th ACM international conference on multimedia},
247
+ pages={4107--4115},
248
+ year={2020}
249
+ }
250
+
251
+ % HAINet
252
+ @article{li2021hierarchical,
253
+ title={Hierarchical alternate interaction network for RGB-D salient object detection},
254
+ author={Li, Gongyang and Liu, Zhi and Chen, Minyu and Bai, Zhen and Lin, Weisi and Ling, Haibin},
255
+ journal={IEEE Transactions on Image Processing},
256
+ volume={30},
257
+ pages={3528--3542},
258
+ year={2021},
259
+ publisher={IEEE}
260
+ }
261
+
262
+ % JLDCF
263
+ @Article{fu2020jl,
264
+ title={JL-DCF: Joint learning and densely-cooperative fusion framework for RGB-D salient object detection},
265
+ author={Fu, Keren and Fan, Deng-Ping and Ji, Ge-Peng and Zhao, Qijun},
266
+ booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
267
+ pages={3052--3062},
268
+ year={2020}
269
+ }
270
+
271
+ % SSLSOD
272
+ @inproceedings{zhao2022self,
273
+ title={Self-supervised pretraining for rgb-d salient object detection},
274
+ author={Zhao, Xiaoqi and Pang, Youwei and Zhang, Lihe and Lu, Huchuan and Ruan, Xiang},
275
+ booktitle={AAAI Conference on Artificial Intelligence},
276
+ volume={3},
277
+ year={2022}
278
+ }
279
+
280
+ % DFTR
281
+ @article{zhudftr,
282
+ title={DFTR: Depth-supervised Fusion Transformer for Salient Object Detection},
283
+ author={Zhu, Heqin and Sun, Xu and Li, Yuexiang and Ma, Kai and Zhou, S Kevin and Zheng, Yefeng}
284
+ }
285
+
286
+ % PGAR
287
+ @Article{chen2020progressively,
288
+ title={Progressively guided alternate refinement network for RGB-D salient object detection},
289
+ author={Chen, Shuhan and Fu, Yun},
290
+ booktitle={European Conference on Computer Vision},
291
+ pages={520--538},
292
+ year={2020},
293
+ organization={Springer}
294
+ }
295
+
296
+ % DCMF
297
+ @article{wang2022learning,
298
+ title={Learning Discriminative Cross-Modality Features for RGB-D Saliency Detection},
299
+ author={Wang, Fengyun and Pan, Jinshan and Xu, Shoukun and Tang, Jinhui},
300
+ journal={IEEE Transactions on Image Processing},
301
+ volume={31},
302
+ pages={1285--1297},
303
+ year={2022},
304
+ publisher={IEEE}
305
+ }
306
+
307
+ % RD3D
308
+ @Article{chen2021rgb,
309
+ title={RGB-D salient object detection via 3D convolutional neural networks},
310
+ author={Chen, Qian and Liu, Ze and Zhang, Yi and Fu, Keren and Zhao, Qijun and Du, Hongwei},
311
+ booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
312
+ volume={35},
313
+ number={2},
314
+ pages={1063--1071},
315
+ year={2021}
316
+ }
317
+
318
+ % ReDWeb-S
319
+ % S2MA
320
+ @Article{liu2020learning,
321
+ title={Learning selective self-mutual attention for RGB-D saliency detection},
322
+ author={Liu, Nian and Zhang, Ni and Han, Junwei},
323
+ booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
324
+ pages={13756--13765},
325
+ year={2020}
326
+ }
327
+
328
+ % SSF
329
+ @Article{zhang2020select,
330
+ title={Select, supplement and focus for RGB-D saliency detection},
331
+ author={Zhang, Miao and Ren, Weisong and Piao, Yongri and Rong, Zhengkun and Lu, Huchuan},
332
+ booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
333
+ pages={3472--3481},
334
+ year={2020}
335
+ }
336
+
337
+ % UCNet
338
+ @Article{zhang2020uc,
339
+ title={UC-Net: Uncertainty inspired RGB-D saliency detection via conditional variational autoencoders},
340
+ author={Zhang, Jing and Fan, Deng-Ping and Dai, Yuchao and Anwar, Saeed and Saleh, Fatemeh Sadat and Zhang, Tong and Barnes, Nick},
341
+ booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
342
+ pages={8582--8591},
343
+ year={2020}
344
+ }
345
+
346
+ % TriTransNet
347
+ @inproceedings{liu2021tritransnet,
348
+ title={TriTransNet: RGB-D salient object detection with a triplet transformer embedding network},
349
+ author={Liu, Zhengyi and Wang, Yuan and Tu, Zhengzheng and Xiao, Yun and Tang, Bin},
350
+ booktitle={Proceedings of the 29th ACM international conference on multimedia},
351
+ pages={4481--4490},
352
+ year={2021}
353
+ }
354
+
355
+ @Comment{jabref-meta: databaseType:bibtex;}
requirements-lock.txt ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.9.3
2
+ aiosignal==1.3.1
3
+ albumentations==1.4.3
4
+ altair==5.3.0
5
+ async-timeout==4.0.3
6
+ attrs==23.2.0
7
+ black==24.3.0
8
+ blinker==1.7.0
9
+ cachetools==5.3.3
10
+ certifi==2024.2.2
11
+ charset-normalizer==3.3.2
12
+ click==8.1.7
13
+ contourpy==1.2.1
14
+ cycler==0.12.1
15
+ docstring_parser==0.16
16
+ einops==0.7.0
17
+ filelock==3.13.3
18
+ fonttools==4.51.0
19
+ frozenlist==1.4.1
20
+ fsspec==2024.3.1
21
+ gitdb==4.0.11
22
+ GitPython==3.1.43
23
+ huggingface-hub==0.22.2
24
+ idna==3.6
25
+ imageio==2.34.0
26
+ Jinja2==3.1.3
27
+ joblib==1.3.2
28
+ jsonschema==4.21.1
29
+ jsonschema-specifications==2023.12.1
30
+ kiwisolver==1.4.5
31
+ lazy_loader==0.4
32
+ lightning-utilities==0.11.2
33
+ markdown-it-py==3.0.0
34
+ MarkupSafe==2.1.5
35
+ matplotlib==3.8.4
36
+ mdurl==0.1.2
37
+ mpmath==1.3.0
38
+ multidict==6.0.5
39
+ mypy-extensions==1.0.0
40
+ networkx==3.3
41
+ numpy==1.26.4
42
+ nvidia-cublas-cu12==12.1.3.1
43
+ nvidia-cuda-cupti-cu12==12.1.105
44
+ nvidia-cuda-nvrtc-cu12==12.1.105
45
+ nvidia-cuda-runtime-cu12==12.1.105
46
+ nvidia-cudnn-cu12==8.9.2.26
47
+ nvidia-cufft-cu12==11.0.2.54
48
+ nvidia-curand-cu12==10.3.2.106
49
+ nvidia-cusolver-cu12==11.4.5.107
50
+ nvidia-cusparse-cu12==12.1.0.106
51
+ nvidia-nccl-cu12==2.18.1
52
+ nvidia-nvjitlink-cu12==12.4.127
53
+ nvidia-nvtx-cu12==12.1.105
54
+ opencv-python==4.9.0.80
55
+ opencv-python-headless==4.9.0.80
56
+ packaging==24.0
57
+ pandas==2.2.1
58
+ pathspec==0.12.1
59
+ pillow==10.3.0
60
+ platformdirs==4.2.0
61
+ protobuf==4.25.3
62
+ pyarrow==15.0.2
63
+ pycocotools==2.0.7
64
+ pydeck==0.8.1b0
65
+ Pygments==2.17.2
66
+ pyparsing==3.1.2
67
+ python-dateutil==2.9.0.post0
68
+ pytorch-lightning==2.2.1
69
+ pytz==2024.1
70
+ PyYAML==6.0.1
71
+ referencing==0.34.0
72
+ requests==2.31.0
73
+ rich==13.7.1
74
+ rpds-py==0.18.0
75
+ safetensors==0.4.2
76
+ scikit-image==0.22.0
77
+ scikit-learn==1.4.1.post1
78
+ scipy==1.13.0
79
+ six==1.16.0
80
+ smmap==5.0.1
81
+ streamlit==1.33.0
82
+ sympy==1.12
83
+ tenacity==8.2.3
84
+ termcolor==2.4.0
85
+ threadpoolctl==3.4.0
86
+ tifffile==2024.2.12
87
+ timm==0.9.16
88
+ toml==0.10.2
89
+ tomli==2.0.1
90
+ toolz==0.12.1
91
+ torch==2.1.0
92
+ torchmetrics==1.3.2
93
+ torchvision==0.16.0
94
+ tornado==6.4
95
+ tqdm==4.66.2
96
+ triton==2.1.0
97
+ typed-argument-parser==1.9.0
98
+ typing-inspect==0.9.0
99
+ typing_extensions==4.11.0
100
+ tzdata==2024.1
101
+ urllib3==2.2.1
102
+ watchdog==4.0.0
103
+ yarl==1.9.4
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.1.0
2
+ torchvision
3
+ opencv-python
4
+ pycocotools
5
+ matplotlib
6
+ Pillow
7
+ numpy
8
+ einops
9
+ timm
10
+ albumentations
11
+ termcolor
12
+ tqdm
13
+ pandas
14
+ typed-argument-parser
15
+ pytorch-lightning
16
+ streamlit
17
+ black
18
+ huggingface-hub
s_multimae/__init__.py ADDED
File without changes
s_multimae/configs/__init__.py ADDED
File without changes
s_multimae/configs/base_config.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import os
3
+ from typing import Dict, Optional, Tuple, List
4
+ import torch
5
+ from torch import nn
6
+ import math
7
+ import albumentations as A
8
+
9
+ from definition import PRETRAINED_BACKBONE
10
+ from .data_augmentation_config import DataAugmentationConfig
11
+
12
+
13
+ class base_cfg:
14
+ def __init__(
15
+ self,
16
+ epoch: int,
17
+ datasets_set: int,
18
+ experiment_name: Optional[str] = None,
19
+ ):
20
+ self.experiment_name = experiment_name = (
21
+ self.__class__.__name__ if experiment_name is None else experiment_name
22
+ )
23
+ self.datasets_set = datasets_set
24
+
25
+ # Trainv3
26
+ self.devices: List[int] = [0, 1]
27
+ # How often to check the validation set. Pass a float in the range [0.0, 1.0] to check
28
+ self.val_check_interval: float = 1.0
29
+
30
+ # Perform a validation loop every after every N training epochs.
31
+ self.check_val_every_n_epoch: int = 2
32
+
33
+ self.precision = 16
34
+ self.transform1 = [A.HorizontalFlip(p=0.5)]
35
+
36
+ self.save_top_k = 2
37
+
38
+ # ConvNeXtAdapter
39
+ self.dec_kernel = 1 # decoder kernel size
40
+
41
+ # Version 1: as usual
42
+ # Version 2: mean, std
43
+ self.model_version = 1
44
+
45
+ self.visualized_num_dev_samples = 0
46
+
47
+ # PytorchLightning Trainer
48
+ self.sync_batchnorm = True
49
+
50
+ self.normalized_depth: bool = True
51
+
52
+ self.test_image_size: int = 224
53
+ self.image_size: int = 224
54
+
55
+ """Whether using fp16 instead of fp32 (default)"""
56
+ self.is_fp16: bool = True
57
+
58
+ self.is_padding: bool = (
59
+ False # deprecated due to randomly switch between padding and non-padding
60
+ )
61
+
62
+ # """For debug only"""
63
+ # self.max_train_samples: Optional[int] = None
64
+ # self.max_dev_samples: Optional[int] = None
65
+
66
+ """Whether using padding for test"""
67
+ self.is_padding_for_test: bool = False
68
+
69
+ """Seed"""
70
+ self.seed: int = 2022
71
+
72
+ """ MultiMAE """
73
+ self.decoder_depth: int = 4
74
+ self.encoder_depth: int = 12
75
+ self.is_inference_with_no_depth: bool = False
76
+ self.inputs = ["rgb", "depth"]
77
+ self.outputs = ["sod"]
78
+ self.decoder_main_tasks: List[List[str]] = [["rgb"]]
79
+ self.learnable_pos_emb: bool = False
80
+ self.learnable_additional_gt_tokens: bool = False
81
+ self.decoder_interpolate_mode: str = "bilinear" # ['bilinear', 'nearest']
82
+ self.dim_tokens: int = 768
83
+ self.act_fn = partial(nn.ReLU, inplace=True)
84
+ self.num_heads: int = 12
85
+ self.freeze_encoder: bool = False
86
+
87
+ """Data Augmentation"""
88
+ self.data_augmentation_version: int = 2
89
+ self.data_augmentation_config = DataAugmentationConfig()
90
+
91
+ self.ckpt_path: Optional[str] = None
92
+ self.description: str = "" # Override this
93
+ self.embed_dim: int = 6144
94
+
95
+ """Pretrained Backbone"""
96
+ self.pretrained_backbone: Optional[PRETRAINED_BACKBONE] = (
97
+ PRETRAINED_BACKBONE.MULTIMAE
98
+ )
99
+
100
+ """
101
+ Required only when self.pretrained_backbone in [PRETRAINED_BACKBONE.S_MULTIMAE, PRETRAINED_BACKBONE.LARGE_S_MULTIMAE].
102
+ Example: 'v1.0.4_e499' stands for version 1.0.4, epoch 499, trained 500 epochs
103
+ """
104
+ self.pretrained_backbone_version: Optional[str] = None
105
+
106
+ """Ground truth
107
+ V1: 1 head, each head has 1 class, BCE
108
+ V2: 1 head, each head has 5 classes, CE
109
+ V3: 5 heads, each head has 1 class, BCE
110
+ V4: 1 head, each head has 5 classes, BCE
111
+ V5: additional global token indicates individual thinker
112
+ """
113
+ self.ground_truth_version = 1
114
+ self.additional_gt_tokens_mlp_channels = []
115
+ self.num_classes = 1
116
+ self.actual_num_classes = 1
117
+
118
+ self.is_cache = False
119
+
120
+ """Learning rate
121
+ LR strategy:
122
+ V1: The ratio of unpretrained and pretrained is also 1:lr_scale
123
+ V2: The ratio of unpretrained and pretrained is changed gradually from 1:lr_scale -> 1:1
124
+ """
125
+ self.lr_strategy_version = 1
126
+ self.lr: float
127
+ self.end_lr: float = 1e-11
128
+ self.lr_scale: int
129
+ self.lr_power: float = 0.9
130
+
131
+ # Deprecated from v3
132
+ self.save_checkpoints_after_each_n_epochs: int = 10 # Not used in trainv3
133
+
134
+ self.weight_decay = 0.05
135
+ self.num_workers = 2
136
+ self.num_epochs_every_restart = 100
137
+
138
+ self.betas: Tuple[float, float] = (0.9, 0.999)
139
+
140
+ self.input_patch_size: int = 16
141
+ self.output_patch_size: int = 16 # must be a square of number
142
+
143
+ """Warmup batchsize"""
144
+ self.warmup_min_batch_size: Optional[int] = None
145
+ self.warmup_epoch_batch_size: Optional[int] = None
146
+
147
+ self.batch_size: int
148
+ self.val_batch_size: int
149
+ self.test_batch_size: int = 100
150
+ self.nepochs: int
151
+
152
+ def todict(self):
153
+ d = dict()
154
+ for k, v in self.__dict__.items():
155
+ if not k.startswith("_"):
156
+ d[k] = v
157
+ return d
158
+
159
+ @property
160
+ def total_iters_per_epoch(self):
161
+ return math.ceil(
162
+ (self.num_training_samples_per_epoch)
163
+ / (self.batch_size * len(self.devices))
164
+ )
s_multimae/configs/data_augmentation_config.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class RandomGaussianBlurConfig:
2
+ def __init__(self, p=0.5, max_gaussian_kernel=19) -> None:
3
+ self.p = p
4
+ self.max_gaussian_kernel = max_gaussian_kernel
5
+
6
+
7
+ class DataAugmentationConfig:
8
+ def __init__(self) -> None:
9
+ self.mean_normalization = [0.5, 0.5, 0.5]
10
+ self.std_normalization = [0.5, 0.5, 0.5]
11
+ self.image_gaussian_config = RandomGaussianBlurConfig(
12
+ p=0.5,
13
+ max_gaussian_kernel=19,
14
+ )
15
+ self.depth_gaussian_config = RandomGaussianBlurConfig(
16
+ p=0.5,
17
+ max_gaussian_kernel=36,
18
+ )
19
+ self.random_horizontal_flip_prob = 0.5
s_multimae/configs/experiment_config.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Dict, Optional, Type
3
+
4
+ from .base_config import base_cfg
5
+ import importlib, inspect, os
6
+ from glob import glob
7
+
8
+ arg_cfg: Dict[str, Type[base_cfg]] = dict()
9
+
10
+ modules = []
11
+ for p in glob("s_multimae/configs/experiment_configs/*.py"):
12
+ if not p.startswith("__"):
13
+ module_name = os.path.splitext(os.path.basename(p))[0]
14
+ modules.append(f"s_multimae.configs.experiment_configs.{module_name}")
15
+
16
+ for module in modules:
17
+ for name, cls in inspect.getmembers(
18
+ importlib.import_module(module), inspect.isclass
19
+ ):
20
+ if name.startswith("cfg"):
21
+ arg_cfg[name] = cls
22
+
23
+
24
+ def get_config_by_set_version(set_version: int) -> base_cfg:
25
+ if set_version not in [1, 2, 3, 4]:
26
+ raise Exception(f"Unsupported set version {set_version}")
27
+ return arg_cfg[f"cfg_set_{set_version}"]()
28
+
29
+
30
+ def get_config(cfg_name: str, epoch: Optional[int] = None) -> base_cfg:
31
+ return arg_cfg[cfg_name](epoch)
s_multimae/configs/experiment_configs/__init__.py ADDED
File without changes
s_multimae/configs/experiment_configs/expv1_dynamic.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ import cv2
3
+
4
+ from typing import Optional
5
+ from definition import PRETRAINED_BACKBONE
6
+ from ..base_config import base_cfg
7
+
8
+
9
+ class cfgv4_0_2006(base_cfg):
10
+ def __init__(self, epoch: Optional[int] = None):
11
+ super().__init__(epoch, datasets_set=1)
12
+
13
+ self.check_val_every_n_epoch = 1
14
+ self.num_workers = 4
15
+ self.devices = [0]
16
+
17
+ self.description = "ViT Large [DoT]"
18
+
19
+ """MultiMAE"""
20
+ self.pretrained_backbone = PRETRAINED_BACKBONE.LARGE_S_MULTIMAE
21
+ self.pretrained_backbone_version = "v2.0.5-pr"
22
+
23
+ # Large MAE
24
+ self.dim_tokens = 1024
25
+ self.encoder_depth = 24
26
+ self.num_heads = 16
27
+
28
+ self.clip_grad = None
29
+ self.normalized_depth = False
30
+
31
+ """Decoders"""
32
+ self.decoder_main_tasks = [["rgb", "depth"]]
33
+ self.decoder_depth = 10
34
+
35
+ # ConvNeXtAdapter
36
+ self.dec_kernel = 3
37
+
38
+ # debug
39
+ # self.max_train_samples = 80
40
+ # self.max_dev_samples = 80
41
+
42
+ # Diversity of thought
43
+ self.ground_truth_version = 6
44
+ self.num_classes = 5 # ignored
45
+ self.actual_num_classes = 5 # ignored
46
+ self.additional_gt_tokens_mlp_channels = [768 * 2]
47
+
48
+ """Learning rate"""
49
+ self.lr = 1e-5
50
+ self.end_lr = 1e-8
51
+ self.lr_scale = 100
52
+
53
+ self.batch_size = 20
54
+ self.val_batch_size = 200
55
+ self.nepochs = 400
56
+ self.num_epochs_every_restart = 100
57
+
58
+ self.data_augmentation_version = 6
59
+ self.train_function_version = 3
60
+ self.weight_decay = 5e-2
61
+ self.transform1 = [
62
+ A.HorizontalFlip(p=0.5),
63
+ ]
64
+
65
+
66
+ class cfgv4_0_2007(base_cfg):
67
+ def __init__(self, epoch: Optional[int] = None):
68
+ super().__init__(epoch, datasets_set=1)
69
+
70
+ self.check_val_every_n_epoch = 1
71
+ self.num_workers = 4
72
+ self.devices = [0]
73
+
74
+ self.description = "ViT Base [DoT]"
75
+
76
+ """MultiMAE"""
77
+ self.pretrained_backbone = PRETRAINED_BACKBONE.S_MULTIMAE
78
+ self.pretrained_backbone_version = "v2.0.1-pr"
79
+
80
+ self.clip_grad = None
81
+ self.normalized_depth = False
82
+
83
+ """Decoders"""
84
+ self.decoder_main_tasks = [["rgb", "depth"]]
85
+ self.decoder_depth = 10
86
+
87
+ # ConvNeXtAdapter
88
+ self.dec_kernel = 3
89
+
90
+ # debug
91
+ # self.max_train_samples = 80
92
+ # self.max_dev_samples = 80
93
+
94
+ # Diversity of thought
95
+ self.ground_truth_version = 6
96
+ self.num_classes = 5 # ignored
97
+ self.actual_num_classes = 5 # ignored
98
+ self.additional_gt_tokens_mlp_channels = [768 * 2]
99
+
100
+ """Learning rate"""
101
+ self.lr = 1e-5
102
+ self.end_lr = 1e-8
103
+ self.lr_scale = 100
104
+
105
+ self.batch_size = 40
106
+ self.val_batch_size = 200
107
+ self.nepochs = 400
108
+ self.num_epochs_every_restart = 100
109
+
110
+ self.data_augmentation_version = 6
111
+ self.train_function_version = 3
112
+ self.weight_decay = 5e-2
113
+ self.transform1 = [
114
+ A.HorizontalFlip(p=0.5),
115
+ A.OneOf(
116
+ [
117
+ A.Compose(
118
+ [
119
+ A.RandomCropFromBorders(
120
+ crop_left=0.3,
121
+ crop_right=0.3,
122
+ crop_top=0.3,
123
+ crop_bottom=0.3,
124
+ p=0.2,
125
+ ),
126
+ A.ShiftScaleRotate(
127
+ shift_limit=0.0625,
128
+ scale_limit=0.1,
129
+ rotate_limit=45,
130
+ p=0.1,
131
+ ),
132
+ A.Perspective(
133
+ p=0.2,
134
+ scale=(0.05, 0.1),
135
+ ),
136
+ ]
137
+ ),
138
+ A.Compose(
139
+ [
140
+ A.RandomCropFromBorders(
141
+ crop_left=0.3,
142
+ crop_right=0.3,
143
+ crop_top=0.3,
144
+ crop_bottom=0.3,
145
+ p=0.2,
146
+ ),
147
+ A.ShiftScaleRotate(
148
+ shift_limit=0.0625,
149
+ scale_limit=0.1,
150
+ rotate_limit=45,
151
+ p=0.1,
152
+ border_mode=cv2.BORDER_CONSTANT,
153
+ value=(255, 255, 255),
154
+ mask_value=0,
155
+ ),
156
+ A.Perspective(
157
+ p=0.2,
158
+ scale=(0.05, 0.1),
159
+ pad_mode=cv2.BORDER_CONSTANT,
160
+ pad_val=(255, 255, 255),
161
+ mask_pad_val=0,
162
+ ),
163
+ ]
164
+ ),
165
+ ]
166
+ ),
167
+ ]
168
+
169
+
170
+ class cfgv4_0_2002(base_cfg):
171
+ def __init__(self, epoch: Optional[int] = None):
172
+ super().__init__(epoch, datasets_set=1)
173
+
174
+ self.check_val_every_n_epoch = 1
175
+ self.num_workers = 4
176
+ self.devices = [3]
177
+
178
+ # self.description = "Trainv3-DAv6-DiversityOfThought-NotMuchAug"
179
+ self.description = "DEBUG"
180
+
181
+ """MultiMAE"""
182
+ self.pretrained_backbone = PRETRAINED_BACKBONE.S_MULTIMAE
183
+ self.pretrained_backbone_version = "v2.0.1-pr"
184
+
185
+ # Large MAE
186
+ # self.dim_tokens = 1024
187
+ # self.encoder_depth = 24
188
+ # self.num_heads = 16
189
+
190
+ self.clip_grad = None
191
+ self.normalized_depth = False
192
+
193
+ """Decoders"""
194
+ self.decoder_main_tasks = [["rgb", "depth"]]
195
+ self.decoder_depth = 10
196
+
197
+ # ConvNeXtAdapter
198
+ self.dec_kernel = 3
199
+
200
+ # debug
201
+ self.max_train_samples = 20
202
+ self.max_dev_samples = 20
203
+
204
+ # Diversity of thought
205
+ self.ground_truth_version = 6
206
+ self.num_classes = 5 # ignored
207
+ self.actual_num_classes = 5 # ignored
208
+ self.additional_gt_tokens_mlp_channels = [768 * 2]
209
+
210
+ """Learning rate"""
211
+ self.lr = 1e-5
212
+ self.end_lr = 1e-8
213
+ self.lr_scale = 100
214
+
215
+ self.batch_size = 5
216
+ self.val_batch_size = 5
217
+ self.nepochs = 400
218
+ self.num_epochs_every_restart = 100
219
+
220
+ self.data_augmentation_version = 6
221
+ self.train_function_version = 3
222
+ self.weight_decay = 5e-2
223
+ self.transform1 = [
224
+ A.HorizontalFlip(p=0.5),
225
+ A.OneOf(
226
+ [
227
+ A.Compose(
228
+ [
229
+ A.RandomCropFromBorders(
230
+ crop_left=0.3,
231
+ crop_right=0.3,
232
+ crop_top=0.3,
233
+ crop_bottom=0.3,
234
+ p=0.2,
235
+ ),
236
+ A.ShiftScaleRotate(
237
+ shift_limit=0.0625,
238
+ scale_limit=0.1,
239
+ rotate_limit=45,
240
+ p=0.1,
241
+ ),
242
+ A.Perspective(
243
+ p=0.2,
244
+ scale=(0.05, 0.1),
245
+ ),
246
+ ]
247
+ ),
248
+ A.Compose(
249
+ [
250
+ A.RandomCropFromBorders(
251
+ crop_left=0.3,
252
+ crop_right=0.3,
253
+ crop_top=0.3,
254
+ crop_bottom=0.3,
255
+ p=0.2,
256
+ ),
257
+ A.ShiftScaleRotate(
258
+ shift_limit=0.0625,
259
+ scale_limit=0.1,
260
+ rotate_limit=45,
261
+ p=0.1,
262
+ border_mode=cv2.BORDER_CONSTANT,
263
+ value=(255, 255, 255),
264
+ mask_value=0,
265
+ ),
266
+ A.Perspective(
267
+ p=0.2,
268
+ scale=(0.05, 0.1),
269
+ pad_mode=cv2.BORDER_CONSTANT,
270
+ pad_val=(255, 255, 255),
271
+ mask_pad_val=0,
272
+ ),
273
+ ]
274
+ ),
275
+ ]
276
+ ),
277
+ ]
s_multimae/da/__init__.py ADDED
File without changes
s_multimae/da/base_da.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import List, Optional, Tuple
3
+ from torch import nn, Tensor
4
+ from PIL import Image
5
+
6
+
7
+ class BaseDataAugmentation(nn.Module):
8
+ def __init__(self):
9
+ super(BaseDataAugmentation, self).__init__()
10
+
11
+ @abc.abstractmethod
12
+ def forward(
13
+ self,
14
+ image: Image.Image,
15
+ depth: Image.Image,
16
+ gt: Optional[Image.Image] = None,
17
+ ranking_gt: Optional[Image.Image] = None,
18
+ multi_gts: Optional[List[Image.Image]] = None,
19
+ is_transform: bool = True, # is augmented?
20
+ is_debug: bool = False,
21
+ ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
22
+ """
23
+ Usual case:
24
+ If gt is provided, return [image, depth, gt]
25
+ Otherwise, return [image, depth]
26
+
27
+ When ranking_gt is provided, gt will be ignored
28
+ Return [image, depth, ranking_gt]
29
+
30
+ For debugging:
31
+ Return [image, depth, gt|ranking_gt, unnormalized, Optional[ranking_gts]]
32
+ """
33
+ pass
s_multimae/da/dav6.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+ from PIL import Image
3
+ import numpy as np
4
+ from torchvision import transforms
5
+ import albumentations as A
6
+ import torch
7
+ from torch import Tensor
8
+
9
+ from ..configs.base_config import base_cfg
10
+ from .base_da import BaseDataAugmentation
11
+
12
+
13
+ class DataAugmentationV6(BaseDataAugmentation):
14
+ def __init__(
15
+ self,
16
+ cfg: base_cfg,
17
+ is_padding=True,
18
+ ):
19
+ super().__init__()
20
+ self.image_size = cfg.image_size
21
+ self.is_padding = is_padding
22
+ self.cfg = cfg
23
+
24
+ self.to_tensor = transforms.ToTensor()
25
+
26
+ self.additional_targets = {
27
+ "depth": "image",
28
+ "gt": "mask",
29
+ "ranking_gt": "mask",
30
+ "multi_gts": "mask",
31
+ }
32
+
33
+ # For rgb+depth+gt
34
+ self.transform1 = A.Compose(
35
+ cfg.transform1,
36
+ additional_targets=self.additional_targets,
37
+ )
38
+
39
+ # For rgb only
40
+ self.transform2 = A.Compose(
41
+ [
42
+ A.GaussianBlur(p=0.5, blur_limit=(3, 19)),
43
+ A.RandomBrightnessContrast(p=0.5),
44
+ A.ColorJitter(p=0.5),
45
+ ]
46
+ )
47
+
48
+ # For depth only
49
+ self.transform3 = A.Compose([A.GaussianBlur(p=0.5, blur_limit=(3, 37))])
50
+
51
+ # For rgb+depth+gt
52
+ self.transform4 = A.Compose(
53
+ [A.Resize(self.image_size, self.image_size)],
54
+ additional_targets=self.additional_targets,
55
+ is_check_shapes=False,
56
+ )
57
+
58
+ # For rgb only
59
+ self.transform5 = A.Compose([A.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
60
+
61
+ # For depth only
62
+ self.transform6 = A.Compose([A.Normalize(0.5, 0.5)])
63
+
64
+ def forward(
65
+ self,
66
+ image: Image.Image,
67
+ depth: Image.Image,
68
+ gt: Optional[Image.Image] = None,
69
+ ranking_gt: Optional[Image.Image] = None,
70
+ multi_gts: Optional[List[Image.Image]] = None,
71
+ is_transform: bool = True, # is augmented?
72
+ is_debug: bool = False,
73
+ ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
74
+ ## 1. Convert to numpy array: image, depth, gt, ranking_gts
75
+ image = np.array(image)
76
+ depth = np.array(depth)
77
+ d = dict(image=image, depth=depth)
78
+ if gt is not None:
79
+ gt = np.array(gt)
80
+ d["gt"] = gt
81
+
82
+ if not is_transform:
83
+ # Dev or Test
84
+ d = self.transform4(**d)
85
+ d["image"] = self.transform5(image=d["image"])["image"]
86
+ # d["depth"] = self.transform6(image=depth)["image"]
87
+ if gt is not None:
88
+ return self.to_tensors([d["image"], d["depth"], d["gt"]])
89
+ else:
90
+ return self.to_tensors([d["image"], d["depth"]])
91
+
92
+ d["depth"] = 255 - d["depth"] # inverse depth
93
+
94
+ # if ranking_gt is not None and multi_gts is not None:
95
+ # print('[WARN] Both ranking_gt and multi_gts are not none, but we prioritize multi_gts')
96
+
97
+ if ranking_gt is not None:
98
+ ranking_gt = np.array(ranking_gt)
99
+
100
+ if multi_gts is not None:
101
+ multi_gts = np.stack(multi_gts, axis=2)
102
+ d["multi_gts"] = multi_gts
103
+
104
+ ## 2. First transformation for image (Contrast, GaussianBlur,...), depth (GaussianBlur,...)
105
+ d["image"] = self.transform2(image=d["image"])["image"]
106
+ d["depth"] = self.transform3(image=d["depth"])["image"]
107
+
108
+ ## 3. Transformation defined in config: change perspective, rotation, size, ...
109
+ d = self.transform1(**d)
110
+
111
+ ## 4. Resize
112
+ d = self.transform4(**d)
113
+
114
+ ## Just backup image before normalizing it
115
+ if is_debug:
116
+ unnormalized_image = d["image"]
117
+
118
+ ## 6. Construct response
119
+ d["depth"] = 255 - d["depth"] # inverse depth
120
+ d["image"] = self.transform5(image=d["image"])["image"]
121
+ # d["depth"] = self.transform6(image=depth)["image"]
122
+ rs = self.to_tensors([d["image"], d["depth"]])
123
+ if multi_gts is not None:
124
+ rs += self.to_tensors([d["multi_gts"]])
125
+ elif ranking_gt is not None:
126
+ rs += [torch.from_numpy(d["ranking_gt"]).to(torch.long)]
127
+ else:
128
+ rs += self.to_tensors([d["gt"]])
129
+
130
+ ## 7. For debug only
131
+ if is_debug:
132
+ rs.append(unnormalized_image)
133
+
134
+ if ranking_gt is not None:
135
+ ranking_gts = []
136
+ for i in range(self.cfg.num_classes):
137
+ ranking_gts.append(
138
+ np.array(d["ranking_gt"] == i).astype(np.uint8) * 255
139
+ )
140
+ rs.append(ranking_gts)
141
+ if multi_gts is not None:
142
+ rs.append(d["multi_gts"])
143
+
144
+ return rs
145
+
146
+ def to_tensors(self, lst: List[Tensor]) -> List[Tensor]:
147
+ return [self.to_tensor(e) for e in lst]
s_multimae/data_augmentation.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ from .configs.base_config import base_cfg
4
+ from .da.dav6 import DataAugmentationV6
5
+ from .da.base_da import BaseDataAugmentation
6
+
7
+
8
+ def get_data_augmentation(
9
+ cfg: base_cfg,
10
+ image_size: int,
11
+ is_padding: bool,
12
+ ) -> BaseDataAugmentation:
13
+ if cfg.data_augmentation_version == 6:
14
+ print("Using DataAugmentationV6")
15
+ return DataAugmentationV6(cfg)
16
+ else:
17
+ raise NotImplementedError(
18
+ f"Unsupported DataAugmentation version {cfg.data_augmentation_version}"
19
+ )
s_multimae/model/__init__.py ADDED
File without changes
s_multimae/model/components.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ import math
4
+ import warnings
5
+ from typing import Tuple, Union
6
+ from einops import rearrange
7
+
8
+ import os
9
+
10
+ from definition import PRETRAINED_BACKBONE
11
+ from ..configs.base_config import base_cfg
12
+
13
+
14
+ def pair(t: Union[int, Tuple[int, int]]) -> Tuple[int, int]:
15
+ return t if isinstance(t, tuple) else (t, t)
16
+
17
+
18
+ def build_2d_sincos_posemb(h: int, w: int, embed_dim=1024, temperature=10000.0):
19
+ """Sine-cosine positional embeddings from MoCo-v3
20
+
21
+ Source: https://github.com/facebookresearch/moco-v3/blob/main/vits.py
22
+ """
23
+ grid_w = torch.arange(w, dtype=torch.float32)
24
+ grid_h = torch.arange(h, dtype=torch.float32)
25
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy")
26
+ assert (
27
+ embed_dim % 4 == 0
28
+ ), "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
29
+ pos_dim = embed_dim // 4
30
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
31
+ omega = 1.0 / (temperature**omega)
32
+ out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega])
33
+ out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega])
34
+ pos_emb = torch.cat(
35
+ [torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1
36
+ )[None, :, :]
37
+ pos_emb = rearrange(pos_emb, "b (h w) d -> b d h w", h=h, w=w, d=embed_dim)
38
+ return pos_emb
39
+
40
+
41
+ def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float, b: float):
42
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
43
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
44
+ def norm_cdf(x):
45
+ # Computes standard normal cumulative distribution function
46
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
47
+
48
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
49
+ warnings.warn(
50
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
51
+ "The distribution of values may be incorrect.",
52
+ stacklevel=2,
53
+ )
54
+
55
+ with torch.no_grad():
56
+ # Values are generated by using a truncated uniform distribution and
57
+ # then using the inverse CDF for the normal distribution.
58
+ # Get upper and lower cdf values
59
+ l = norm_cdf((a - mean) / std)
60
+ u = norm_cdf((b - mean) / std)
61
+
62
+ # Uniformly fill tensor with values from [l, u], then translate to
63
+ # [2l-1, 2u-1].
64
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
65
+
66
+ # Use inverse cdf transform for normal distribution to get truncated
67
+ # standard normal
68
+ tensor.erfinv_()
69
+
70
+ # Transform to proper mean, std
71
+ tensor.mul_(std * math.sqrt(2.0))
72
+ tensor.add_(mean)
73
+
74
+ # Clamp to ensure it's in the proper range
75
+ tensor.clamp_(min=a, max=b)
76
+ return tensor
77
+
78
+
79
+ def trunc_normal_(tensor: Tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
80
+ # type: (Tensor, float, float, float, float) -> Tensor
81
+ r"""Fills the input Tensor with values drawn from a truncated
82
+ normal distribution. The values are effectively drawn from the
83
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
84
+ with values outside :math:`[a, b]` redrawn until they are within
85
+ the bounds. The method used for generating the random values works
86
+ best when :math:`a \leq \text{mean} \leq b`.
87
+ Args:
88
+ tensor: an n-dimensional `Tensor`
89
+ mean: the mean of the normal distribution
90
+ std: the standard deviation of the normal distribution
91
+ a: the minimum cutoff value
92
+ b: the maximum cutoff value
93
+ Examples:
94
+ >>> w = torch.empty(3, 5)
95
+ >>> nn.init.trunc_normal_(w)
96
+ """
97
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
98
+
99
+
100
+ def drop_path(x: Tensor, drop_prob: float = 0.0, training: bool = False):
101
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
102
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
103
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
104
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
105
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
106
+ 'survival rate' as the argument.
107
+ """
108
+ if drop_prob == 0.0 or not training:
109
+ return x
110
+ keep_prob = 1 - drop_prob
111
+ shape = (x.shape[0],) + (1,) * (
112
+ x.ndim - 1
113
+ ) # work with diff dim tensors, not just 2D ConvNets
114
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
115
+ random_tensor.floor_() # binarize
116
+ output = x.div(keep_prob) * random_tensor
117
+ return output
s_multimae/model/multimae.py ADDED
@@ -0,0 +1,938 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import re
3
+ from collections import OrderedDict
4
+ from functools import partial
5
+ from typing import Dict, Iterable, List, Optional, Tuple, Union
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torchvision.ops import MLP
9
+ from einops import rearrange, repeat
10
+ from torch import Tensor, nn
11
+
12
+ from definition import PRETRAINED_BACKBONE
13
+ from ..configs.base_config import base_cfg
14
+ from ..utils import count_parameters
15
+ from .components import (
16
+ build_2d_sincos_posemb,
17
+ drop_path,
18
+ pair,
19
+ trunc_normal_,
20
+ )
21
+
22
+
23
+ class PatchedInputAdapter(nn.Module):
24
+ """Adapter for spatial inputs, like images or feature maps.
25
+ Creates tokens from patches over the image.
26
+
27
+ :param num_channels: Number of input channels of the image/feature map
28
+ :param stride_level: Stride level compared to the full-sized image.
29
+ E.g. 4 for 1/4th the size of the image.
30
+ :param patch_size_full: Int or tuple of the patch size over the full image size.
31
+ Patch size for smaller inputs will be computed accordingly.
32
+ :param dim_tokens: Dimension of output tokens. Can be set using init method.
33
+ :param sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings
34
+ :param learnable_pos_emb: Set to True to learn positional embeddings instead
35
+ :param image_size: Default image size. Used to initialize size of positional embeddings.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ num_channels: int,
41
+ stride_level: int,
42
+ patch_size_full: Union[int, Tuple[int, int]],
43
+ dim_tokens: Optional[int] = None,
44
+ sincos_pos_emb: bool = True,
45
+ learnable_pos_emb: bool = False,
46
+ image_size: Union[int, Tuple[int]] = 224,
47
+ ):
48
+ super().__init__()
49
+ self.num_channels = num_channels
50
+ self.stride_level = stride_level
51
+ self.patch_size_full = pair(patch_size_full)
52
+ self.dim_tokens = dim_tokens
53
+ self.sincos_pos_emb = sincos_pos_emb
54
+ self.learnable_pos_emb = learnable_pos_emb
55
+ self.image_size = pair(image_size)
56
+ self.num_patches = (self.image_size[0] // patch_size_full) * (
57
+ self.image_size[1] // patch_size_full
58
+ )
59
+
60
+ # Actual patch height and width, taking into account stride of input
61
+ self.P_H = max(1, self.patch_size_full[0] // stride_level)
62
+ self.P_W = max(1, self.patch_size_full[1] // stride_level)
63
+
64
+ if self.dim_tokens is not None:
65
+ self.init(dim_tokens=dim_tokens)
66
+
67
+ def init(self, dim_tokens: int = 768):
68
+ """
69
+ Initialize parts of encoder that are dependent on dimension of tokens.
70
+ Should be called when setting up MultiMAE.
71
+
72
+ :param dim_tokens: Dimension of tokens
73
+ """
74
+ self.dim_tokens = dim_tokens
75
+
76
+ # Task embedding identifying from which task a given token comes from
77
+ # Fixed-size positional embeddings. Can be interpolated to different input sizes
78
+ h_posemb = self.image_size[0] // (self.stride_level * self.P_H)
79
+ w_posemb = self.image_size[1] // (self.stride_level * self.P_W)
80
+ if self.sincos_pos_emb:
81
+ self.pos_emb = build_2d_sincos_posemb(
82
+ h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens
83
+ )
84
+ self.pos_emb = nn.Parameter(
85
+ self.pos_emb, requires_grad=self.learnable_pos_emb
86
+ )
87
+ else:
88
+ self.pos_emb = nn.Parameter(
89
+ torch.zeros(1, self.dim_tokens, h_posemb, w_posemb)
90
+ )
91
+ trunc_normal_(self.pos_emb, std=0.02)
92
+
93
+ # Image -> tokens projection
94
+ self.proj = nn.Conv2d(
95
+ in_channels=self.num_channels,
96
+ out_channels=self.dim_tokens,
97
+ kernel_size=(self.P_H, self.P_W),
98
+ stride=(self.P_H, self.P_W),
99
+ )
100
+
101
+ @torch.jit.ignore
102
+ def no_weight_decay(self):
103
+ return {"pos_emb"}
104
+
105
+ def forward(self, x: Tensor) -> Tensor:
106
+ """
107
+ Forward pass through input adapter, transforming image to sequence of tokens.
108
+ Adds task and positional encodings.
109
+
110
+ :param x: Input image tensor
111
+ """
112
+ B, C, H, W = x.shape
113
+ assert (
114
+ self.dim_tokens is not None
115
+ ), "Need to call init(dim_tokens) function first"
116
+ assert (H % self.P_H == 0) and (
117
+ W % self.P_W == 0
118
+ ), f"Image sizes {H}x{W} must be divisible by patch sizes {self.P_H}x{self.P_W}"
119
+ N_H, N_W = H // self.P_H, W // self.P_W # Number of patches in height and width
120
+
121
+ # Create patches [B, C, H, W] -> [B, (H*W), C]
122
+ projected_x = self.proj(x)
123
+ x_patch = rearrange(projected_x, "b d nh nw -> b (nh nw) d")
124
+
125
+ # Create positional embedding
126
+ x_pos_emb = F.interpolate(
127
+ self.pos_emb, size=(N_H, N_W), mode="bicubic", align_corners=False
128
+ )
129
+ x_pos_emb = rearrange(x_pos_emb, "b d nh nw -> b (nh nw) d")
130
+
131
+ # Add patches and positional embeddings
132
+ x = x_patch + x_pos_emb
133
+
134
+ return x
135
+
136
+
137
+ class DropPath(nn.Module):
138
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
139
+
140
+ def __init__(self, drop_prob=None):
141
+ super(DropPath, self).__init__()
142
+ self.drop_prob = drop_prob
143
+
144
+ def forward(self, x: Tensor) -> Tensor:
145
+ return drop_path(x, self.drop_prob, self.training)
146
+
147
+ def extra_repr(self) -> str:
148
+ return "p={}".format(self.drop_prob)
149
+
150
+
151
+ class ConvNeXtBlock(nn.Module):
152
+ r"""ConvNeXt Block. There are two equivalent implementations:
153
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
154
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
155
+ We use (2) as we find it slightly faster in PyTorch
156
+
157
+ Args:
158
+ dim (int): Number of input channels.
159
+ drop_path: Stochastic depth rate. Default: 0.0
160
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 0 (disabled for isotropic ConvNeXt).
161
+
162
+ Code from: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py
163
+ """
164
+
165
+ def __init__(self, dim, drop_path=0.0, layer_scale_init_value=0.0):
166
+ super().__init__()
167
+ self.dwconv = nn.Conv2d(
168
+ dim, dim, kernel_size=7, padding=3, groups=dim
169
+ ) # depthwise conv
170
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
171
+ self.pwconv1 = nn.Linear(
172
+ dim, 4 * dim
173
+ ) # pointwise/1x1 convs, implemented with linear layers
174
+ self.act = nn.GELU()
175
+ self.pwconv2 = nn.Linear(4 * dim, dim)
176
+ self.gamma = (
177
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
178
+ if layer_scale_init_value > 0
179
+ else None
180
+ )
181
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
182
+
183
+ def forward(self, x: Tensor) -> Tensor:
184
+ input = x
185
+ x = self.dwconv(x)
186
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
187
+ x = self.norm(x)
188
+ x = self.pwconv1(x)
189
+ x = self.act(x)
190
+ x = self.pwconv2(x)
191
+ if self.gamma is not None:
192
+ x = self.gamma * x
193
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
194
+
195
+ x = input + self.drop_path(x)
196
+ return x
197
+
198
+
199
+ class ConvNeXtAdapter(nn.Module):
200
+ """Output adapter with ConvNext blocks for semantic segmentation
201
+
202
+ :param num_classes: Number of classes
203
+ :param num_heads: Number of attention heads
204
+ :param embed_dim: Token dimension after projection, and before reshaping operation.
205
+ :param preds_per_patch: Increases size of feature map by reshaping each patch Each patch gets reshaped
206
+ from embed_dim x 1 x 1 to (embed_dim / preds_per_patch) x (preds_per_patch ** 0.5) x (preds_per_patch ** 0.5)
207
+ :param main_tasks: Tasks to use for the adapter. Only tokens coming from these tasks are kept.
208
+ :param patch_size: Size of patches
209
+ :param depth: Number of ConvNeXt blocks
210
+ :interpolate_mode: Interpolation mode for final upsampling
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ image_size: int,
216
+ num_classes: int,
217
+ embed_dim: int = 6144,
218
+ preds_per_patch: int = 16,
219
+ main_tasks: Iterable[str] = ("rgb",),
220
+ patch_size: int = 16,
221
+ depth: int = 4,
222
+ interpolate_mode: str = "bilinear",
223
+ act_fn: nn.Module = nn.GELU,
224
+ dec_kernel: int = 1,
225
+ ):
226
+ super().__init__()
227
+ self.main_tasks = main_tasks
228
+ self.patch_size = patch_size
229
+ self.embed_dim = embed_dim
230
+ self.preds_per_patch = preds_per_patch
231
+ self.class_dim = embed_dim // preds_per_patch
232
+ self.num_classes = num_classes
233
+ self.interpolate_mode = interpolate_mode
234
+ self.image_size = image_size
235
+
236
+ self.blocks = nn.Sequential(
237
+ *[ConvNeXtBlock(dim=self.class_dim) for _ in range(depth)]
238
+ )
239
+ if dec_kernel == 1:
240
+ self.final_layer_1 = nn.Sequential(
241
+ nn.Conv2d(self.class_dim, self.class_dim // 4, 1),
242
+ nn.BatchNorm2d(self.class_dim // 4),
243
+ act_fn(),
244
+ nn.Upsample(scale_factor=2, mode=self.interpolate_mode),
245
+ )
246
+
247
+ self.final_layer_2 = nn.Sequential(
248
+ nn.Conv2d(self.class_dim // 4, self.class_dim // 16, 1),
249
+ nn.BatchNorm2d(self.class_dim // 16),
250
+ act_fn(),
251
+ nn.Upsample(size=image_size, mode=self.interpolate_mode),
252
+ )
253
+
254
+ self.final_layer = nn.Conv2d(self.class_dim // 16, self.num_classes, 1)
255
+ elif dec_kernel == 3:
256
+ self.final_layer_1 = nn.Sequential(
257
+ nn.Conv2d(
258
+ self.class_dim,
259
+ self.class_dim // 4,
260
+ kernel_size=3,
261
+ stride=1,
262
+ padding=1,
263
+ ),
264
+ nn.BatchNorm2d(self.class_dim // 4),
265
+ act_fn(),
266
+ nn.Upsample(scale_factor=2, mode=self.interpolate_mode),
267
+ )
268
+
269
+ self.final_layer_2 = nn.Sequential(
270
+ nn.Conv2d(
271
+ self.class_dim // 4,
272
+ self.class_dim // 16,
273
+ kernel_size=3,
274
+ stride=1,
275
+ padding=1,
276
+ ),
277
+ nn.BatchNorm2d(self.class_dim // 16),
278
+ act_fn(),
279
+ nn.Upsample(size=image_size, mode=self.interpolate_mode),
280
+ )
281
+
282
+ self.final_layer = nn.Conv2d(
283
+ self.class_dim // 16,
284
+ self.num_classes,
285
+ kernel_size=3,
286
+ stride=1,
287
+ padding=1,
288
+ )
289
+ else:
290
+ raise Exception(f"Unsupported dec_kernel {dec_kernel}")
291
+
292
+ self.apply(self._init_weights)
293
+
294
+ def init(self, dim_tokens_enc: int = 768):
295
+ """
296
+ Initialize parts of decoder that are dependent on dimension of encoder tokens.
297
+ Should be called when setting up MultiMAE.
298
+
299
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
300
+ """
301
+ self.in_channels = dim_tokens_enc * len(self.main_tasks)
302
+
303
+ # Projection of encoder tokens to the patch dimension
304
+ self.proj_dec = nn.Linear(self.in_channels, self.embed_dim)
305
+ self._init_weights(self.proj_dec)
306
+
307
+ def _init_weights(self, m: nn.Module):
308
+ if isinstance(m, nn.Linear):
309
+ trunc_normal_(m.weight, std=0.02)
310
+ if isinstance(m, nn.Linear) and m.bias is not None:
311
+ nn.init.constant_(m.bias, 0)
312
+ elif isinstance(m, nn.LayerNorm):
313
+ nn.init.constant_(m.bias, 0)
314
+ nn.init.constant_(m.weight, 1.0)
315
+
316
+ def adapt_tokens(self, encoder_tokens: Tensor, input_info: Dict):
317
+ # Adapt tokens
318
+ x = []
319
+ for task in self.main_tasks:
320
+ start_idx = input_info["tasks"][task]["start_idx"]
321
+ end_idx = input_info["tasks"][task]["end_idx"]
322
+ x.append(encoder_tokens[:, start_idx:end_idx])
323
+
324
+ x = torch.cat(x, dim=-1)
325
+ return x
326
+
327
+ def forward(self, encoder_tokens: Tensor, input_info: Dict) -> Tensor:
328
+ H, W = input_info["image_size"]
329
+ N_H, N_W = H // self.patch_size, W // self.patch_size
330
+
331
+ x = self.adapt_tokens(encoder_tokens, input_info)
332
+
333
+ x = self.proj_dec(x)
334
+ x = rearrange(
335
+ x,
336
+ "b n (p c) -> b (n p) c",
337
+ n=N_H * N_W,
338
+ p=self.preds_per_patch,
339
+ c=self.class_dim,
340
+ )
341
+ x = rearrange(
342
+ x,
343
+ "b (nh nw ph pw) c -> b c (nh ph) (nw pw)",
344
+ nh=N_H,
345
+ nw=N_W,
346
+ ph=int(self.preds_per_patch**0.5),
347
+ pw=int(self.preds_per_patch**0.5),
348
+ )
349
+
350
+ x = self.blocks(x)
351
+
352
+ # for block in self.blocks:
353
+ # x = block(x)
354
+ # print(x.shape)
355
+
356
+ # print(x.shape)
357
+ x = self.final_layer_1(x)
358
+ # print(x.shape)
359
+ x = self.final_layer_2(x)
360
+ # print(x.shape)
361
+ x = self.final_layer(x)
362
+ # print(x.shape)
363
+
364
+ # Interpolate to sod res
365
+ # x = F.interpolate(x, size=(H, W), mode=self.interpolate_mode)
366
+
367
+ return x
368
+
369
+
370
+ class Attention(nn.Module):
371
+ def __init__(
372
+ self,
373
+ dim: int,
374
+ num_heads=8,
375
+ qkv_bias=False,
376
+ attn_drop=0.0,
377
+ proj_drop=0.0,
378
+ ):
379
+ super().__init__()
380
+ self.num_heads = num_heads
381
+ head_dim = dim // num_heads
382
+ self.scale = head_dim**-0.5
383
+
384
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
385
+ self.attn_drop = nn.Dropout(attn_drop)
386
+ self.proj = nn.Linear(dim, dim)
387
+ self.proj_drop = nn.Dropout(proj_drop)
388
+
389
+ def forward(self, x: Tensor) -> Tensor:
390
+ B, N, C = x.shape
391
+ qkv = (
392
+ self.qkv(x)
393
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
394
+ .permute(2, 0, 3, 1, 4)
395
+ )
396
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
397
+
398
+ attn = (q @ k.transpose(-2, -1)) * self.scale
399
+ attn = attn.softmax(dim=-1)
400
+ attn = self.attn_drop(attn)
401
+
402
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
403
+ x = self.proj(x)
404
+ x = self.proj_drop(x)
405
+ return x
406
+
407
+
408
+ class Mlp(nn.Module):
409
+ def __init__(
410
+ self,
411
+ in_features: int,
412
+ hidden_features: Optional[int] = None,
413
+ out_features: Optional[int] = None,
414
+ act_layer: nn.Module = nn.GELU,
415
+ drop: float = 0.0,
416
+ ):
417
+ super().__init__()
418
+ out_features = out_features or in_features
419
+ hidden_features = hidden_features or in_features
420
+ self.fc1 = nn.Linear(in_features, hidden_features)
421
+ self.act = act_layer()
422
+ self.fc2 = nn.Linear(hidden_features, out_features)
423
+ self.drop = nn.Dropout(drop)
424
+
425
+ def forward(self, x: Tensor) -> Tensor:
426
+ x = self.fc1(x)
427
+ x = self.act(x)
428
+ # x = self.drop(x)
429
+ # commit this for the orignal BERT implement
430
+ x = self.fc2(x)
431
+ x = self.drop(x)
432
+ return x
433
+
434
+
435
+ class Block(nn.Module):
436
+ def __init__(
437
+ self,
438
+ dim: int,
439
+ num_heads: int,
440
+ mlp_ratio=4.0,
441
+ qkv_bias=False,
442
+ drop=0.0,
443
+ attn_drop=0.0,
444
+ drop_path=0.0,
445
+ act_layer=nn.GELU,
446
+ norm_layer=nn.LayerNorm,
447
+ ):
448
+ super().__init__()
449
+ self.norm1 = norm_layer(dim)
450
+ self.attn = Attention(
451
+ dim,
452
+ num_heads=num_heads,
453
+ qkv_bias=qkv_bias,
454
+ attn_drop=attn_drop,
455
+ proj_drop=drop,
456
+ )
457
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
458
+ self.norm2 = norm_layer(dim)
459
+ mlp_hidden_dim = int(dim * mlp_ratio)
460
+ self.mlp = Mlp(
461
+ in_features=dim,
462
+ hidden_features=mlp_hidden_dim,
463
+ act_layer=act_layer,
464
+ drop=drop,
465
+ )
466
+
467
+ def forward(self, x: Tensor) -> Tensor:
468
+ x = x + self.drop_path(self.attn(self.norm1(x)))
469
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
470
+ return x
471
+
472
+
473
+ class MultiMAE(nn.Module):
474
+ """MultiMAE: Multi-task Multi-modal Masked Autoencoder
475
+ This module performs masking in its forward pass.
476
+ The MultiViT module defined below inherits from this module and performs a regular forward pass,
477
+ and should be used instead for downstream tasks
478
+
479
+
480
+ :param input_adapters: Dictionary of task -> input adapters
481
+ :param output_adapters: Optional dictionary of task -> output adapters
482
+
483
+ :param num_global_tokens: Number of additional global tokens to add (like cls tokens), default is 1
484
+ :param dim_tokens: Dimension of encoder tokens
485
+ :param depth: Depth of encoder
486
+ :param num_heads: Number of attention heads
487
+ :param mlp_ratio: MLP hidden dim ratio
488
+ :param qkv_bias: Set to False to disable bias
489
+ :param drop_rate: Dropout after MLPs and Attention
490
+ :param attn_drop_rate: Attention matrix drop rate
491
+ :param drop_path_rate: DropPath drop rate
492
+ :param norm_layer: Type of normalization layer
493
+ """
494
+
495
+ def __init__(
496
+ self,
497
+ input_adapters: Dict[str, PatchedInputAdapter],
498
+ output_adapters: Dict[str, ConvNeXtAdapter],
499
+ num_global_tokens: int = 1,
500
+ dim_tokens: int = 768,
501
+ depth: int = 12,
502
+ num_heads: int = 12,
503
+ mlp_ratio: float = 4.0,
504
+ qkv_bias: bool = True,
505
+ drop_rate: float = 0.0,
506
+ attn_drop_rate: float = 0.0,
507
+ drop_path_rate: float = 0.0,
508
+ norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
509
+ freeze_encoder: bool = False,
510
+ num_additional_gt_tokens: int = 0, # @deprecated
511
+ actual_num_additional_gt_tokens: int = 0, # @deprecated
512
+ learnable_additional_gt_tokens: bool = False,
513
+ additional_gt_tokens_mlp_channels: List[int] = [],
514
+ ground_truth_version: int = -1,
515
+ A: float = 0.5,
516
+ ):
517
+ super().__init__()
518
+ self.dim_tokens = dim_tokens
519
+ self.ground_truth_version = ground_truth_version
520
+ # Initialize input and output adapters
521
+ for adapter in input_adapters.values():
522
+ adapter.init(dim_tokens=dim_tokens)
523
+ self.input_adapters = nn.ModuleDict(input_adapters)
524
+ for adapter in output_adapters.values():
525
+ adapter.init(dim_tokens_enc=dim_tokens)
526
+ self.output_adapters = nn.ModuleDict(output_adapters)
527
+
528
+ # Additional learnable tokens that can be used by encoder to process/store global information
529
+ self.num_global_tokens = num_global_tokens
530
+ self.global_tokens = nn.Parameter(torch.zeros(1, num_global_tokens, dim_tokens))
531
+ trunc_normal_(self.global_tokens, std=0.02)
532
+
533
+ self.num_additional_gt_tokens = num_additional_gt_tokens # @deprecated
534
+ self.actual_num_additional_gt_tokens = (
535
+ actual_num_additional_gt_tokens # @deprecated
536
+ )
537
+ self.A = A
538
+ self.additional_gt_tokens_mlp_channels = additional_gt_tokens_mlp_channels
539
+ self.learnable_additional_gt_tokens = learnable_additional_gt_tokens
540
+ self.init_gt_tokens()
541
+
542
+ # Transformer encoder
543
+ dpr = [
544
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
545
+ ] # stochastic depth decay rule
546
+ self.encoder = nn.Sequential(
547
+ *[
548
+ Block(
549
+ dim=dim_tokens,
550
+ num_heads=num_heads,
551
+ mlp_ratio=mlp_ratio,
552
+ qkv_bias=qkv_bias,
553
+ drop=drop_rate,
554
+ attn_drop=attn_drop_rate,
555
+ drop_path=dpr[i],
556
+ norm_layer=norm_layer,
557
+ )
558
+ for i in range(depth)
559
+ ]
560
+ )
561
+
562
+ print(f"Encoder {count_parameters(self.encoder)}")
563
+
564
+ if freeze_encoder:
565
+ print("Freeze encoder")
566
+ for param in self.encoder.parameters():
567
+ param.requires_grad = False
568
+
569
+ self.apply(self._init_weights)
570
+ for name, m in self.named_modules():
571
+ if isinstance(m, nn.Linear):
572
+ if "qkv" in name:
573
+ # treat the weights of Q, K, V separately
574
+ val = math.sqrt(
575
+ 6.0 / float(m.weight.shape[0] // 3 + m.weight.shape[1])
576
+ )
577
+ nn.init.uniform_(m.weight, -val, val)
578
+ elif "kv" in name:
579
+ # treat the weights of K, V separately
580
+ val = math.sqrt(
581
+ 6.0 / float(m.weight.shape[0] // 2 + m.weight.shape[1])
582
+ )
583
+ nn.init.uniform_(m.weight, -val, val)
584
+
585
+ if isinstance(m, nn.Conv2d):
586
+ if ".proj" in name:
587
+ # From MAE, initialize projection like nn.Linear (instead of nn.Conv2d)
588
+ w = m.weight.data
589
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
590
+
591
+ print(f"Total params: {count_parameters(self)}")
592
+
593
+ def init_gt_tokens(self):
594
+ """Just prepare beforehand to save time in training
595
+ In inference, there is no need"""
596
+ addtional_gt_tokens: List[Tensor] = []
597
+ if self.num_additional_gt_tokens == 0:
598
+ self.token_mlp = nn.Identity()
599
+ return
600
+ if len(self.additional_gt_tokens_mlp_channels) > 0:
601
+ self.token_mlp = MLP(
602
+ self.dim_tokens,
603
+ self.additional_gt_tokens_mlp_channels + [self.dim_tokens],
604
+ )
605
+ else:
606
+ self.token_mlp = nn.Identity()
607
+
608
+ if self.ground_truth_version != 6:
609
+ T = 1 / (self.num_additional_gt_tokens * 4)
610
+ for i in range(self.actual_num_additional_gt_tokens):
611
+ t = [
612
+ 2 * math.pi * (offset / self.dim_tokens - i * T)
613
+ for offset in range(self.dim_tokens)
614
+ ]
615
+ addtional_gt_tokens.append(
616
+ nn.Parameter(
617
+ self.A * torch.cos(Tensor(t).unsqueeze(0).unsqueeze(0)),
618
+ requires_grad=self.learnable_additional_gt_tokens,
619
+ )
620
+ )
621
+ self.addtional_gt_tokens = nn.ParameterList(addtional_gt_tokens)
622
+
623
+ def _init_weights(self, m: nn.Module) -> None:
624
+ if isinstance(m, nn.Linear):
625
+ nn.init.xavier_uniform_(m.weight)
626
+ if isinstance(m, nn.Linear) and m.bias is not None:
627
+ nn.init.constant_(m.bias, 0)
628
+ elif isinstance(m, nn.LayerNorm):
629
+ nn.init.constant_(m.bias, 0)
630
+ nn.init.constant_(m.weight, 1.0)
631
+
632
+ @torch.jit.ignore
633
+ def no_weight_decay(self):
634
+ no_wd_set = {"global_tokens"}
635
+
636
+ for task, adapter in self.input_adapters.items():
637
+ if hasattr(adapter, "no_weight_decay"):
638
+ to_skip = adapter.no_weight_decay()
639
+ to_skip = set([f"input_adapters.{task}.{name}" for name in to_skip])
640
+ no_wd_set = no_wd_set | to_skip
641
+
642
+ for task, adapter in self.output_adapters.items():
643
+ if hasattr(adapter, "no_weight_decay"):
644
+ to_skip = adapter.no_weight_decay()
645
+ to_skip = set([f"output_adapters.{task}.{name}" for name in to_skip])
646
+ no_wd_set = no_wd_set | to_skip
647
+
648
+ return no_wd_set
649
+
650
+ def generate_input_info(
651
+ self, input_task_tokens: Dict[str, Tensor], image_size: Tuple[int, int]
652
+ ) -> Dict[str, Tensor]:
653
+ input_info = OrderedDict()
654
+ i = 0
655
+ input_info["tasks"] = {}
656
+ for domain, tensor in input_task_tokens.items():
657
+ num_tokens: Union[int, Tensor] = tensor.shape[1]
658
+
659
+ if type(num_tokens) == Tensor:
660
+ num_tokens = num_tokens.item()
661
+
662
+ d = {
663
+ "num_tokens": num_tokens,
664
+ "has_2d_posemb": True,
665
+ "start_idx": i,
666
+ "end_idx": i + num_tokens,
667
+ }
668
+ i += num_tokens
669
+ input_info["tasks"][domain] = d
670
+
671
+ input_info["image_size"] = image_size
672
+ input_info["num_task_tokens"] = i
673
+ input_info["num_global_tokens"] = self.num_global_tokens
674
+
675
+ return input_info
676
+
677
+
678
+ class MultiViT(MultiMAE):
679
+ def extract_B_H_W(self, x: Dict[str, Tensor]) -> Tuple[int, int, int]:
680
+ # If input x is a Tensor, assume it's RGB
681
+ # x = {'rgb': x} if isinstance(x, Tensor) else x
682
+ # Need image size for tokens->image reconstruction
683
+ if "rgb" in x:
684
+ B, _, H, W = x["rgb"].shape
685
+ elif "sod" in x:
686
+ B, H, W = x["sod"].shape
687
+ H *= self.input_adapters["sod"].stride_level
688
+ W *= self.input_adapters["sod"].stride_level
689
+ else:
690
+ B, _, H, W = list(x.values())[0].shape
691
+ return B, H, W
692
+
693
+ def process_input(
694
+ self,
695
+ x: Dict[str, Tensor],
696
+ gt_index_lst: List[int],
697
+ num_gts_lst: List[int],
698
+ ) -> Tuple[Tensor, Dict[str, Tensor]]:
699
+ """
700
+ len(gt_i) must equal to x.shape[0] when self.num_additional_gt_tokens > 0
701
+ """
702
+ B, H, W = self.extract_B_H_W(x)
703
+
704
+ # Encode selected inputs to tokens
705
+ input_task_tokens: Dict[str, Tensor] = {
706
+ domain: self.input_adapters[domain](tensor)
707
+ for domain, tensor in x.items()
708
+ if domain in self.input_adapters
709
+ }
710
+
711
+ input_info = self.generate_input_info(
712
+ input_task_tokens=input_task_tokens, image_size=(H, W)
713
+ )
714
+ input_tokens = torch.cat(
715
+ [task_tokens for task_tokens in input_task_tokens.values()], dim=1
716
+ )
717
+
718
+ # Add global tokens to input tokens
719
+ global_tokens = repeat(self.global_tokens, "() n d -> b n d", b=B)
720
+
721
+ if self.ground_truth_version == 6:
722
+ # We need two inputs: gt_index, num_gts
723
+ assert len(gt_index_lst) == len(num_gts_lst)
724
+ additional_gt_tokens = []
725
+ for gt_index, num_gts in zip(gt_index_lst, num_gts_lst):
726
+ T = 1 / num_gts
727
+ i = gt_index
728
+ t = [
729
+ 2 * math.pi * (offset / self.dim_tokens - i * T)
730
+ for offset in range(self.dim_tokens)
731
+ ]
732
+ additional_gt_token = self.A * torch.cos(
733
+ Tensor(t).unsqueeze(0).unsqueeze(0)
734
+ )
735
+ additional_gt_tokens.append(additional_gt_token)
736
+ additional_gt_tokens = torch.cat(additional_gt_tokens, dim=0).to(
737
+ input_tokens.device
738
+ )
739
+ additional_gt_tokens = self.token_mlp(additional_gt_tokens)
740
+ input_tokens = torch.cat(
741
+ [input_tokens, global_tokens, additional_gt_tokens], dim=1
742
+ )
743
+ else:
744
+ if self.num_additional_gt_tokens > 0:
745
+
746
+ assert gt_index_lst is not None and len(gt_index_lst) == B
747
+ additional_gt_tokens: Tensor = torch.cat(
748
+ [self.addtional_gt_tokens[gt_i] for gt_i in gt_index_lst], dim=0
749
+ )
750
+ additional_gt_tokens = self.token_mlp(additional_gt_tokens)
751
+ input_tokens = torch.cat(
752
+ [input_tokens, global_tokens, additional_gt_tokens], dim=1
753
+ )
754
+ else:
755
+ input_tokens = torch.cat([input_tokens, global_tokens], dim=1)
756
+
757
+ return input_tokens, input_info
758
+
759
+ def forward(
760
+ self,
761
+ x: Dict[str, Tensor],
762
+ gt_index_lst: Optional[List[int]] = None,
763
+ max_gts_lst: Optional[List[int]] = None,
764
+ ) -> Dict[str, Tensor]:
765
+ """
766
+ Forward pass through input adapters, transformer encoder and output adapters.
767
+
768
+ :param x: Dictionary of tensors
769
+ :param outputs: List of outputs. For ex: outputs=['sod', 'depth']. Make sure 'sod' placed first!
770
+ """
771
+ input_tokens, input_info = self.process_input(x, gt_index_lst, max_gts_lst)
772
+
773
+ # Pass tokens through Transformer
774
+ encoder_tokens = self.encoder(input_tokens)
775
+
776
+ # Decode tokens for each task using task-specific output adapters
777
+ preds = {
778
+ domain: self.output_adapters[domain](
779
+ encoder_tokens=encoder_tokens,
780
+ input_info=input_info,
781
+ )
782
+ for domain in self.output_adapters
783
+ }
784
+
785
+ return preds
786
+
787
+
788
+ def interpolate_pos_embed_multimae(
789
+ model: MultiViT,
790
+ checkpoint_model: Dict[str, Tensor],
791
+ ) -> None:
792
+ pattern = "input_adapters\.(.*)\.pos_emb"
793
+ matched_keys = [k for k in checkpoint_model if bool(re.match(pattern, k))]
794
+
795
+ for key in matched_keys:
796
+ domain = re.match(pattern, key).group(1) # group(0) is entire matched regex
797
+ if getattr(model.input_adapters, domain, None) is not None:
798
+ pos_embed_checkpoint = checkpoint_model[key]
799
+ _, _, orig_H, orig_W = pos_embed_checkpoint.shape
800
+ _, _, new_H, new_W = getattr(model.input_adapters, domain).pos_emb.shape
801
+ if (orig_H != new_H) or (orig_W != new_W):
802
+ print(
803
+ f"Key {key}: Position interpolate from {orig_H}x{orig_W} to {new_H}x{new_W}"
804
+ )
805
+ pos_embed_checkpoint = torch.nn.functional.interpolate(
806
+ pos_embed_checkpoint,
807
+ size=(new_H, new_W),
808
+ mode="bicubic",
809
+ align_corners=False,
810
+ )
811
+ checkpoint_model[key] = pos_embed_checkpoint
812
+
813
+
814
+ def construct_adapters(cfg: base_cfg):
815
+ INPUT_ADAPTERS = {
816
+ "rgb": PatchedInputAdapter(
817
+ num_channels=3,
818
+ stride_level=1,
819
+ patch_size_full=cfg.input_patch_size,
820
+ image_size=cfg.image_size,
821
+ learnable_pos_emb=cfg.learnable_pos_emb,
822
+ ),
823
+ "depth": PatchedInputAdapter(
824
+ num_channels=1,
825
+ stride_level=1,
826
+ patch_size_full=cfg.input_patch_size,
827
+ image_size=cfg.image_size,
828
+ learnable_pos_emb=cfg.learnable_pos_emb,
829
+ ),
830
+ }
831
+
832
+ num_classes = cfg.num_classes
833
+ if cfg.ground_truth_version in [5, 6]:
834
+ num_classes = 1
835
+
836
+ OUTPUT_ADAPTERS = {
837
+ "sod": partial(
838
+ ConvNeXtAdapter,
839
+ num_classes=num_classes,
840
+ image_size=cfg.image_size,
841
+ embed_dim=cfg.embed_dim,
842
+ patch_size=cfg.input_patch_size,
843
+ preds_per_patch=cfg.output_patch_size,
844
+ depth=cfg.decoder_depth,
845
+ interpolate_mode=cfg.decoder_interpolate_mode,
846
+ main_tasks=cfg.decoder_main_tasks,
847
+ act_fn=cfg.act_fn,
848
+ dec_kernel=cfg.dec_kernel,
849
+ ),
850
+ "rgb": partial(
851
+ ConvNeXtAdapter,
852
+ num_classes=3,
853
+ image_size=cfg.image_size,
854
+ embed_dim=cfg.embed_dim,
855
+ patch_size=cfg.input_patch_size,
856
+ preds_per_patch=cfg.output_patch_size,
857
+ depth=cfg.decoder_depth,
858
+ interpolate_mode=cfg.decoder_interpolate_mode,
859
+ main_tasks=cfg.decoder_main_tasks,
860
+ act_fn=cfg.act_fn,
861
+ dec_kernel=cfg.dec_kernel,
862
+ ),
863
+ "depth": partial(
864
+ ConvNeXtAdapter,
865
+ num_classes=1,
866
+ image_size=cfg.image_size,
867
+ embed_dim=cfg.embed_dim,
868
+ patch_size=cfg.input_patch_size,
869
+ preds_per_patch=cfg.output_patch_size,
870
+ depth=cfg.decoder_depth,
871
+ interpolate_mode=cfg.decoder_interpolate_mode,
872
+ main_tasks=cfg.decoder_main_tasks,
873
+ act_fn=cfg.act_fn,
874
+ dec_kernel=cfg.dec_kernel,
875
+ ),
876
+ }
877
+
878
+ if cfg.ground_truth_version == 3:
879
+ for i in range(cfg.num_classes):
880
+ OUTPUT_ADAPTERS[f"sod{i}"] = partial(
881
+ ConvNeXtAdapter,
882
+ num_classes=1,
883
+ image_size=cfg.image_size,
884
+ embed_dim=cfg.embed_dim,
885
+ patch_size=cfg.input_patch_size,
886
+ preds_per_patch=cfg.output_patch_size,
887
+ depth=cfg.decoder_depth,
888
+ interpolate_mode=cfg.decoder_interpolate_mode,
889
+ main_tasks=cfg.decoder_main_tasks,
890
+ act_fn=cfg.act_fn,
891
+ dec_kernel=cfg.dec_kernel,
892
+ )
893
+ return INPUT_ADAPTERS, OUTPUT_ADAPTERS
894
+
895
+
896
+ def generate_smultimae_model(cfg: base_cfg) -> Tuple[MultiViT, List[Dict]]:
897
+ """MULTIMAE"""
898
+ assert len(cfg.decoder_main_tasks) == len(
899
+ cfg.outputs
900
+ ), "Length of decoder main tasks must match length of outputs"
901
+
902
+ INPUT_ADAPTERS, OUTPUT_ADAPTERS = construct_adapters(cfg)
903
+
904
+ input_adapters = dict()
905
+ for input_key in cfg.inputs:
906
+ input_adapters[input_key] = INPUT_ADAPTERS[input_key]
907
+
908
+ output_adapters = dict()
909
+ for output_key, decoder_main_tasks_per_output in zip(
910
+ cfg.outputs, cfg.decoder_main_tasks
911
+ ):
912
+ output_adapters[output_key] = OUTPUT_ADAPTERS[output_key](
913
+ main_tasks=decoder_main_tasks_per_output
914
+ )
915
+
916
+ num_additional_gt_tokens = 0 # @deprecated
917
+ actual_num_additional_gt_tokens = 0 # @deprecated
918
+ if cfg.ground_truth_version in [5, 6]: # @deprecated
919
+ num_additional_gt_tokens = cfg.num_classes # @deprecated
920
+ actual_num_additional_gt_tokens = cfg.actual_num_classes # @deprecated
921
+ model = MultiViT(
922
+ input_adapters=input_adapters,
923
+ output_adapters=output_adapters,
924
+ freeze_encoder=cfg.freeze_encoder,
925
+ drop_path_rate=0.1,
926
+ dim_tokens=cfg.dim_tokens,
927
+ depth=cfg.encoder_depth,
928
+ num_heads=cfg.num_heads,
929
+ mlp_ratio=4,
930
+ qkv_bias=True,
931
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
932
+ num_additional_gt_tokens=num_additional_gt_tokens, # @deprecated
933
+ actual_num_additional_gt_tokens=actual_num_additional_gt_tokens, # @deprecated
934
+ ground_truth_version=cfg.ground_truth_version,
935
+ )
936
+
937
+ # return load_pretrained_backbone(cfg, model)
938
+ return model, []
s_multimae/model_pl.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import os
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+ import cv2
5
+ import torch
6
+ from torch import Tensor, nn
7
+ import torch.nn.functional as F
8
+ import pytorch_lightning as pl
9
+ import numpy as np
10
+
11
+ from .configs.base_config import base_cfg
12
+ from .rgbd_model import RGBDModel
13
+
14
+
15
+ class ModelPL(pl.LightningModule):
16
+ def __init__(self, cfg: base_cfg):
17
+ super().__init__()
18
+ self.cfg = cfg
19
+ self.model = RGBDModel(cfg)
20
+
21
+ def forward(self, images: Tensor, depths: Tensor):
22
+ return self.model.forward(images, depths)
23
+
24
+ def __inference_v1(
25
+ self, outputs: Dict[str, Tensor], image_sizes: List[Tuple[int, int]]
26
+ ):
27
+ res_lst: List[List[np.ndarray]] = []
28
+ for output, image_size in zip(outputs["sod"], image_sizes):
29
+ output: Tensor = F.interpolate(
30
+ output.unsqueeze(0),
31
+ size=(image_size[1], image_size[0]),
32
+ mode="bilinear",
33
+ align_corners=False,
34
+ )
35
+ res: np.ndarray = output.sigmoid().data.cpu().numpy().squeeze()
36
+ res = (res - res.min()) / (res.max() - res.min() + 1e-8)
37
+ if self.cfg.is_fp16:
38
+ res = np.float32(res)
39
+ res_lst.append([(res * 255).astype(np.uint8)])
40
+ return res_lst
41
+
42
+ def __inference_v2(
43
+ self, outputs: Dict[str, Tensor], image_sizes: List[Tuple[int, int]]
44
+ ):
45
+ res_lst: List[List[np.ndarray]] = []
46
+ for output, image_size in zip(outputs["sod"], image_sizes):
47
+ output: Tensor = F.interpolate(
48
+ output.unsqueeze(0),
49
+ size=(image_size[1], image_size[0]),
50
+ mode="bilinear",
51
+ align_corners=False,
52
+ )
53
+ res: np.ndarray = torch.argmax(output, dim=1).cpu().numpy().squeeze()
54
+ res_lst.append([res])
55
+ return res_lst
56
+
57
+ def __inference_v3v5(
58
+ self, outputs: Dict[str, Tensor], image_sizes: List[Tuple[int, int]]
59
+ ):
60
+ res_lst: List[List[np.ndarray]] = []
61
+ for bi, image_size in enumerate(image_sizes):
62
+ res_lst_per_sample: List[np.ndarray] = []
63
+ for i in range(self.cfg.num_classes):
64
+ pred = outputs[f"sod{i}"][bi]
65
+ pred: Tensor = F.interpolate(
66
+ pred.unsqueeze(0),
67
+ size=(image_size[1], image_size[0]),
68
+ mode="bilinear",
69
+ align_corners=False,
70
+ )
71
+ res: np.ndarray = pred.sigmoid().data.cpu().numpy().squeeze()
72
+ res = (res - res.min()) / (res.max() - res.min() + 1e-8)
73
+ if self.cfg.is_fp16:
74
+ res = np.float32(res)
75
+ res_lst_per_sample.append((res * 255).astype(np.uint8))
76
+ res_lst.append(res_lst_per_sample)
77
+ return res_lst
78
+
79
+ @torch.no_grad()
80
+ def inference(
81
+ self,
82
+ image_sizes: List[Tuple[int, int]],
83
+ images: Tensor,
84
+ depths: Optional[Tensor],
85
+ max_gts: Optional[List[int]],
86
+ ) -> List[List[np.ndarray]]:
87
+ self.model.eval()
88
+ assert len(image_sizes) == len(
89
+ images
90
+ ), "The number of image_sizes must equal to the number of images"
91
+ gpu_images: Tensor = images.to(self.device)
92
+ gpu_depths: Tensor = depths.to(self.device)
93
+
94
+ if self.cfg.ground_truth_version == 6:
95
+ with torch.cuda.amp.autocast(enabled=self.cfg.is_fp16):
96
+ outputs: Dict[str, Tensor] = dict()
97
+ for i in range(self.cfg.num_classes):
98
+ outputs[f"sod{i}"] = self.model.inference(
99
+ gpu_images, gpu_depths, [i] * gpu_images.shape[0], max_gts
100
+ )["sod"]
101
+ return self.__inference_v3v5(outputs, image_sizes)
102
+ else:
103
+ raise Exception(
104
+ f"Unsupported ground_truth_version {self.cfg.ground_truth_version}"
105
+ )
s_multimae/rgbd_model.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+ from torch import nn, Tensor
3
+
4
+ from .model.multimae import generate_smultimae_model as generate_smultimae_model_v1
5
+ from .configs.base_config import base_cfg
6
+
7
+
8
+ class RGBDModel(nn.Module):
9
+ def __init__(self, cfg: base_cfg):
10
+ super(RGBDModel, self).__init__()
11
+
12
+ self.inputs = cfg.inputs
13
+ self.outputs = cfg.outputs
14
+
15
+ self.is_no_depth = cfg.is_inference_with_no_depth
16
+
17
+ if cfg.model_version == 1:
18
+ self.model, self.opt_params = generate_smultimae_model_v1(cfg)
19
+ else:
20
+ raise Exception(f"Unsupported model version {cfg.model_version}")
21
+
22
+ def encode_decode(
23
+ self,
24
+ images: Tensor,
25
+ depths: Optional[Tensor],
26
+ gt_index_lst: Optional[List[int]] = None,
27
+ max_gts_lst: Optional[List[int]] = None,
28
+ ) -> Dict[str, Tensor]:
29
+ """Encode images with backbone and decode into a semantic segmentation
30
+ map of the same size as input.
31
+
32
+ Returns:
33
+ {
34
+ "sod": Tensor,
35
+ "depth": Optional[Tensor],
36
+ "rgb": Optional[tensor],
37
+ }
38
+ """
39
+ inputs = {"rgb": images}
40
+ if "depth" in self.inputs:
41
+ inputs["depth"] = depths
42
+ return self.model.forward(inputs, gt_index_lst, max_gts_lst)
43
+
44
+ def forward(
45
+ self,
46
+ images: Tensor,
47
+ depths: Optional[Tensor],
48
+ gt_index_lst: Optional[List[int]] = None,
49
+ max_gts_lst: Optional[List[int]] = None,
50
+ ) -> Dict[str, Tensor]:
51
+ return self.encode_decode(images, depths, gt_index_lst, max_gts_lst)
52
+
53
+ def inference(
54
+ self,
55
+ images: Tensor,
56
+ depths: Optional[Tensor],
57
+ gt_index_lst: Optional[List[int]] = None,
58
+ max_gts_lst: Optional[List[int]] = None,
59
+ ) -> Dict[str, Tensor]:
60
+ return self.encode_decode(images, depths, gt_index_lst, max_gts_lst)
s_multimae/utils.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from glob import glob
3
+ import random
4
+ from typing import Dict, List
5
+ from torch import nn, Tensor
6
+ import os, shutil
7
+ import torch
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import gc, cv2
11
+
12
+ from .visualizer import post_processing_depth
13
+
14
+ """
15
+ This module should not depend on other s_multimae modules.
16
+ """
17
+
18
+ num_format = "{:,}".format
19
+
20
+
21
+ def list_dirs(dir_root: str) -> List[str]:
22
+ return list(
23
+ sorted(
24
+ [
25
+ item
26
+ for item in os.listdir(dir_root)
27
+ if os.path.isdir(f"{dir_root}/{item}")
28
+ ]
29
+ )
30
+ )
31
+
32
+
33
+ def clean_cache() -> None:
34
+ torch.cuda.empty_cache()
35
+ gc.collect()
36
+
37
+
38
+ def count_parameters(model: nn.Module) -> str:
39
+ """Count the number of learnable parameters of a model"""
40
+ return num_format(sum(p.numel() for p in model.parameters() if p.requires_grad))
41
+
42
+
43
+ def ranking_gts_to_dict(
44
+ ranking_gts: List[np.ndarray | str],
45
+ ) -> Dict[str, np.ndarray | str]:
46
+ """
47
+ Return:
48
+ dict(
49
+ gt0=ranking_gts[0],
50
+ gt1=ranking_gts[1],
51
+ gt2=ranking_gts[2],
52
+ gt3=ranking_gts[3],
53
+ gt4=ranking_gts[4],
54
+ )
55
+ """
56
+ return {f"gt{i}": v for i, v in enumerate(ranking_gts)}
57
+
58
+
59
+ def dict_to_ranking_gts(d: Dict[str, np.ndarray], l=5) -> List[np.ndarray]:
60
+ """
61
+ Return: [ranking_gts["gt0"], ranking_gts["gt1"], ...]
62
+ """
63
+ return [d[f"gt{i}"] for i in range(l)]
64
+
65
+
66
+ def random_choice(p: float) -> bool:
67
+ """Return True if random float <= p"""
68
+ return random.random() <= p
69
+
70
+
71
+ def fname_without_ext(p: str) -> str:
72
+ return os.path.splitext(os.path.basename(p))[0]
73
+
74
+
75
+ def list_files(
76
+ dirpaths: List[str] = [
77
+ "datasets/v1/train/RGB",
78
+ "datasets/v1/train/GT",
79
+ "datasets/v1/train/depths",
80
+ ],
81
+ ) -> List[List[str]]:
82
+ assert len(dirpaths) >= 1, "dirnames must contain at least 1 item"
83
+
84
+ fullpaths_lst: List[List[str]] = []
85
+ names_lst: List[List[str]] = []
86
+
87
+ for dirname in dirpaths:
88
+ fullpaths = list(sorted(glob(os.path.join(dirname, "*"))))
89
+ names = [fname_without_ext(fullpath) for fullpath in fullpaths]
90
+ fullpaths_lst.append(fullpaths)
91
+ names_lst.append(names)
92
+
93
+ rs: List[List[str]] = [fullpaths_lst[0]] + [[] for _ in range(len(dirpaths) - 1)]
94
+
95
+ # Ensure integrity
96
+ assert (
97
+ len(set([len(e) for e in names_lst])) == 1
98
+ ), f"Data is not integrity {[len(e) for e in names_lst]} | dirpath = {dirpaths}"
99
+
100
+ for name in names_lst[0]:
101
+ for i, names in enumerate(names_lst[1:]):
102
+ idx = names.index(name)
103
+ rs[i + 1].append(fullpaths_lst[i + 1][idx])
104
+
105
+ return rs
106
+
107
+
108
+ def scale_saliency_maps(inputs: Tensor) -> Tensor:
109
+ """Input: Tensor, shape of (B, C, H, W)"""
110
+ min_v = (
111
+ torch.min(torch.flatten(inputs, 1), dim=1)[0]
112
+ .unsqueeze(1)
113
+ .unsqueeze(1)
114
+ .unsqueeze(1)
115
+ )
116
+ max_v = (
117
+ torch.max(torch.flatten(inputs, 1), dim=1)[0]
118
+ .unsqueeze(1)
119
+ .unsqueeze(1)
120
+ .unsqueeze(1)
121
+ )
122
+ return (inputs - min_v) / (max_v - min_v + 1e-8)
123
+
124
+
125
+ def get_epoch_from_ckpt_path(ckpt_path: str) -> int:
126
+ """Example ckpt_path
127
+ os.path.join(experiment_dir_path, 'cfgv2.3', 'checkpoint_100.pt')
128
+ """
129
+ return int(ckpt_path.split("_")[-1].split(".")[0])
130
+
131
+
132
+ def clean_dir(dir_path: str) -> None:
133
+ """Remove a directory if existed and create an empty directory"""
134
+ if os.path.isdir(dir_path):
135
+ shutil.rmtree(dir_path)
136
+ os.makedirs(dir_path, exist_ok=True)
137
+
138
+
139
+ def get_sota_type(experiment_name: str) -> int:
140
+ """0 for SOTAs, 4 for experiment version 4, e.g. ..."""
141
+ if "cfgv" not in experiment_name:
142
+ return 0
143
+
144
+ half_right = experiment_name.split("cfgv")[1]
145
+ return int(half_right.split("_")[0])
146
+
147
+
148
+ def hex_to_rgb(hex: str) -> np.ndarray:
149
+ """Convert hex color to rgb color
150
+
151
+ Args:
152
+ hex (str): "#00f900"
153
+
154
+ Returns:
155
+ np.ndarray: numpy array of rgb color
156
+ """
157
+ hex = hex[1:]
158
+ rgb = []
159
+ for i in (0, 2, 4):
160
+ decimal = int(hex[i : i + 2], 16)
161
+ rgb.append(decimal)
162
+
163
+ return (np.array(rgb) / 255.0)[::-1]
164
+
165
+
166
+ def normalize(data: np.ndarray) -> np.ndarray:
167
+ return (data - data.min()) / (data.max() - data.min() + 1e-8)
168
+
169
+
170
+ def post_processing_depth(depth_path: str) -> np.ndarray:
171
+ depth = np.array(Image.open(depth_path).convert("L"))
172
+ depth = (normalize(depth) * 255).astype(np.uint8)
173
+ return cv2.applyColorMap(depth, cv2.COLORMAP_SUMMER)
174
+
175
+
176
+ def convert_batch_tensors_to_numpy_images(images: Tensor) -> np.ndarray:
177
+ """images of shape (batch_size, channels, width, height)"""
178
+ images = torch.permute(images, (0, 2, 3, 1))
179
+ images = images.numpy()
180
+ if images.shape[3] == 1:
181
+ return np.squeeze(images, axis=3)
182
+ else:
183
+ return images
184
+
185
+
186
+ def join_horizontally(lst: List[np.ndarray]) -> np.ndarray:
187
+ return np.concatenate(lst, axis=1)
188
+
189
+
190
+ def join_vertically(lst: List[np.ndarray]) -> np.ndarray:
191
+ return np.concatenate(lst, axis=0)
192
+
193
+
194
+ def plot_batch_of_pairs(
195
+ images: Tensor,
196
+ depths: Tensor,
197
+ gts: Tensor,
198
+ save_file_path: str,
199
+ ) -> None:
200
+ images = convert_batch_tensors_to_numpy_images(images)
201
+ depths = convert_batch_tensors_to_numpy_images(depths)
202
+ gts = convert_batch_tensors_to_numpy_images(gts)
203
+ batch_size = images.shape[0]
204
+ samples: List[np.ndarray] = []
205
+
206
+ # fig, axes = plt.subplots(batch_size, 3, figsize=(3*batch_size, 20)) # (number of images, 3)
207
+ for i in range(batch_size):
208
+ samples.append(
209
+ join_horizontally(
210
+ [
211
+ ((images[i] + 1.0) / 2 * 255).astype(np.uint8),
212
+ post_processing_depth(depths[i]),
213
+ post_processing_depth(gts[i]),
214
+ ]
215
+ )
216
+ )
217
+ # axes[i, 0].imshow(images[i])
218
+ # axes[i, 1].imshow(depths[i])
219
+ # axes[i, 2].imshow(gts[i])
220
+ # plt.show()
221
+
222
+ final = join_vertically(samples)
223
+ cv2.imwrite(save_file_path, cv2.cvtColor(final, cv2.COLOR_RGB2BGR))
224
+ print(f"Saved to file {save_file_path}")
225
+
226
+
227
+ def plot_pairs(image: np.ndarray, depth: np.ndarray, gt: np.ndarray) -> None:
228
+ batch_size = 1
229
+ fig, axes = plt.subplots(
230
+ batch_size, 3, figsize=(3 * batch_size, 20)
231
+ ) # (number of images, 3)
232
+ for i in range(batch_size):
233
+ axes[i, 0].imshow(image)
234
+ axes[i, 1].imshow(depth)
235
+ axes[i, 2].imshow(gt)
236
+ plt.show()
s_multimae/visualize_2d_posemb.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torch import Tensor
3
+ import matplotlib.pyplot as plt
4
+
5
+ from s_multimae.model.multimae import build_2d_sincos_posemb
6
+
7
+
8
+ def visualize_2d_posemb():
9
+ NH, NW = 14, 14
10
+ dim_tokens = 768
11
+
12
+ colors = [
13
+ "Greys",
14
+ "Purples",
15
+ "Blues",
16
+ "Greens",
17
+ "Oranges",
18
+ "Reds",
19
+ "YlOrBr",
20
+ "YlOrRd",
21
+ "OrRd",
22
+ "PuRd",
23
+ "RdPu",
24
+ "BuPu",
25
+ "GnBu",
26
+ "PuBu",
27
+ "YlGnBu",
28
+ "PuBuGn",
29
+ "BuGn",
30
+ "YlGn",
31
+ ]
32
+
33
+ pos_emb: Tensor = build_2d_sincos_posemb(NH, NW, dim_tokens)
34
+ pos_emb_numpy: np.ndarray = (
35
+ pos_emb.squeeze(0).permute(1, 2, 0).numpy()
36
+ ) # 14 x 14 x 768
37
+
38
+ x = np.linspace(0, NH - 1, NH)
39
+ y = np.linspace(0, NW - 1, NW)
40
+ X, Y = np.meshgrid(x, y)
41
+
42
+ for color, i in zip(colors, range(0, pos_emb_numpy.shape[2], 100)):
43
+ ax = plt.axes(projection="3d")
44
+ Z = pos_emb_numpy[:, :, i]
45
+
46
+ # plt.imshow(Z, cmap='viridis')
47
+ # plt.savefig(f'posemb_visualization/test_{i}.png')
48
+
49
+ ax.plot_surface(
50
+ X,
51
+ Y,
52
+ Z,
53
+ # rstride=1, cstride=1,
54
+ cmap="viridis",
55
+ edgecolor="none",
56
+ )
57
+ plt.show()
58
+ plt.savefig(f"posemb_visualization/test_{i}.png")
s_multimae/visualizer.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import colorsys
2
+ from typing import Union
3
+ import numpy as np
4
+ import cv2
5
+ import matplotlib.colors as mplc
6
+ import pycocotools.mask as mask_util
7
+ import matplotlib.figure as mplfigure
8
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
9
+ import matplotlib as mpl
10
+ from enum import Enum, unique
11
+ from PIL import Image
12
+
13
+ _LARGE_MASK_AREA_THRESH = 120000
14
+ _COLORS = (
15
+ np.array(
16
+ [
17
+ 0.000,
18
+ 0.447,
19
+ 0.741,
20
+ 0.850,
21
+ 0.325,
22
+ 0.098,
23
+ 0.929,
24
+ 0.694,
25
+ 0.125,
26
+ 0.494,
27
+ 0.184,
28
+ 0.556,
29
+ 0.466,
30
+ 0.674,
31
+ 0.188,
32
+ 0.301,
33
+ 0.745,
34
+ 0.933,
35
+ 0.635,
36
+ 0.078,
37
+ 0.184,
38
+ 0.300,
39
+ 0.300,
40
+ 0.300,
41
+ 0.600,
42
+ 0.600,
43
+ 0.600,
44
+ 1.000,
45
+ 0.000,
46
+ 0.000,
47
+ 1.000,
48
+ 0.500,
49
+ 0.000,
50
+ 0.749,
51
+ 0.749,
52
+ 0.000,
53
+ 0.000,
54
+ 1.000,
55
+ 0.000,
56
+ 0.000,
57
+ 0.000,
58
+ 1.000,
59
+ 0.667,
60
+ 0.000,
61
+ 1.000,
62
+ 0.333,
63
+ 0.333,
64
+ 0.000,
65
+ 0.333,
66
+ 0.667,
67
+ 0.000,
68
+ 0.333,
69
+ 1.000,
70
+ 0.000,
71
+ 0.667,
72
+ 0.333,
73
+ 0.000,
74
+ 0.667,
75
+ 0.667,
76
+ 0.000,
77
+ 0.667,
78
+ 1.000,
79
+ 0.000,
80
+ 1.000,
81
+ 0.333,
82
+ 0.000,
83
+ 1.000,
84
+ 0.667,
85
+ 0.000,
86
+ 1.000,
87
+ 1.000,
88
+ 0.000,
89
+ 0.000,
90
+ 0.333,
91
+ 0.500,
92
+ 0.000,
93
+ 0.667,
94
+ 0.500,
95
+ 0.000,
96
+ 1.000,
97
+ 0.500,
98
+ 0.333,
99
+ 0.000,
100
+ 0.500,
101
+ 0.333,
102
+ 0.333,
103
+ 0.500,
104
+ 0.333,
105
+ 0.667,
106
+ 0.500,
107
+ 0.333,
108
+ 1.000,
109
+ 0.500,
110
+ 0.667,
111
+ 0.000,
112
+ 0.500,
113
+ 0.667,
114
+ 0.333,
115
+ 0.500,
116
+ 0.667,
117
+ 0.667,
118
+ 0.500,
119
+ 0.667,
120
+ 1.000,
121
+ 0.500,
122
+ 1.000,
123
+ 0.000,
124
+ 0.500,
125
+ 1.000,
126
+ 0.333,
127
+ 0.500,
128
+ 1.000,
129
+ 0.667,
130
+ 0.500,
131
+ 1.000,
132
+ 1.000,
133
+ 0.500,
134
+ 0.000,
135
+ 0.333,
136
+ 1.000,
137
+ 0.000,
138
+ 0.667,
139
+ 1.000,
140
+ 0.000,
141
+ 1.000,
142
+ 1.000,
143
+ 0.333,
144
+ 0.000,
145
+ 1.000,
146
+ 0.333,
147
+ 0.333,
148
+ 1.000,
149
+ 0.333,
150
+ 0.667,
151
+ 1.000,
152
+ 0.333,
153
+ 1.000,
154
+ 1.000,
155
+ 0.667,
156
+ 0.000,
157
+ 1.000,
158
+ 0.667,
159
+ 0.333,
160
+ 1.000,
161
+ 0.667,
162
+ 0.667,
163
+ 1.000,
164
+ 0.667,
165
+ 1.000,
166
+ 1.000,
167
+ 1.000,
168
+ 0.000,
169
+ 1.000,
170
+ 1.000,
171
+ 0.333,
172
+ 1.000,
173
+ 1.000,
174
+ 0.667,
175
+ 1.000,
176
+ 0.333,
177
+ 0.000,
178
+ 0.000,
179
+ 0.500,
180
+ 0.000,
181
+ 0.000,
182
+ 0.667,
183
+ 0.000,
184
+ 0.000,
185
+ 0.833,
186
+ 0.000,
187
+ 0.000,
188
+ 1.000,
189
+ 0.000,
190
+ 0.000,
191
+ 0.000,
192
+ 0.167,
193
+ 0.000,
194
+ 0.000,
195
+ 0.333,
196
+ 0.000,
197
+ 0.000,
198
+ 0.500,
199
+ 0.000,
200
+ 0.000,
201
+ 0.667,
202
+ 0.000,
203
+ 0.000,
204
+ 0.833,
205
+ 0.000,
206
+ 0.000,
207
+ 1.000,
208
+ 0.000,
209
+ 0.000,
210
+ 0.000,
211
+ 0.167,
212
+ 0.000,
213
+ 0.000,
214
+ 0.333,
215
+ 0.000,
216
+ 0.000,
217
+ 0.500,
218
+ 0.000,
219
+ 0.000,
220
+ 0.667,
221
+ 0.000,
222
+ 0.000,
223
+ 0.833,
224
+ 0.000,
225
+ 0.000,
226
+ 1.000,
227
+ 0.000,
228
+ 0.000,
229
+ 0.000,
230
+ 0.143,
231
+ 0.143,
232
+ 0.143,
233
+ 0.857,
234
+ 0.857,
235
+ 0.857,
236
+ 1.000,
237
+ 1.000,
238
+ 1.000,
239
+ ]
240
+ )
241
+ .astype(np.float32)
242
+ .reshape(-1, 3)
243
+ )
244
+
245
+
246
+ def random_color(rgb=False, maximum=255):
247
+ """
248
+ Args:
249
+ rgb (bool): whether to return RGB colors or BGR colors.
250
+ maximum (int): either 255 or 1
251
+
252
+ Returns:
253
+ ndarray: a vector of 3 numbers
254
+ """
255
+ idx = np.random.randint(0, len(_COLORS))
256
+ ret = _COLORS[idx] * maximum
257
+ if not rgb:
258
+ ret = ret[::-1]
259
+ return ret
260
+
261
+
262
+ @unique
263
+ class ColorMode(Enum):
264
+ """
265
+ Enum of different color modes to use for instance visualizations.
266
+ """
267
+
268
+ IMAGE = 0
269
+ """
270
+ Picks a random color for every instance and overlay segmentations with low opacity.
271
+ """
272
+ SEGMENTATION = 1
273
+ """
274
+ Let instances of the same category have similar colors
275
+ (from metadata.thing_colors), and overlay them with
276
+ high opacity. This provides more attention on the quality of segmentation.
277
+ """
278
+ IMAGE_BW = 2
279
+ """
280
+ Same as IMAGE, but convert all areas without masks to gray-scale.
281
+ Only available for drawing per-instance mask predictions.
282
+ """
283
+
284
+
285
+ class VisImage:
286
+ def __init__(self, img, scale=1.0):
287
+ """
288
+ Args:
289
+ img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
290
+ scale (float): scale the input image
291
+ """
292
+ self.img = img
293
+ self.scale = scale
294
+ self.width, self.height = img.shape[1], img.shape[0]
295
+ self._setup_figure(img)
296
+
297
+ def _setup_figure(self, img):
298
+ """
299
+ Args:
300
+ Same as in :meth:`__init__()`.
301
+
302
+ Returns:
303
+ fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
304
+ ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
305
+ """
306
+ fig = mplfigure.Figure(frameon=False)
307
+ self.dpi = fig.get_dpi()
308
+ # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
309
+ # (https://github.com/matplotlib/matplotlib/issues/15363)
310
+ fig.set_size_inches(
311
+ (self.width * self.scale + 1e-2) / self.dpi,
312
+ (self.height * self.scale + 1e-2) / self.dpi,
313
+ )
314
+ self.canvas = FigureCanvasAgg(fig)
315
+ # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
316
+ ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
317
+ ax.axis("off")
318
+ self.fig = fig
319
+ self.ax = ax
320
+ self.reset_image(img)
321
+
322
+ def reset_image(self, img):
323
+ """
324
+ Args:
325
+ img: same as in __init__
326
+ """
327
+ img = img.astype("uint8")
328
+ self.ax.imshow(
329
+ img, extent=(0, self.width, self.height, 0), interpolation="nearest"
330
+ )
331
+
332
+ def save(self, filepath):
333
+ """
334
+ Args:
335
+ filepath (str): a string that contains the absolute path, including the file name, where
336
+ the visualized image will be saved.
337
+ """
338
+ self.fig.savefig(filepath)
339
+
340
+ def get_image(self):
341
+ """
342
+ Returns:
343
+ ndarray:
344
+ the visualized image of shape (H, W, 3) (RGB) in uint8 type.
345
+ The shape is scaled w.r.t the input image using the given `scale` argument.
346
+ """
347
+ canvas = self.canvas
348
+ s, (width, height) = canvas.print_to_buffer()
349
+ # buf = io.BytesIO() # works for cairo backend
350
+ # canvas.print_rgba(buf)
351
+ # width, height = self.width, self.height
352
+ # s = buf.getvalue()
353
+
354
+ buffer = np.frombuffer(s, dtype="uint8")
355
+
356
+ img_rgba = buffer.reshape(height, width, 4)
357
+ rgb, alpha = np.split(img_rgba, [3], axis=2)
358
+ return rgb.astype("uint8")
359
+
360
+
361
+ class GenericMask:
362
+ """
363
+ Attribute:
364
+ polygons (list[ndarray]): list[ndarray]: polygons for this mask.
365
+ Each ndarray has format [x, y, x, y, ...]
366
+ mask (ndarray): a binary mask
367
+ """
368
+
369
+ def __init__(self, mask_or_polygons, height, width):
370
+ self._mask = self._polygons = self._has_holes = None
371
+ self.height = height
372
+ self.width = width
373
+
374
+ m = mask_or_polygons
375
+ if isinstance(m, dict):
376
+ # RLEs
377
+ assert "counts" in m and "size" in m
378
+ if isinstance(m["counts"], list): # uncompressed RLEs
379
+ h, w = m["size"]
380
+ assert h == height and w == width
381
+ m = mask_util.frPyObjects(m, h, w)
382
+ self._mask = mask_util.decode(m)[:, :]
383
+ return
384
+
385
+ if isinstance(m, list): # list[ndarray]
386
+ self._polygons = [np.asarray(x).reshape(-1) for x in m]
387
+ return
388
+
389
+ if isinstance(m, np.ndarray): # assumed to be a binary mask
390
+ assert m.shape[1] != 2, m.shape
391
+ assert m.shape == (
392
+ height,
393
+ width,
394
+ ), f"mask shape: {m.shape}, target dims: {height}, {width}"
395
+ self._mask = m.astype("uint8")
396
+ return
397
+
398
+ raise ValueError(
399
+ "GenericMask cannot handle object {} of type '{}'".format(m, type(m))
400
+ )
401
+
402
+ @property
403
+ def mask(self):
404
+ if self._mask is None:
405
+ self._mask = self.polygons_to_mask(self._polygons)
406
+ return self._mask
407
+
408
+ @property
409
+ def polygons(self):
410
+ if self._polygons is None:
411
+ self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
412
+ return self._polygons
413
+
414
+ @property
415
+ def has_holes(self):
416
+ if self._has_holes is None:
417
+ if self._mask is not None:
418
+ self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
419
+ else:
420
+ self._has_holes = (
421
+ False # if original format is polygon, does not have holes
422
+ )
423
+ return self._has_holes
424
+
425
+ def mask_to_polygons(self, mask):
426
+ # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
427
+ # hierarchy. External contours (boundary) of the object are placed in hierarchy-1.
428
+ # Internal contours (holes) are placed in hierarchy-2.
429
+ # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.
430
+ mask = np.ascontiguousarray(
431
+ mask
432
+ ) # some versions of cv2 does not support incontiguous arr
433
+ res = cv2.findContours(
434
+ mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE
435
+ )
436
+ hierarchy = res[-1]
437
+ if hierarchy is None: # empty mask
438
+ return [], False
439
+ has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
440
+ res = res[-2]
441
+ res = [x.flatten() for x in res]
442
+ # These coordinates from OpenCV are integers in range [0, W-1 or H-1].
443
+ # We add 0.5 to turn them into real-value coordinate space. A better solution
444
+ # would be to first +0.5 and then dilate the returned polygon by 0.5.
445
+ res = [x + 0.5 for x in res if len(x) >= 6]
446
+ return res, has_holes
447
+
448
+ def polygons_to_mask(self, polygons):
449
+ rle = mask_util.frPyObjects(polygons, self.height, self.width)
450
+ rle = mask_util.merge(rle)
451
+ return mask_util.decode(rle)[:, :]
452
+
453
+ def area(self):
454
+ return self.mask.sum()
455
+
456
+ def bbox(self):
457
+ p = mask_util.frPyObjects(self.polygons, self.height, self.width)
458
+ p = mask_util.merge(p)
459
+ bbox = mask_util.toBbox(p)
460
+ bbox[2] += bbox[0]
461
+ bbox[3] += bbox[1]
462
+ return bbox
463
+
464
+
465
+ class Visualizer:
466
+ """
467
+ Visualizer that draws data about detection/segmentation on images.
468
+
469
+ It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
470
+ that draw primitive objects to images, as well as high-level wrappers like
471
+ `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
472
+ that draw composite data in some pre-defined style.
473
+
474
+ Note that the exact visualization style for the high-level wrappers are subject to change.
475
+ Style such as color, opacity, label contents, visibility of labels, or even the visibility
476
+ of objects themselves (e.g. when the object is too small) may change according
477
+ to different heuristics, as long as the results still look visually reasonable.
478
+
479
+ To obtain a consistent style, you can implement custom drawing functions with the
480
+ abovementioned primitive methods instead. If you need more customized visualization
481
+ styles, you can process the data yourself following their format documented in
482
+ tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
483
+ intend to satisfy everyone's preference on drawing styles.
484
+
485
+ This visualizer focuses on high rendering quality rather than performance. It is not
486
+ designed to be used for real-time applications.
487
+ """
488
+
489
+ # TODO implement a fast, rasterized version using OpenCV
490
+
491
+ def __init__(
492
+ self,
493
+ img_rgb: Union[Image.Image, np.ndarray],
494
+ scale=1.0,
495
+ instance_mode=ColorMode.IMAGE,
496
+ ):
497
+ """
498
+ Args:
499
+ img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
500
+ the height and width of the image respectively. C is the number of
501
+ color channels. The image is required to be in RGB format since that
502
+ is a requirement of the Matplotlib library. The image is also expected
503
+ to be in the range [0, 255].
504
+ instance_mode (ColorMode): defines one of the pre-defined style for drawing
505
+ instances on an image.
506
+ """
507
+ if type(img_rgb) == np.ndarray:
508
+ img_rgb = img_rgb[:, :, ::-1]
509
+ else:
510
+ img_rgb = np.array(img_rgb)[:, :, ::-1]
511
+ self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
512
+ self.output = VisImage(self.img, scale=scale)
513
+
514
+ # too small texts are useless, therefore clamp to 9
515
+ self._default_font_size = max(
516
+ np.sqrt(self.output.height * self.output.width) // 90, 10 // scale
517
+ )
518
+ self._instance_mode = instance_mode
519
+
520
+ def draw_binary_mask(
521
+ self,
522
+ binary_mask,
523
+ color=None,
524
+ *,
525
+ edge_color=None,
526
+ text=None,
527
+ alpha=0.5,
528
+ area_threshold=10,
529
+ ):
530
+ """
531
+ Args:
532
+ binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
533
+ W is the image width. Each value in the array is either a 0 or 1 value of uint8
534
+ type.
535
+ color: color of the mask. Refer to `matplotlib.colors` for a full list of
536
+ formats that are accepted. If None, will pick a random color.
537
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
538
+ full list of formats that are accepted.
539
+ text (str): if None, will be drawn on the object
540
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
541
+ area_threshold (float): a connected component smaller than this area will not be shown.
542
+
543
+ Returns:
544
+ output (VisImage): image object with mask drawn.
545
+ """
546
+ if color is None:
547
+ color = random_color(rgb=True, maximum=1)
548
+ color = mplc.to_rgb(color)
549
+
550
+ has_valid_segment = False
551
+ binary_mask = binary_mask.astype("uint8") # opencv needs uint8
552
+ mask = GenericMask(binary_mask, self.output.height, self.output.width)
553
+ shape2d = (binary_mask.shape[0], binary_mask.shape[1])
554
+
555
+ if not mask.has_holes:
556
+ # draw polygons for regular masks
557
+ for segment in mask.polygons:
558
+ area = mask_util.area(
559
+ mask_util.frPyObjects([segment], shape2d[0], shape2d[1])
560
+ )
561
+ if area < (area_threshold or 0):
562
+ continue
563
+ has_valid_segment = True
564
+ segment = segment.reshape(-1, 2)
565
+ self.draw_polygon(
566
+ segment, color=color, edge_color=edge_color, alpha=alpha
567
+ )
568
+ else:
569
+ # TODO: Use Path/PathPatch to draw vector graphics:
570
+ # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
571
+ rgba = np.zeros(shape2d + (4,), dtype="float32")
572
+ rgba[:, :, :3] = color
573
+ rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
574
+ has_valid_segment = True
575
+ self.output.ax.imshow(
576
+ rgba, extent=(0, self.output.width, self.output.height, 0)
577
+ )
578
+
579
+ if text is not None and has_valid_segment:
580
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
581
+ self._draw_text_in_mask(binary_mask, text, lighter_color)
582
+ return self.output
583
+
584
+ def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
585
+ """
586
+ Args:
587
+ segment: numpy array of shape Nx2, containing all the points in the polygon.
588
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
589
+ formats that are accepted.
590
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
591
+ full list of formats that are accepted. If not provided, a darker shade
592
+ of the polygon color will be used instead.
593
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
594
+
595
+ Returns:
596
+ output (VisImage): image object with polygon drawn.
597
+ """
598
+ if edge_color is None:
599
+ # make edge color darker than the polygon color
600
+ if alpha > 0.8:
601
+ edge_color = self._change_color_brightness(
602
+ color, brightness_factor=-0.7
603
+ )
604
+ else:
605
+ edge_color = color
606
+ edge_color = mplc.to_rgb(edge_color) + (1,)
607
+
608
+ polygon = mpl.patches.Polygon(
609
+ segment,
610
+ fill=True,
611
+ facecolor=mplc.to_rgb(color) + (alpha,),
612
+ edgecolor=edge_color,
613
+ linewidth=max(self._default_font_size // 15 * self.output.scale, 1),
614
+ )
615
+ self.output.ax.add_patch(polygon)
616
+ return self.output
617
+
618
+ """
619
+ Internal methods:
620
+ """
621
+
622
+ def _change_color_brightness(self, color, brightness_factor):
623
+ """
624
+ Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
625
+ less or more saturation than the original color.
626
+
627
+ Args:
628
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
629
+ formats that are accepted.
630
+ brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
631
+ 0 will correspond to no change, a factor in [-1.0, 0) range will result in
632
+ a darker color and a factor in (0, 1.0] range will result in a lighter color.
633
+
634
+ Returns:
635
+ modified_color (tuple[double]): a tuple containing the RGB values of the
636
+ modified color. Each value in the tuple is in the [0.0, 1.0] range.
637
+ """
638
+ assert brightness_factor >= -1.0 and brightness_factor <= 1.0
639
+ color = mplc.to_rgb(color)
640
+ polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
641
+ modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
642
+ modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
643
+ modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
644
+ modified_color = colorsys.hls_to_rgb(
645
+ polygon_color[0], modified_lightness, polygon_color[2]
646
+ )
647
+ return modified_color
648
+
649
+ def _draw_text_in_mask(self, binary_mask, text, color):
650
+ """
651
+ Find proper places to draw text given a binary mask.
652
+ """
653
+ # TODO sometimes drawn on wrong objects. the heuristics here can improve.
654
+ _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(
655
+ binary_mask, 8
656
+ )
657
+ if stats[1:, -1].size == 0:
658
+ return
659
+ largest_component_id = np.argmax(stats[1:, -1]) + 1
660
+
661
+ # draw text on the largest component, as well as other very large components.
662
+ for cid in range(1, _num_cc):
663
+ if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
664
+ # median is more stable than centroid
665
+ # center = centroids[largest_component_id]
666
+ center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
667
+ self.draw_text(text, center, color=color)
668
+
669
+ def get_output(self):
670
+ """
671
+ Returns:
672
+ output (VisImage): the image output containing the visualizations added
673
+ to the image.
674
+ """
675
+ return self.output
676
+
677
+
678
+ def apply_threshold(pred: np.ndarray) -> np.ndarray:
679
+ """Apply threshold to a salient map
680
+
681
+ Args:
682
+ pred (np.ndarray): each pixel is in range [0, 255]
683
+
684
+ Returns:
685
+ np.ndarray: each pixel is only 0.0 or 1.0
686
+ """
687
+ binary_mask = pred / 255.0
688
+ binary_mask[binary_mask >= 0.5] = 1.0
689
+ binary_mask[binary_mask < 0.5] = 0.0
690
+ return binary_mask
691
+
692
+
693
+ def normalize(data: np.ndarray) -> np.ndarray:
694
+ return (data - data.min()) / (data.max() - data.min() + 1e-8)
695
+
696
+
697
+ def post_processing_depth(depth: np.ndarray) -> np.ndarray:
698
+ depth = (normalize(depth) * 255).astype(np.uint8)
699
+ return cv2.applyColorMap(depth, cv2.COLORMAP_OCEAN)
700
+
701
+
702
+ def apply_vis_to_image(
703
+ rgb: np.ndarray, binary_mask: np.ndarray, color: np.ndarray
704
+ ) -> np.ndarray:
705
+ if rgb.shape[:2] != binary_mask.shape[:2]:
706
+ print(rgb.shape, binary_mask.shape)
707
+ binary_mask = cv2.resize(binary_mask, [rgb.shape[1], rgb.shape[0]])
708
+ visualizer = Visualizer(rgb)
709
+ vis_image: VisImage = visualizer.draw_binary_mask(binary_mask, color)
710
+ vis_image = vis_image.get_image()[:, :, ::-1]
711
+ return vis_image
streamlit_apps/__init__.py ADDED
File without changes
streamlit_apps/app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ sys.path.append(os.getcwd())
4
+
5
+ import multiprocessing
6
+
7
+ import streamlit as st
8
+
9
+ from app_utils.color_selection_ui import color_selection_ui
10
+ from app_utils.depth_selection_ui import depth_selection_ui
11
+ from app_utils.device import device
12
+ from app_utils.sod_selection_ui import sod_selection_ui
13
+
14
+
15
+ class MODE:
16
+ IMAGE = "image"
17
+ VIDEO = "video"
18
+ WEBRTC = "webrtc"
19
+ DEMO = "demo"
20
+
21
+
22
+ TITLE = "S-MultiMAE: A Multi-Ground Truth approach for RGB-D Saliency Detection"
23
+
24
+ st.set_page_config(
25
+ page_title=TITLE,
26
+ page_icon="🧊",
27
+ layout="wide",
28
+ # initial_sidebar_state="expanded",
29
+ # menu_items={
30
+ # 'Get Help': 'https://www.extremelycoolapp.com/help',
31
+ # 'Report a bug': "https://www.extremelycoolapp.com/bug",
32
+ # 'About': "# This is a header. This is an *extremely* cool app!"
33
+ # }
34
+ )
35
+ st.title(TITLE)
36
+
37
+ with st.expander("INTRODUCTION"):
38
+ st.text(
39
+ f"""Demo for S-MultiMAE.
40
+ Device: {device.type}
41
+ Number of CPU(s): {multiprocessing.cpu_count()}"""
42
+ )
43
+ st.image("docs/figures/proposed_method_v5.drawio.png", use_column_width="always")
44
+
45
+ with st.expander("SETTINGS", expanded=True):
46
+ col1, col2 = st.columns(2)
47
+
48
+ with col1:
49
+ mode = st.radio(
50
+ "Mode",
51
+ (
52
+ MODE.IMAGE,
53
+ # MODE.VIDEO,
54
+ # MODE.WEBRTC,
55
+ # MODE.DEMO,
56
+ ),
57
+ )
58
+ st.markdown("---")
59
+ color = color_selection_ui()
60
+
61
+ with col2:
62
+ depth_model = depth_selection_ui()
63
+ st.markdown("---")
64
+ sod_model, da = sod_selection_ui()
65
+
66
+ with st.expander("HOW TO USE", expanded=True):
67
+ st.text(
68
+ "(1) You can change the model type (using different backbones) in the settings."
69
+ )
70
+ st.text("(2) Upload an RGB image.")
71
+ st.text(
72
+ "(3) (Optional) Provide its corresponding depth. If not present, a pseudo-depth will be inferred by a rgb2depth model."
73
+ )
74
+ st.text(
75
+ "(4) You may try a different number of sets of salient objects the model can produce."
76
+ )
77
+ st.text("""(5) Click "Predict Salient Objects".""")
78
+
79
+ if mode == MODE.IMAGE:
80
+ from app_utils.image_inference import image_inference
81
+
82
+ image_inference(depth_model, sod_model, da, color)
83
+ # elif mode == MODE.VIDEO:
84
+ # from video_inference import video_inference
85
+ # video_inference(depth_model, sod_model, color)
86
+ # elif mode == MODE.WEBRTC:
87
+ # from webrtc_app import webrtc_app
88
+ # webrtc_app(depth_model, sod_model, color)
89
+ # elif mode == MODE.DEMO:
90
+ # from demo import demo
91
+ # demo()
streamlit_apps/app_utils/__init__.py ADDED
File without changes
streamlit_apps/app_utils/app_env.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ app_env = os.environ.get("APP_ENVIRONMENT", "HUGGINGFACE")
4
+
5
+ IMAGE_SIZE = 224
6
+
7
+
8
+ class DEPTH_MODEL_TYPE:
9
+ DPT_DEPTH = "DPTDepth"
10
+ REL_DEPTH = "RelDepth"
11
+
12
+
13
+ class SOD_MODEL_TYPE:
14
+ S_MULTIMAE = "S-MultiMAE"
15
+ SPNET = "SPNet"
16
+ BBSNET = "BBSNet"
streamlit_apps/app_utils/app_utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ from typing import Tuple, Union
5
+ import cv2
6
+ import numpy as np
7
+ import streamlit as st
8
+ from PIL import Image
9
+ from torch import nn
10
+
11
+ num_format = "{:,}".format
12
+
13
+
14
+ def count_parameters(model: nn.Module) -> str:
15
+ """Count the number of parameters of a model"""
16
+ return num_format(sum(p.numel() for p in model.parameters() if p.requires_grad))
17
+
18
+
19
+ class FrameRate:
20
+ def __init__(self) -> None:
21
+ self.c: int = 0
22
+ self.start_time: float = None
23
+ self.NO_FRAMES = 100
24
+ self.fps: float = -1
25
+
26
+ def reset(self) -> None:
27
+ self.start_time = time.time()
28
+ self.c = 0
29
+ self.fps = -1
30
+
31
+ def count(self) -> None:
32
+ self.c += 1
33
+ if self.c % self.NO_FRAMES == 0:
34
+ self.c = 0
35
+ end_time = time.time()
36
+ self.fps = self.NO_FRAMES / (end_time - self.start_time)
37
+ self.start_time = end_time
38
+
39
+ def show_fps(self, image: np.ndarray) -> np.ndarray:
40
+ if self.fps != -1:
41
+ return cv2.putText(
42
+ image,
43
+ f"FPS {self.fps:.0f}",
44
+ (50, 50),
45
+ cv2.FONT_HERSHEY_SIMPLEX,
46
+ fontScale=1,
47
+ color=(255, 0, 0),
48
+ thickness=2,
49
+ )
50
+ else:
51
+ return image
52
+
53
+
54
+ class ImgContainer:
55
+ img: np.ndarray = None # raw image
56
+ frame_rate: FrameRate = FrameRate()
57
+
58
+
59
+ def load_video(video_path: str) -> bytes:
60
+ if not os.path.isfile(video_path):
61
+ return
62
+ with st.spinner(f"Loading video {video_path} ..."):
63
+ video_bytes = open(video_path, "rb").read()
64
+ st.video(video_bytes, format="video/mp4")
65
+
66
+
67
+ def normalize(data: np.ndarray) -> np.ndarray:
68
+ return (data - data.min()) / (data.max() - data.min() + 1e-8)
69
+
70
+
71
+ def get_size(image: Union[Image.Image, np.ndarray]) -> Tuple[int, int]:
72
+ """Get resolution (w, h) of an image
73
+ An input image can be Pillow Image or CV2 Image
74
+ """
75
+ if type(image) == np.ndarray:
76
+ return (image.shape[1], image.shape[0])
77
+ else:
78
+ return image.size
79
+
80
+
81
+ def random_choice(p: float) -> bool:
82
+ """Return True if random float <= p"""
83
+ return random.random() <= p
streamlit_apps/app_utils/base_model.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import numpy as np
3
+ from torch import Tensor, nn
4
+
5
+
6
+ class BaseRGBDModel(nn.Module):
7
+ def __init__(self):
8
+ super(BaseRGBDModel, self).__init__()
9
+ """
10
+ Requirements:
11
+ 1. Construct a model
12
+ 2. Load pretrained weights
13
+ 3. Load model into device
14
+ 4. Construct preprocessing
15
+ """
16
+
17
+ def inference(
18
+ self,
19
+ image: Tensor,
20
+ depth: Tensor,
21
+ origin_shape: np.array,
22
+ ) -> List[np.ndarray]:
23
+ """
24
+ Given:
25
+ - An image (Tensor) with original shape [c, h, w]
26
+ - A depth image (Tensor) with a shape of [c, h, w], do not need to be the same shape as image
27
+
28
+ Requirements:
29
+ 1. Preprocessing
30
+ 2. Inference
31
+ 3. Return saliency maps np.float32 between 0.0 and 1.0,
32
+ with the same size as original size
33
+
34
+ """
35
+ raise NotImplementedError()
36
+
37
+ def batch_inference(
38
+ self,
39
+ images: Tensor,
40
+ depths: Tensor,
41
+ ) -> List[np.ndarray]:
42
+ """
43
+ Given:
44
+ - A batch of images (Tensor) with original shape [b, c, h, w]
45
+ - A batch of depths (Tensor) with a shape of [b, c, h, w], do not need to be the same shape as image
46
+
47
+ Requirements:
48
+ 1. Preprocessing
49
+ 2. Inference
50
+ 3. Return saliency maps np.float32 between 0.0 and 1.0,
51
+ with the same size as original size
52
+
53
+ """
54
+ raise NotImplementedError()
streamlit_apps/app_utils/color_selection_ui.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import streamlit as st
3
+
4
+ from s_multimae.utils import hex_to_rgb
5
+
6
+
7
+ def color_selection_ui() -> np.ndarray:
8
+ color = st.color_picker("Pick A Color", value="#00f900", key="color")
9
+ color = hex_to_rgb(color)
10
+ return color
streamlit_apps/app_utils/depth_model.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ import torchvision.transforms.functional as TF
5
+ from torch import Tensor, nn
6
+
7
+ from .app_utils import count_parameters
8
+ from .device import device
9
+ from .dpt.models import DPTDepthModel
10
+
11
+
12
+ class BaseDepthModel:
13
+ def __init__(self, image_size: int) -> None:
14
+ self.image_size = image_size
15
+ self.model: nn.Module = None
16
+
17
+ def forward(self, image: Tensor) -> Tensor:
18
+ """Perform forward inference for an image
19
+ Input image of shape [c, h, w]
20
+ Return of shape [c, h, w]
21
+ """
22
+ raise NotImplementedError()
23
+
24
+ def batch_forward(self, images: Tensor) -> Tensor:
25
+ """Perform forward inference for a batch of images
26
+ Input images of shape [b, c, h, w]
27
+ Return of shape [b, c, h, w]"""
28
+ raise NotImplementedError()
29
+
30
+ def get_number_of_parameters(self) -> int:
31
+ return count_parameters(self.model)
32
+
33
+
34
+ class DPTDepth(BaseDepthModel):
35
+ def __init__(self, image_size: int) -> None:
36
+ super().__init__(image_size)
37
+ print("DPTDepthconstructor")
38
+ weights_fname = "omnidata_rgb2depth_dpt_hybrid.pth"
39
+ weights_path = os.path.join("weights", weights_fname)
40
+ if not os.path.isfile(weights_path):
41
+ from huggingface_hub import hf_hub_download
42
+ hf_hub_download(repo_id="RGBD-SOD/S-MultiMAE", filename=weights_fname)
43
+ os.system(f"mv {weights_fname} weights")
44
+ omnidata_ckpt = torch.load(
45
+ weights_path,
46
+ map_location="cpu",
47
+ )
48
+
49
+ self.model = DPTDepthModel()
50
+ self.model.load_state_dict(omnidata_ckpt)
51
+ self.model: DPTDepthModel = self.model.to(device).eval()
52
+
53
+ self.transform = transforms.Compose(
54
+ [
55
+ transforms.Resize(
56
+ (self.image_size, self.image_size),
57
+ interpolation=TF.InterpolationMode.BICUBIC,
58
+ ),
59
+ transforms.Normalize(
60
+ (0.5, 0.5, 0.5),
61
+ (0.5, 0.5, 0.5),
62
+ ),
63
+ ]
64
+ )
65
+
66
+ def forward(self, image: Tensor) -> Tensor:
67
+ depth_model_input = self.transform(image.unsqueeze(0))
68
+ return self.model.forward(depth_model_input.to(device)).squeeze(0)
69
+
70
+ def batch_forward(self, images: Tensor) -> Tensor:
71
+ images: Tensor = TF.resize(
72
+ images,
73
+ (self.image_size, self.image_size),
74
+ interpolation=TF.InterpolationMode.BICUBIC,
75
+ )
76
+ depth_model_input = (images - 0.5) / 0.5
77
+ return self.model(depth_model_input.to(device))
streamlit_apps/app_utils/depth_selection_ui.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from .app_env import DEPTH_MODEL_TYPE, IMAGE_SIZE
4
+ from .depth_model import BaseDepthModel, DPTDepth
5
+
6
+
7
+ @st.cache_resource
8
+ def load_depth_model(depth_model_type: DEPTH_MODEL_TYPE) -> DPTDepth:
9
+ if depth_model_type == DEPTH_MODEL_TYPE.DPT_DEPTH:
10
+ return DPTDepth(IMAGE_SIZE)
11
+ else:
12
+ return DPTDepth(IMAGE_SIZE)
13
+
14
+
15
+ def depth_selection_ui() -> BaseDepthModel:
16
+ depth_model: BaseDepthModel = None
17
+ depth_model_type = st.selectbox(
18
+ "Choose depth model",
19
+ (
20
+ DEPTH_MODEL_TYPE.DPT_DEPTH,
21
+ # DEPTH_MODEL_TYPE.REL_DEPTH,
22
+ ),
23
+ key="depth_model_type",
24
+ )
25
+ depth_model = load_depth_model(depth_model_type)
26
+ st.text(f"Number of parameters {depth_model.get_number_of_parameters()}")
27
+ return depth_model
streamlit_apps/app_utils/device.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch
2
+
3
+ cpu_device = torch.device("cpu")
4
+ # device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
5
+ device = cpu_device
streamlit_apps/app_utils/dpt/__init__.py ADDED
File without changes
streamlit_apps/app_utils/dpt/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device("cpu"))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)
streamlit_apps/app_utils/dpt/blocks.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .vit import (
5
+ _make_pretrained_vitb_rn50_384,
6
+ _make_pretrained_vitl16_384,
7
+ _make_pretrained_vitb16_384,
8
+ forward_vit,
9
+ )
10
+
11
+
12
+ def _make_encoder(
13
+ backbone,
14
+ features,
15
+ use_pretrained,
16
+ groups=1,
17
+ expand=False,
18
+ exportable=True,
19
+ hooks=None,
20
+ use_vit_only=False,
21
+ use_readout="ignore",
22
+ enable_attention_hooks=False,
23
+ ):
24
+ if backbone == "vitl16_384":
25
+ pretrained = _make_pretrained_vitl16_384(
26
+ use_pretrained,
27
+ hooks=hooks,
28
+ use_readout=use_readout,
29
+ enable_attention_hooks=enable_attention_hooks,
30
+ )
31
+ scratch = _make_scratch(
32
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
33
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
34
+ elif backbone == "vitb_rn50_384":
35
+ pretrained = _make_pretrained_vitb_rn50_384(
36
+ use_pretrained,
37
+ hooks=hooks,
38
+ use_vit_only=use_vit_only,
39
+ use_readout=use_readout,
40
+ enable_attention_hooks=enable_attention_hooks,
41
+ )
42
+ scratch = _make_scratch(
43
+ [256, 512, 768, 768], features, groups=groups, expand=expand
44
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
45
+ elif backbone == "vitb16_384":
46
+ pretrained = _make_pretrained_vitb16_384(
47
+ use_pretrained,
48
+ hooks=hooks,
49
+ use_readout=use_readout,
50
+ enable_attention_hooks=enable_attention_hooks,
51
+ )
52
+ scratch = _make_scratch(
53
+ [96, 192, 384, 768], features, groups=groups, expand=expand
54
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
55
+ elif backbone == "resnext101_wsl":
56
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
57
+ scratch = _make_scratch(
58
+ [256, 512, 1024, 2048], features, groups=groups, expand=expand
59
+ ) # efficientnet_lite3
60
+ else:
61
+ print(f"Backbone '{backbone}' not implemented")
62
+ assert False
63
+
64
+ return pretrained, scratch
65
+
66
+
67
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
68
+ scratch = nn.Module()
69
+
70
+ out_shape1 = out_shape
71
+ out_shape2 = out_shape
72
+ out_shape3 = out_shape
73
+ out_shape4 = out_shape
74
+ if expand == True:
75
+ out_shape1 = out_shape
76
+ out_shape2 = out_shape * 2
77
+ out_shape3 = out_shape * 4
78
+ out_shape4 = out_shape * 8
79
+
80
+ scratch.layer1_rn = nn.Conv2d(
81
+ in_shape[0],
82
+ out_shape1,
83
+ kernel_size=3,
84
+ stride=1,
85
+ padding=1,
86
+ bias=False,
87
+ groups=groups,
88
+ )
89
+ scratch.layer2_rn = nn.Conv2d(
90
+ in_shape[1],
91
+ out_shape2,
92
+ kernel_size=3,
93
+ stride=1,
94
+ padding=1,
95
+ bias=False,
96
+ groups=groups,
97
+ )
98
+ scratch.layer3_rn = nn.Conv2d(
99
+ in_shape[2],
100
+ out_shape3,
101
+ kernel_size=3,
102
+ stride=1,
103
+ padding=1,
104
+ bias=False,
105
+ groups=groups,
106
+ )
107
+ scratch.layer4_rn = nn.Conv2d(
108
+ in_shape[3],
109
+ out_shape4,
110
+ kernel_size=3,
111
+ stride=1,
112
+ padding=1,
113
+ bias=False,
114
+ groups=groups,
115
+ )
116
+
117
+ return scratch
118
+
119
+
120
+ def _make_resnet_backbone(resnet):
121
+ pretrained = nn.Module()
122
+ pretrained.layer1 = nn.Sequential(
123
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
124
+ )
125
+
126
+ pretrained.layer2 = resnet.layer2
127
+ pretrained.layer3 = resnet.layer3
128
+ pretrained.layer4 = resnet.layer4
129
+
130
+ return pretrained
131
+
132
+
133
+ def _make_pretrained_resnext101_wsl(use_pretrained):
134
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
135
+ return _make_resnet_backbone(resnet)
136
+
137
+
138
+ class Interpolate(nn.Module):
139
+ """Interpolation module."""
140
+
141
+ def __init__(self, scale_factor, mode, align_corners=False):
142
+ """Init.
143
+
144
+ Args:
145
+ scale_factor (float): scaling
146
+ mode (str): interpolation mode
147
+ """
148
+ super(Interpolate, self).__init__()
149
+
150
+ self.interp = nn.functional.interpolate
151
+ self.scale_factor = scale_factor
152
+ self.mode = mode
153
+ self.align_corners = align_corners
154
+
155
+ def forward(self, x):
156
+ """Forward pass.
157
+
158
+ Args:
159
+ x (tensor): input
160
+
161
+ Returns:
162
+ tensor: interpolated data
163
+ """
164
+
165
+ x = self.interp(
166
+ x,
167
+ scale_factor=self.scale_factor,
168
+ mode=self.mode,
169
+ align_corners=self.align_corners,
170
+ )
171
+
172
+ return x
173
+
174
+
175
+ class ResidualConvUnit(nn.Module):
176
+ """Residual convolution module."""
177
+
178
+ def __init__(self, features):
179
+ """Init.
180
+
181
+ Args:
182
+ features (int): number of features
183
+ """
184
+ super().__init__()
185
+
186
+ self.conv1 = nn.Conv2d(
187
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
188
+ )
189
+
190
+ self.conv2 = nn.Conv2d(
191
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
192
+ )
193
+
194
+ self.relu = nn.ReLU(inplace=True)
195
+
196
+ def forward(self, x):
197
+ """Forward pass.
198
+
199
+ Args:
200
+ x (tensor): input
201
+
202
+ Returns:
203
+ tensor: output
204
+ """
205
+ out = self.relu(x)
206
+ out = self.conv1(out)
207
+ out = self.relu(out)
208
+ out = self.conv2(out)
209
+
210
+ return out + x
211
+
212
+
213
+ class FeatureFusionBlock(nn.Module):
214
+ """Feature fusion block."""
215
+
216
+ def __init__(self, features):
217
+ """Init.
218
+
219
+ Args:
220
+ features (int): number of features
221
+ """
222
+ super(FeatureFusionBlock, self).__init__()
223
+
224
+ self.resConfUnit1 = ResidualConvUnit(features)
225
+ self.resConfUnit2 = ResidualConvUnit(features)
226
+
227
+ def forward(self, *xs):
228
+ """Forward pass.
229
+
230
+ Returns:
231
+ tensor: output
232
+ """
233
+ output = xs[0]
234
+
235
+ if len(xs) == 2:
236
+ output += self.resConfUnit1(xs[1])
237
+
238
+ output = self.resConfUnit2(output)
239
+
240
+ output = nn.functional.interpolate(
241
+ output, scale_factor=2, mode="bilinear", align_corners=True
242
+ )
243
+
244
+ return output
245
+
246
+
247
+ class ResidualConvUnit_custom(nn.Module):
248
+ """Residual convolution module."""
249
+
250
+ def __init__(self, features, activation, bn):
251
+ """Init.
252
+
253
+ Args:
254
+ features (int): number of features
255
+ """
256
+ super().__init__()
257
+
258
+ self.bn = bn
259
+
260
+ self.groups = 1
261
+
262
+ self.conv1 = nn.Conv2d(
263
+ features,
264
+ features,
265
+ kernel_size=3,
266
+ stride=1,
267
+ padding=1,
268
+ bias=not self.bn,
269
+ groups=self.groups,
270
+ )
271
+
272
+ self.conv2 = nn.Conv2d(
273
+ features,
274
+ features,
275
+ kernel_size=3,
276
+ stride=1,
277
+ padding=1,
278
+ bias=not self.bn,
279
+ groups=self.groups,
280
+ )
281
+
282
+ if self.bn == True:
283
+ self.bn1 = nn.BatchNorm2d(features)
284
+ self.bn2 = nn.BatchNorm2d(features)
285
+
286
+ self.activation = activation
287
+
288
+ self.skip_add = nn.quantized.FloatFunctional()
289
+
290
+ def forward(self, x):
291
+ """Forward pass.
292
+
293
+ Args:
294
+ x (tensor): input
295
+
296
+ Returns:
297
+ tensor: output
298
+ """
299
+
300
+ out = self.activation(x)
301
+ out = self.conv1(out)
302
+ if self.bn == True:
303
+ out = self.bn1(out)
304
+
305
+ out = self.activation(out)
306
+ out = self.conv2(out)
307
+ if self.bn == True:
308
+ out = self.bn2(out)
309
+
310
+ if self.groups > 1:
311
+ out = self.conv_merge(out)
312
+
313
+ return self.skip_add.add(out, x)
314
+
315
+ # return out + x
316
+
317
+
318
+ class FeatureFusionBlock_custom(nn.Module):
319
+ """Feature fusion block."""
320
+
321
+ def __init__(
322
+ self,
323
+ features,
324
+ activation,
325
+ deconv=False,
326
+ bn=False,
327
+ expand=False,
328
+ align_corners=True,
329
+ ):
330
+ """Init.
331
+
332
+ Args:
333
+ features (int): number of features
334
+ """
335
+ super(FeatureFusionBlock_custom, self).__init__()
336
+
337
+ self.deconv = deconv
338
+ self.align_corners = align_corners
339
+
340
+ self.groups = 1
341
+
342
+ self.expand = expand
343
+ out_features = features
344
+ if self.expand == True:
345
+ out_features = features // 2
346
+
347
+ self.out_conv = nn.Conv2d(
348
+ features,
349
+ out_features,
350
+ kernel_size=1,
351
+ stride=1,
352
+ padding=0,
353
+ bias=True,
354
+ groups=1,
355
+ )
356
+
357
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
358
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
359
+
360
+ self.skip_add = nn.quantized.FloatFunctional()
361
+
362
+ def forward(self, *xs):
363
+ """Forward pass.
364
+
365
+ Returns:
366
+ tensor: output
367
+ """
368
+ output = xs[0]
369
+
370
+ if len(xs) == 2:
371
+ res = self.resConfUnit1(xs[1])
372
+ output = self.skip_add.add(output, res)
373
+ # output += res
374
+
375
+ output = self.resConfUnit2(output)
376
+
377
+ output = nn.functional.interpolate(
378
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
379
+ )
380
+
381
+ output = self.out_conv(output)
382
+
383
+ return output
streamlit_apps/app_utils/dpt/midas_net.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from .base_model import BaseModel
10
+ from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
11
+
12
+
13
+ class MidasNet_large(BaseModel):
14
+ """Network for monocular depth estimation."""
15
+
16
+ def __init__(self, path=None, features=256, non_negative=True):
17
+ """Init.
18
+
19
+ Args:
20
+ path (str, optional): Path to saved model. Defaults to None.
21
+ features (int, optional): Number of features. Defaults to 256.
22
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23
+ """
24
+ print("Loading weights: ", path)
25
+
26
+ super(MidasNet_large, self).__init__()
27
+
28
+ use_pretrained = False if path is None else True
29
+
30
+ self.pretrained, self.scratch = _make_encoder(
31
+ backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained
32
+ )
33
+
34
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
35
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
36
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
37
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
38
+
39
+ self.scratch.output_conv = nn.Sequential(
40
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
41
+ Interpolate(scale_factor=2, mode="bilinear"),
42
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
43
+ nn.ReLU(True),
44
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
45
+ nn.ReLU(True) if non_negative else nn.Identity(),
46
+ )
47
+
48
+ if path:
49
+ self.load(path)
50
+
51
+ def forward(self, x):
52
+ """Forward pass.
53
+
54
+ Args:
55
+ x (tensor): input data (image)
56
+
57
+ Returns:
58
+ tensor: depth
59
+ """
60
+
61
+ layer_1 = self.pretrained.layer1(x)
62
+ layer_2 = self.pretrained.layer2(layer_1)
63
+ layer_3 = self.pretrained.layer3(layer_2)
64
+ layer_4 = self.pretrained.layer4(layer_3)
65
+
66
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
67
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
68
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
69
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
70
+
71
+ path_4 = self.scratch.refinenet4(layer_4_rn)
72
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
73
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
74
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
75
+
76
+ out = self.scratch.output_conv(path_1)
77
+
78
+ return torch.squeeze(out, dim=1)
streamlit_apps/app_utils/dpt/models.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+
5
+ from .base_model import BaseModel
6
+ from .blocks import (
7
+ FeatureFusionBlock_custom,
8
+ Interpolate,
9
+ _make_encoder,
10
+ forward_vit,
11
+ )
12
+
13
+
14
+ def _make_fusion_block(features, use_bn):
15
+ return FeatureFusionBlock_custom(
16
+ features,
17
+ nn.ReLU(False),
18
+ deconv=False,
19
+ bn=use_bn,
20
+ expand=False,
21
+ align_corners=True,
22
+ )
23
+
24
+
25
+ class DPT(BaseModel):
26
+ def __init__(
27
+ self,
28
+ head,
29
+ features=256,
30
+ backbone="vitb_rn50_384",
31
+ readout="project",
32
+ channels_last=False,
33
+ use_bn=False,
34
+ enable_attention_hooks=False,
35
+ ):
36
+ super(DPT, self).__init__()
37
+
38
+ self.channels_last = channels_last
39
+
40
+ hooks = {
41
+ "vitb_rn50_384": [0, 1, 8, 11],
42
+ "vitb16_384": [2, 5, 8, 11],
43
+ "vitl16_384": [5, 11, 17, 23],
44
+ }
45
+
46
+ # Instantiate backbone and reassemble blocks
47
+ self.pretrained, self.scratch = _make_encoder(
48
+ backbone,
49
+ features,
50
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
51
+ groups=1,
52
+ expand=False,
53
+ exportable=False,
54
+ hooks=hooks[backbone],
55
+ use_readout=readout,
56
+ enable_attention_hooks=enable_attention_hooks,
57
+ )
58
+
59
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63
+
64
+ self.scratch.output_conv = head
65
+
66
+ def forward(self, x: Tensor) -> Tensor:
67
+ if self.channels_last == True:
68
+ x.contiguous(memory_format=torch.channels_last)
69
+
70
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
71
+
72
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
73
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
74
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
75
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
76
+
77
+ path_4 = self.scratch.refinenet4(layer_4_rn)
78
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
79
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
80
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
81
+
82
+ out = self.scratch.output_conv(path_1)
83
+
84
+ return out
85
+
86
+
87
+ class DPTDepthModel(DPT):
88
+ def __init__(
89
+ self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs
90
+ ):
91
+ features = kwargs["features"] if "features" in kwargs else 256
92
+
93
+ self.scale = scale
94
+ self.shift = shift
95
+ self.invert = invert
96
+
97
+ head = nn.Sequential(
98
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
99
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
100
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
101
+ nn.ReLU(True),
102
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
103
+ nn.ReLU(True) if non_negative else nn.Identity(),
104
+ nn.Identity(),
105
+ )
106
+
107
+ super().__init__(head, **kwargs)
108
+
109
+ if path is not None:
110
+ self.load(path)
111
+
112
+ def forward(self, x: Tensor) -> Tensor:
113
+ """Input x of shape [b, c, h, w]
114
+ Return tensor of shape [b, c, h, w]
115
+ """
116
+ inv_depth = super().forward(x)
117
+
118
+ if self.invert:
119
+ depth = self.scale * inv_depth + self.shift
120
+ depth[depth < 1e-8] = 1e-8
121
+ depth = 1.0 / depth
122
+ return depth
123
+ else:
124
+ return inv_depth
streamlit_apps/app_utils/dpt/transforms.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import math
4
+
5
+
6
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
8
+
9
+ Args:
10
+ sample (dict): sample
11
+ size (tuple): image size
12
+
13
+ Returns:
14
+ tuple: new size
15
+ """
16
+ shape = list(sample["disparity"].shape)
17
+
18
+ if shape[0] >= size[0] and shape[1] >= size[1]:
19
+ return sample
20
+
21
+ scale = [0, 0]
22
+ scale[0] = size[0] / shape[0]
23
+ scale[1] = size[1] / shape[1]
24
+
25
+ scale = max(scale)
26
+
27
+ shape[0] = math.ceil(scale * shape[0])
28
+ shape[1] = math.ceil(scale * shape[1])
29
+
30
+ # resize
31
+ sample["image"] = cv2.resize(
32
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33
+ )
34
+
35
+ sample["disparity"] = cv2.resize(
36
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37
+ )
38
+ sample["mask"] = cv2.resize(
39
+ sample["mask"].astype(np.float32),
40
+ tuple(shape[::-1]),
41
+ interpolation=cv2.INTER_NEAREST,
42
+ )
43
+ sample["mask"] = sample["mask"].astype(bool)
44
+
45
+ return tuple(shape)
46
+
47
+
48
+ class Resize(object):
49
+ """Resize sample to given size (width, height)."""
50
+
51
+ def __init__(
52
+ self,
53
+ width,
54
+ height,
55
+ resize_target=True,
56
+ keep_aspect_ratio=False,
57
+ ensure_multiple_of=1,
58
+ resize_method="lower_bound",
59
+ image_interpolation_method=cv2.INTER_AREA,
60
+ ):
61
+ """Init.
62
+
63
+ Args:
64
+ width (int): desired output width
65
+ height (int): desired output height
66
+ resize_target (bool, optional):
67
+ True: Resize the full sample (image, mask, target).
68
+ False: Resize image only.
69
+ Defaults to True.
70
+ keep_aspect_ratio (bool, optional):
71
+ True: Keep the aspect ratio of the input sample.
72
+ Output sample might not have the given width and height, and
73
+ resize behaviour depends on the parameter 'resize_method'.
74
+ Defaults to False.
75
+ ensure_multiple_of (int, optional):
76
+ Output width and height is constrained to be multiple of this parameter.
77
+ Defaults to 1.
78
+ resize_method (str, optional):
79
+ "lower_bound": Output will be at least as large as the given size.
80
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
81
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
82
+ Defaults to "lower_bound".
83
+ """
84
+ self.__width = width
85
+ self.__height = height
86
+
87
+ self.__resize_target = resize_target
88
+ self.__keep_aspect_ratio = keep_aspect_ratio
89
+ self.__multiple_of = ensure_multiple_of
90
+ self.__resize_method = resize_method
91
+ self.__image_interpolation_method = image_interpolation_method
92
+
93
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
94
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
95
+
96
+ if max_val is not None and y > max_val:
97
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
98
+
99
+ if y < min_val:
100
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
101
+
102
+ return y
103
+
104
+ def get_size(self, width, height):
105
+ # determine new height and width
106
+ scale_height = self.__height / height
107
+ scale_width = self.__width / width
108
+
109
+ if self.__keep_aspect_ratio:
110
+ if self.__resize_method == "lower_bound":
111
+ # scale such that output size is lower bound
112
+ if scale_width > scale_height:
113
+ # fit width
114
+ scale_height = scale_width
115
+ else:
116
+ # fit height
117
+ scale_width = scale_height
118
+ elif self.__resize_method == "upper_bound":
119
+ # scale such that output size is upper bound
120
+ if scale_width < scale_height:
121
+ # fit width
122
+ scale_height = scale_width
123
+ else:
124
+ # fit height
125
+ scale_width = scale_height
126
+ elif self.__resize_method == "minimal":
127
+ # scale as least as possbile
128
+ if abs(1 - scale_width) < abs(1 - scale_height):
129
+ # fit width
130
+ scale_height = scale_width
131
+ else:
132
+ # fit height
133
+ scale_width = scale_height
134
+ else:
135
+ raise ValueError(
136
+ f"resize_method {self.__resize_method} not implemented"
137
+ )
138
+
139
+ if self.__resize_method == "lower_bound":
140
+ new_height = self.constrain_to_multiple_of(
141
+ scale_height * height, min_val=self.__height
142
+ )
143
+ new_width = self.constrain_to_multiple_of(
144
+ scale_width * width, min_val=self.__width
145
+ )
146
+ elif self.__resize_method == "upper_bound":
147
+ new_height = self.constrain_to_multiple_of(
148
+ scale_height * height, max_val=self.__height
149
+ )
150
+ new_width = self.constrain_to_multiple_of(
151
+ scale_width * width, max_val=self.__width
152
+ )
153
+ elif self.__resize_method == "minimal":
154
+ new_height = self.constrain_to_multiple_of(scale_height * height)
155
+ new_width = self.constrain_to_multiple_of(scale_width * width)
156
+ else:
157
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
158
+
159
+ return (new_width, new_height)
160
+
161
+ def __call__(self, sample):
162
+ width, height = self.get_size(
163
+ sample["image"].shape[1], sample["image"].shape[0]
164
+ )
165
+
166
+ # resize sample
167
+ sample["image"] = cv2.resize(
168
+ sample["image"],
169
+ (width, height),
170
+ interpolation=self.__image_interpolation_method,
171
+ )
172
+
173
+ if self.__resize_target:
174
+ if "disparity" in sample:
175
+ sample["disparity"] = cv2.resize(
176
+ sample["disparity"],
177
+ (width, height),
178
+ interpolation=cv2.INTER_NEAREST,
179
+ )
180
+
181
+ if "depth" in sample:
182
+ sample["depth"] = cv2.resize(
183
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
184
+ )
185
+
186
+ sample["mask"] = cv2.resize(
187
+ sample["mask"].astype(np.float32),
188
+ (width, height),
189
+ interpolation=cv2.INTER_NEAREST,
190
+ )
191
+ sample["mask"] = sample["mask"].astype(bool)
192
+
193
+ return sample
194
+
195
+
196
+ class NormalizeImage(object):
197
+ """Normlize image by given mean and std."""
198
+
199
+ def __init__(self, mean, std):
200
+ self.__mean = mean
201
+ self.__std = std
202
+
203
+ def __call__(self, sample):
204
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
205
+
206
+ return sample
207
+
208
+
209
+ class PrepareForNet(object):
210
+ """Prepare sample for usage as network input."""
211
+
212
+ def __init__(self):
213
+ pass
214
+
215
+ def __call__(self, sample):
216
+ image = np.transpose(sample["image"], (2, 0, 1))
217
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
218
+
219
+ if "mask" in sample:
220
+ sample["mask"] = sample["mask"].astype(np.float32)
221
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
222
+
223
+ if "disparity" in sample:
224
+ disparity = sample["disparity"].astype(np.float32)
225
+ sample["disparity"] = np.ascontiguousarray(disparity)
226
+
227
+ if "depth" in sample:
228
+ depth = sample["depth"].astype(np.float32)
229
+ sample["depth"] = np.ascontiguousarray(depth)
230
+
231
+ return sample
streamlit_apps/app_utils/dpt/vit.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ import types
5
+ import math
6
+ import torch.nn.functional as F
7
+
8
+
9
+ activations = {}
10
+
11
+
12
+ def get_activation(name):
13
+ def hook(model, input, output):
14
+ activations[name] = output
15
+
16
+ return hook
17
+
18
+
19
+ attention = {}
20
+
21
+
22
+ def get_attention(name):
23
+ def hook(module, input, output):
24
+ x = input[0]
25
+ B, N, C = x.shape
26
+ qkv = (
27
+ module.qkv(x)
28
+ .reshape(B, N, 3, module.num_heads, C // module.num_heads)
29
+ .permute(2, 0, 3, 1, 4)
30
+ )
31
+ q, k, v = (
32
+ qkv[0],
33
+ qkv[1],
34
+ qkv[2],
35
+ ) # make torchscript happy (cannot use tensor as tuple)
36
+
37
+ attn = (q @ k.transpose(-2, -1)) * module.scale
38
+
39
+ attn = attn.softmax(dim=-1) # [:,:,1,1:]
40
+ attention[name] = attn
41
+
42
+ return hook
43
+
44
+
45
+ def get_mean_attention_map(attn, token, shape):
46
+ attn = attn[:, :, token, 1:]
47
+ attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float()
48
+ attn = torch.nn.functional.interpolate(
49
+ attn, size=shape[2:], mode="bicubic", align_corners=False
50
+ ).squeeze(0)
51
+
52
+ all_attn = torch.mean(attn, 0)
53
+
54
+ return all_attn
55
+
56
+
57
+ class Slice(nn.Module):
58
+ def __init__(self, start_index=1):
59
+ super(Slice, self).__init__()
60
+ self.start_index = start_index
61
+
62
+ def forward(self, x):
63
+ return x[:, self.start_index :]
64
+
65
+
66
+ class AddReadout(nn.Module):
67
+ def __init__(self, start_index=1):
68
+ super(AddReadout, self).__init__()
69
+ self.start_index = start_index
70
+
71
+ def forward(self, x):
72
+ if self.start_index == 2:
73
+ readout = (x[:, 0] + x[:, 1]) / 2
74
+ else:
75
+ readout = x[:, 0]
76
+ return x[:, self.start_index :] + readout.unsqueeze(1)
77
+
78
+
79
+ class ProjectReadout(nn.Module):
80
+ def __init__(self, in_features, start_index=1):
81
+ super(ProjectReadout, self).__init__()
82
+ self.start_index = start_index
83
+
84
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
85
+
86
+ def forward(self, x):
87
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
88
+ features = torch.cat((x[:, self.start_index :], readout), -1)
89
+
90
+ return self.project(features)
91
+
92
+
93
+ class Transpose(nn.Module):
94
+ def __init__(self, dim0, dim1):
95
+ super(Transpose, self).__init__()
96
+ self.dim0 = dim0
97
+ self.dim1 = dim1
98
+
99
+ def forward(self, x):
100
+ x = x.transpose(self.dim0, self.dim1)
101
+ return x
102
+
103
+
104
+ def forward_vit(pretrained, x):
105
+ b, c, h, w = x.shape
106
+
107
+ glob = pretrained.model.forward_flex(x)
108
+
109
+ layer_1 = pretrained.activations["1"]
110
+ layer_2 = pretrained.activations["2"]
111
+ layer_3 = pretrained.activations["3"]
112
+ layer_4 = pretrained.activations["4"]
113
+
114
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
115
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
116
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
117
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
118
+
119
+ unflatten = nn.Sequential(
120
+ nn.Unflatten(
121
+ 2,
122
+ torch.Size(
123
+ [
124
+ h // pretrained.model.patch_size[1],
125
+ w // pretrained.model.patch_size[0],
126
+ ]
127
+ ),
128
+ )
129
+ )
130
+
131
+ if layer_1.ndim == 3:
132
+ layer_1 = unflatten(layer_1)
133
+ if layer_2.ndim == 3:
134
+ layer_2 = unflatten(layer_2)
135
+ if layer_3.ndim == 3:
136
+ layer_3 = unflatten(layer_3)
137
+ if layer_4.ndim == 3:
138
+ layer_4 = unflatten(layer_4)
139
+
140
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
141
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
142
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
143
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
144
+
145
+ return layer_1, layer_2, layer_3, layer_4
146
+
147
+
148
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
149
+ posemb_tok, posemb_grid = (
150
+ posemb[:, : self.start_index],
151
+ posemb[0, self.start_index :],
152
+ )
153
+
154
+ gs_old = int(math.sqrt(len(posemb_grid)))
155
+
156
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
157
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
158
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
159
+
160
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
161
+
162
+ return posemb
163
+
164
+
165
+ def forward_flex(self, x):
166
+ b, c, h, w = x.shape
167
+
168
+ pos_embed = self._resize_pos_embed(
169
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
170
+ )
171
+
172
+ B = x.shape[0]
173
+
174
+ if hasattr(self.patch_embed, "backbone"):
175
+ x = self.patch_embed.backbone(x)
176
+ if isinstance(x, (list, tuple)):
177
+ x = x[-1] # last feature if backbone outputs list/tuple of features
178
+
179
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
180
+
181
+ if getattr(self, "dist_token", None) is not None:
182
+ cls_tokens = self.cls_token.expand(
183
+ B, -1, -1
184
+ ) # stole cls_tokens impl from Phil Wang, thanks
185
+ dist_token = self.dist_token.expand(B, -1, -1)
186
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
187
+ else:
188
+ cls_tokens = self.cls_token.expand(
189
+ B, -1, -1
190
+ ) # stole cls_tokens impl from Phil Wang, thanks
191
+ x = torch.cat((cls_tokens, x), dim=1)
192
+
193
+ x = x + pos_embed
194
+ x = self.pos_drop(x)
195
+
196
+ for blk in self.blocks:
197
+ x = blk(x)
198
+
199
+ x = self.norm(x)
200
+
201
+ return x
202
+
203
+
204
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
205
+ if use_readout == "ignore":
206
+ readout_oper = [Slice(start_index)] * len(features)
207
+ elif use_readout == "add":
208
+ readout_oper = [AddReadout(start_index)] * len(features)
209
+ elif use_readout == "project":
210
+ readout_oper = [
211
+ ProjectReadout(vit_features, start_index) for out_feat in features
212
+ ]
213
+ else:
214
+ assert (
215
+ False
216
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
217
+
218
+ return readout_oper
219
+
220
+
221
+ def _make_vit_b16_backbone(
222
+ model,
223
+ features=[96, 192, 384, 768],
224
+ size=[384, 384],
225
+ hooks=[2, 5, 8, 11],
226
+ vit_features=768,
227
+ use_readout="ignore",
228
+ start_index=1,
229
+ enable_attention_hooks=False,
230
+ ):
231
+ pretrained = nn.Module()
232
+
233
+ pretrained.model = model
234
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
235
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
236
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
237
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
238
+
239
+ pretrained.activations = activations
240
+
241
+ if enable_attention_hooks:
242
+ pretrained.model.blocks[hooks[0]].attn.register_forward_hook(
243
+ get_attention("attn_1")
244
+ )
245
+ pretrained.model.blocks[hooks[1]].attn.register_forward_hook(
246
+ get_attention("attn_2")
247
+ )
248
+ pretrained.model.blocks[hooks[2]].attn.register_forward_hook(
249
+ get_attention("attn_3")
250
+ )
251
+ pretrained.model.blocks[hooks[3]].attn.register_forward_hook(
252
+ get_attention("attn_4")
253
+ )
254
+ pretrained.attention = attention
255
+
256
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
257
+
258
+ # 32, 48, 136, 384
259
+ pretrained.act_postprocess1 = nn.Sequential(
260
+ readout_oper[0],
261
+ Transpose(1, 2),
262
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
263
+ nn.Conv2d(
264
+ in_channels=vit_features,
265
+ out_channels=features[0],
266
+ kernel_size=1,
267
+ stride=1,
268
+ padding=0,
269
+ ),
270
+ nn.ConvTranspose2d(
271
+ in_channels=features[0],
272
+ out_channels=features[0],
273
+ kernel_size=4,
274
+ stride=4,
275
+ padding=0,
276
+ bias=True,
277
+ dilation=1,
278
+ groups=1,
279
+ ),
280
+ )
281
+
282
+ pretrained.act_postprocess2 = nn.Sequential(
283
+ readout_oper[1],
284
+ Transpose(1, 2),
285
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
286
+ nn.Conv2d(
287
+ in_channels=vit_features,
288
+ out_channels=features[1],
289
+ kernel_size=1,
290
+ stride=1,
291
+ padding=0,
292
+ ),
293
+ nn.ConvTranspose2d(
294
+ in_channels=features[1],
295
+ out_channels=features[1],
296
+ kernel_size=2,
297
+ stride=2,
298
+ padding=0,
299
+ bias=True,
300
+ dilation=1,
301
+ groups=1,
302
+ ),
303
+ )
304
+
305
+ pretrained.act_postprocess3 = nn.Sequential(
306
+ readout_oper[2],
307
+ Transpose(1, 2),
308
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
309
+ nn.Conv2d(
310
+ in_channels=vit_features,
311
+ out_channels=features[2],
312
+ kernel_size=1,
313
+ stride=1,
314
+ padding=0,
315
+ ),
316
+ )
317
+
318
+ pretrained.act_postprocess4 = nn.Sequential(
319
+ readout_oper[3],
320
+ Transpose(1, 2),
321
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
322
+ nn.Conv2d(
323
+ in_channels=vit_features,
324
+ out_channels=features[3],
325
+ kernel_size=1,
326
+ stride=1,
327
+ padding=0,
328
+ ),
329
+ nn.Conv2d(
330
+ in_channels=features[3],
331
+ out_channels=features[3],
332
+ kernel_size=3,
333
+ stride=2,
334
+ padding=1,
335
+ ),
336
+ )
337
+
338
+ pretrained.model.start_index = start_index
339
+ pretrained.model.patch_size = [16, 16]
340
+
341
+ # We inject this function into the VisionTransformer instances so that
342
+ # we can use it with interpolated position embeddings without modifying the library source.
343
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
344
+ pretrained.model._resize_pos_embed = types.MethodType(
345
+ _resize_pos_embed, pretrained.model
346
+ )
347
+
348
+ return pretrained
349
+
350
+
351
+ def _make_vit_b_rn50_backbone(
352
+ model,
353
+ features=[256, 512, 768, 768],
354
+ size=[384, 384],
355
+ hooks=[0, 1, 8, 11],
356
+ vit_features=768,
357
+ use_vit_only=False,
358
+ use_readout="ignore",
359
+ start_index=1,
360
+ enable_attention_hooks=False,
361
+ ):
362
+ pretrained = nn.Module()
363
+
364
+ pretrained.model = model
365
+
366
+ if use_vit_only == True:
367
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
368
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
369
+ else:
370
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
371
+ get_activation("1")
372
+ )
373
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
374
+ get_activation("2")
375
+ )
376
+
377
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
378
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
379
+
380
+ if enable_attention_hooks:
381
+ pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1"))
382
+ pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2"))
383
+ pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3"))
384
+ pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4"))
385
+ pretrained.attention = attention
386
+
387
+ pretrained.activations = activations
388
+
389
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
390
+
391
+ if use_vit_only == True:
392
+ pretrained.act_postprocess1 = nn.Sequential(
393
+ readout_oper[0],
394
+ Transpose(1, 2),
395
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
396
+ nn.Conv2d(
397
+ in_channels=vit_features,
398
+ out_channels=features[0],
399
+ kernel_size=1,
400
+ stride=1,
401
+ padding=0,
402
+ ),
403
+ nn.ConvTranspose2d(
404
+ in_channels=features[0],
405
+ out_channels=features[0],
406
+ kernel_size=4,
407
+ stride=4,
408
+ padding=0,
409
+ bias=True,
410
+ dilation=1,
411
+ groups=1,
412
+ ),
413
+ )
414
+
415
+ pretrained.act_postprocess2 = nn.Sequential(
416
+ readout_oper[1],
417
+ Transpose(1, 2),
418
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
419
+ nn.Conv2d(
420
+ in_channels=vit_features,
421
+ out_channels=features[1],
422
+ kernel_size=1,
423
+ stride=1,
424
+ padding=0,
425
+ ),
426
+ nn.ConvTranspose2d(
427
+ in_channels=features[1],
428
+ out_channels=features[1],
429
+ kernel_size=2,
430
+ stride=2,
431
+ padding=0,
432
+ bias=True,
433
+ dilation=1,
434
+ groups=1,
435
+ ),
436
+ )
437
+ else:
438
+ pretrained.act_postprocess1 = nn.Sequential(
439
+ nn.Identity(), nn.Identity(), nn.Identity()
440
+ )
441
+ pretrained.act_postprocess2 = nn.Sequential(
442
+ nn.Identity(), nn.Identity(), nn.Identity()
443
+ )
444
+
445
+ pretrained.act_postprocess3 = nn.Sequential(
446
+ readout_oper[2],
447
+ Transpose(1, 2),
448
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
449
+ nn.Conv2d(
450
+ in_channels=vit_features,
451
+ out_channels=features[2],
452
+ kernel_size=1,
453
+ stride=1,
454
+ padding=0,
455
+ ),
456
+ )
457
+
458
+ pretrained.act_postprocess4 = nn.Sequential(
459
+ readout_oper[3],
460
+ Transpose(1, 2),
461
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
462
+ nn.Conv2d(
463
+ in_channels=vit_features,
464
+ out_channels=features[3],
465
+ kernel_size=1,
466
+ stride=1,
467
+ padding=0,
468
+ ),
469
+ nn.Conv2d(
470
+ in_channels=features[3],
471
+ out_channels=features[3],
472
+ kernel_size=3,
473
+ stride=2,
474
+ padding=1,
475
+ ),
476
+ )
477
+
478
+ pretrained.model.start_index = start_index
479
+ pretrained.model.patch_size = [16, 16]
480
+
481
+ # We inject this function into the VisionTransformer instances so that
482
+ # we can use it with interpolated position embeddings without modifying the library source.
483
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
484
+
485
+ # We inject this function into the VisionTransformer instances so that
486
+ # we can use it with interpolated position embeddings without modifying the library source.
487
+ pretrained.model._resize_pos_embed = types.MethodType(
488
+ _resize_pos_embed, pretrained.model
489
+ )
490
+
491
+ return pretrained
492
+
493
+
494
+ def _make_pretrained_vitb_rn50_384(
495
+ pretrained,
496
+ use_readout="ignore",
497
+ hooks=None,
498
+ use_vit_only=False,
499
+ enable_attention_hooks=False,
500
+ ):
501
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
502
+
503
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
504
+ return _make_vit_b_rn50_backbone(
505
+ model,
506
+ features=[256, 512, 768, 768],
507
+ size=[384, 384],
508
+ hooks=hooks,
509
+ use_vit_only=use_vit_only,
510
+ use_readout=use_readout,
511
+ enable_attention_hooks=enable_attention_hooks,
512
+ )
513
+
514
+
515
+ def _make_pretrained_vitl16_384(
516
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
517
+ ):
518
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
519
+
520
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
521
+ return _make_vit_b16_backbone(
522
+ model,
523
+ features=[256, 512, 1024, 1024],
524
+ hooks=hooks,
525
+ vit_features=1024,
526
+ use_readout=use_readout,
527
+ enable_attention_hooks=enable_attention_hooks,
528
+ )
529
+
530
+
531
+ def _make_pretrained_vitb16_384(
532
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
533
+ ):
534
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
535
+
536
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
537
+ return _make_vit_b16_backbone(
538
+ model,
539
+ features=[96, 192, 384, 768],
540
+ hooks=hooks,
541
+ use_readout=use_readout,
542
+ enable_attention_hooks=enable_attention_hooks,
543
+ )
544
+
545
+
546
+ def _make_pretrained_deitb16_384(
547
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
548
+ ):
549
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
550
+
551
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
552
+ return _make_vit_b16_backbone(
553
+ model,
554
+ features=[96, 192, 384, 768],
555
+ hooks=hooks,
556
+ use_readout=use_readout,
557
+ enable_attention_hooks=enable_attention_hooks,
558
+ )
559
+
560
+
561
+ def _make_pretrained_deitb16_distil_384(
562
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
563
+ ):
564
+ model = timm.create_model(
565
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
566
+ )
567
+
568
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
569
+ return _make_vit_b16_backbone(
570
+ model,
571
+ features=[96, 192, 384, 768],
572
+ hooks=hooks,
573
+ use_readout=use_readout,
574
+ start_index=2,
575
+ enable_attention_hooks=enable_attention_hooks,
576
+ )
streamlit_apps/app_utils/image_inference.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ import streamlit as st
4
+ from PIL import Image
5
+
6
+ from s_multimae.da.base_da import BaseDataAugmentation
7
+ from .base_model import BaseRGBDModel
8
+ from .depth_model import BaseDepthModel
9
+ from .model import base_inference
10
+
11
+ if "depth" not in st.session_state:
12
+ st.session_state.depth = None
13
+
14
+
15
+ def image_inference(
16
+ depth_model: BaseDepthModel,
17
+ sod_model: BaseRGBDModel,
18
+ da: BaseDataAugmentation,
19
+ color: np.ndarray,
20
+ ) -> None:
21
+ col1, col2 = st.columns(2)
22
+ image: Image = None
23
+ # depth: Image = None
24
+
25
+ with col1:
26
+ img_file_buffer = st.file_uploader(
27
+ "Upload an RGB image", key="img_file_buffer", type=["png", "jpg", "jpeg"]
28
+ )
29
+ if img_file_buffer is not None:
30
+ image = Image.open(img_file_buffer).convert("RGB")
31
+ st.image(image, caption="RGB")
32
+
33
+ with col2:
34
+ depth_file_buffer = st.file_uploader(
35
+ "Upload a depth image (Optional)",
36
+ key="depth_file_buffer",
37
+ type=["png", "jpg", "jpeg"],
38
+ )
39
+ if depth_file_buffer is not None:
40
+ st.session_state.depth = Image.open(depth_file_buffer).convert("L")
41
+ if st.session_state.depth is not None:
42
+ st.image(st.session_state.depth, caption="Depth")
43
+
44
+ if sod_model.cfg.ground_truth_version == 6:
45
+ num_sets_of_salient_objects = st.number_input(
46
+ "Number of sets of salient objects", value=1, min_value=1, max_value=10
47
+ )
48
+ else:
49
+ num_sets_of_salient_objects = 1
50
+
51
+ is_predict = st.button(
52
+ "Predict Salient Objects",
53
+ key="predict_salient_objects",
54
+ disabled=img_file_buffer is None,
55
+ )
56
+ if is_predict:
57
+ with st.spinner("Processing..."):
58
+ start_time = time.time()
59
+ pred_depth, pred_sods, pred_sms = base_inference(
60
+ depth_model,
61
+ sod_model,
62
+ da,
63
+ image,
64
+ st.session_state.depth,
65
+ color,
66
+ num_sets_of_salient_objects,
67
+ )
68
+ if st.session_state.depth is None:
69
+ st.session_state.depth = Image.fromarray(pred_depth).convert("L")
70
+ col2.image(st.session_state.depth, "Pseudo-depth")
71
+
72
+ if num_sets_of_salient_objects == 1:
73
+ st.warning(
74
+ "HINT: To view a wider variety of sets of salient objects, try to increase the number of sets the model can produce."
75
+ )
76
+ elif num_sets_of_salient_objects > 1:
77
+ st.warning(
78
+ "NOTE: As single-GT accounts for 77.61% of training samples, the model may not consistently yield different sets. The best approach is to gradually increase the number of sets of salient objects until you achieve the desired result."
79
+ )
80
+
81
+ st.info(f"Inference time: {time.time() - start_time:.4f} seconds")
82
+
83
+ sod_cols = st.columns(len(pred_sods))
84
+
85
+ for i, (pred_sod, pred_sm) in enumerate(zip(pred_sods, pred_sms)):
86
+ with sod_cols[i]:
87
+ st.image(pred_sod, "Salient Objects (Otsu threshold)")
88
+ st.image(pred_sm, "Salient Map")
streamlit_apps/app_utils/model.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torchvision.transforms.functional as TF
6
+ from PIL import Image
7
+ from torch import Tensor, nn
8
+ import torch
9
+ from skimage.filters import threshold_otsu
10
+
11
+ from s_multimae.da.base_da import BaseDataAugmentation
12
+ from s_multimae.model_pl import ModelPL
13
+ from s_multimae.visualizer import apply_vis_to_image
14
+
15
+ from .base_model import BaseRGBDModel
16
+ from .app_utils import get_size, normalize
17
+ from .depth_model import BaseDepthModel
18
+
19
+
20
+ # Environment
21
+ torch.set_grad_enabled(False)
22
+ from .device import device
23
+
24
+ print(f"device: {device}")
25
+
26
+
27
+ def post_processing_depth(depth: np.ndarray) -> np.ndarray:
28
+ depth = (normalize(depth) * 255).astype(np.uint8)
29
+ return cv2.applyColorMap(depth, cv2.COLORMAP_OCEAN)
30
+
31
+
32
+ def base_inference(
33
+ depth_model: BaseDepthModel,
34
+ sod_model: BaseRGBDModel,
35
+ da: BaseDataAugmentation,
36
+ raw_image: Union[Image.Image, np.ndarray],
37
+ raw_depth: Optional[Union[Image.Image, np.ndarray]] = None,
38
+ color: np.ndarray = None,
39
+ num_sets_of_salient_objects: int = 1,
40
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
41
+ """Inference a pair of rgb image and depth image
42
+ if depth image is not provided, the depth_model will predict a depth image based on image
43
+ """
44
+ origin_size = get_size(raw_image)
45
+
46
+ # Predict depth
47
+ image = TF.to_tensor(raw_image)
48
+ origin_shape = image.shape
49
+ if raw_depth is None:
50
+ depth: Tensor = depth_model.forward(image)
51
+ else:
52
+ depth = TF.to_tensor(raw_depth)
53
+
54
+ # Preprocessing
55
+ image, depth = da.forward(
56
+ raw_image, depth.cpu().detach().squeeze(0).numpy(), is_transform=False
57
+ )
58
+
59
+ # Inference
60
+ sms = sod_model.inference(image, depth, origin_shape, num_sets_of_salient_objects)
61
+
62
+ # Postprocessing
63
+ sods = []
64
+
65
+ for sm in sms:
66
+ binary_mask = np.array(sm)
67
+ t = threshold_otsu(binary_mask)
68
+ binary_mask[binary_mask < t] = 0.0
69
+ binary_mask[binary_mask >= t] = 1.0
70
+
71
+ sod = apply_vis_to_image(np.array(raw_image), binary_mask, color)
72
+ sods.append(sod)
73
+
74
+ depth = depth.permute(1, 2, 0).detach().cpu().numpy()
75
+ depth = cv2.resize(depth, origin_size)
76
+ depth = post_processing_depth(depth)
77
+
78
+ return depth, sods, [e / 255.0 for e in sms]
79
+
80
+
81
+ def transform_images(inputs: List[Image.Image], transform: nn.Module) -> Tensor:
82
+ if len(inputs) == 1:
83
+ return transform(inputs[0]).unsqueeze(0)
84
+ return torch.cat([transform(input).unsqueeze(0) for input in inputs])
streamlit_apps/app_utils/smultimae_model.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torch import Tensor
3
+ from torchvision.transforms import Resize
4
+
5
+ from s_multimae.model_pl import ModelPL
6
+ from s_multimae.configs.base_config import base_cfg
7
+
8
+ from .base_model import BaseRGBDModel
9
+
10
+
11
+ class RGBDSMultiMAEModel(BaseRGBDModel):
12
+ def __init__(self, cfg: base_cfg, model: ModelPL):
13
+ """Wrapper of RGBDModel"""
14
+ super(RGBDSMultiMAEModel, self).__init__()
15
+ self.model: ModelPL = model
16
+ self.cfg = cfg
17
+ self.resize = Resize([self.cfg.image_size, self.cfg.image_size])
18
+
19
+ def inference(
20
+ self,
21
+ image: Tensor,
22
+ depth: Tensor,
23
+ origin_shape: np.array,
24
+ num_sets_of_salient_objects: int = 1,
25
+ ) -> np.ndarray:
26
+ # 1. Preprocessing
27
+ images = image.unsqueeze(0)
28
+ depths = depth.unsqueeze(0)
29
+
30
+ # images = self.resize(images)
31
+ # depths = self.resize(depths)
32
+
33
+ # 2. Inference
34
+ images, depths = images.to(self.model.device), depths.to(self.model.device)
35
+ if self.cfg.ground_truth_version == 6:
36
+ self.cfg.num_classes = num_sets_of_salient_objects
37
+ res = self.model.inference(
38
+ [[origin_shape[2], origin_shape[1]]],
39
+ images,
40
+ depths,
41
+ [num_sets_of_salient_objects],
42
+ )
43
+ return res[0]
streamlit_apps/app_utils/sod_selection_ui.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import streamlit as st
3
+ import os
4
+ import torch
5
+
6
+ from .app_env import SOD_MODEL_TYPE
7
+ from .app_utils import count_parameters
8
+ from .smultimae_model import RGBDSMultiMAEModel
9
+ from .base_model import BaseRGBDModel
10
+ from .device import device
11
+
12
+ from s_multimae.da.dav6 import DataAugmentationV6
13
+ from s_multimae.configs.base_config import base_cfg
14
+ from s_multimae.configs.experiment_config import arg_cfg
15
+ from s_multimae.model_pl import ModelPL
16
+
17
+ # from spnet_model import SPNetModel
18
+
19
+
20
+ @st.cache_resource
21
+ def load_smultimae_model(
22
+ sod_model_config_key: str, top: int
23
+ ) -> Tuple[BaseRGBDModel, base_cfg]:
24
+ """
25
+ 1. Construct model
26
+ 2. Load pretrained weights
27
+ 3. Load model into device
28
+ """
29
+ cfg = arg_cfg[sod_model_config_key]()
30
+
31
+ weights_fname = f"s-multimae-{cfg.experiment_name}-top{top}.pth"
32
+ ckpt_path = os.path.join(
33
+ "weights", weights_fname
34
+ )
35
+ print(ckpt_path)
36
+ if not os.path.isfile(ckpt_path):
37
+ from huggingface_hub import hf_hub_download
38
+ hf_hub_download(repo_id="RGBD-SOD/S-MultiMAE", filename=weights_fname)
39
+ os.system(f"mv {weights_fname} weights")
40
+ assert os.path.isfile(ckpt_path)
41
+
42
+ # sod_model = ModelPL.load_from_checkpoint(
43
+ # ckpt_path,
44
+ # cfg=cfg,
45
+ # map_location=device,
46
+ # )
47
+ sod_model = ModelPL(cfg)
48
+ sod_model.model.load_state_dict(
49
+ torch.load(ckpt_path, map_location="cpu"), strict=False
50
+ )
51
+ da = DataAugmentationV6(cfg)
52
+ return RGBDSMultiMAEModel(cfg, sod_model), cfg, da
53
+
54
+
55
+ # @st.cache_resource
56
+ # def load_spnet_model() -> BaseRGBDModel:
57
+ # """
58
+ # 1. Construct model
59
+ # 2. Load pretrained weights
60
+ # 3. Load model into device
61
+ # """
62
+ # sod_model = SPNetModel()
63
+ # return sod_model
64
+
65
+
66
+ # @st.cache_resource
67
+ # def load_bbsnet_model() -> BaseRGBDModel:
68
+ # """
69
+ # 1. Construct model
70
+ # 2. Load pretrained weights
71
+ # 3. Load model into device
72
+ # """
73
+ # sod_model = BBSNetModel()
74
+ # return sod_model
75
+
76
+
77
+ def sod_selection_ui() -> BaseRGBDModel:
78
+ sod_model_type = st.selectbox(
79
+ "Choose SOD model",
80
+ (
81
+ SOD_MODEL_TYPE.S_MULTIMAE,
82
+ # SOD_MODEL_TYPE.SPNET,
83
+ # SOD_MODEL_TYPE.BBSNET,
84
+ ),
85
+ key="sod_model_type",
86
+ )
87
+
88
+ if sod_model_type == SOD_MODEL_TYPE.S_MULTIMAE:
89
+ d = {
90
+ "S-MultiMAE [ViT-L] Multi-GT": {"top": 1, "cfg": "cfgv4_0_2006"},
91
+ "S-MultiMAE [ViT-B] Multi-GT": {"top": 1, "cfg": "cfgv4_0_2007"},
92
+ }
93
+
94
+ sod_model_config_key = st.selectbox(
95
+ "Choose config",
96
+ list(d.keys()),
97
+ key="sod_model_config_key",
98
+ )
99
+ sod_model, cfg, da = load_smultimae_model(
100
+ d[sod_model_config_key]["cfg"], d[sod_model_config_key]["top"]
101
+ )
102
+ # st.text(f"Model description: {cfg.description}")
103
+ # elif sod_model_type == SOD_MODEL_TYPE.SPNET:
104
+ # sod_model = load_spnet_model()
105
+ # st.text(f"Model description: SPNet (https://github.com/taozh2017/SPNet)")
106
+ # elif sod_model_type == SOD_MODEL_TYPE.BBSNET:
107
+ # sod_model = load_bbsnet_model()
108
+ # st.text(f"Model description: BBSNet (https://github.com/DengPingFan/BBS-Net)")
109
+ st.text(f"Number of parameters {count_parameters(sod_model)}")
110
+
111
+ return sod_model, da