Spaces:
Sleeping
Sleeping
Commit
·
6e9c433
1
Parent(s):
2cadd70
Init
Browse files- .gitignore +28 -0
- README.md +72 -1
- definition.py +21 -0
- docs/figures/proposed_method_v5.drawio.png +0 -0
- docs/references/Dataset.bib +124 -0
- docs/references/References.bib +190 -0
- docs/references/SOTAs.bib +355 -0
- requirements-lock.txt +103 -0
- requirements.txt +18 -0
- s_multimae/__init__.py +0 -0
- s_multimae/configs/__init__.py +0 -0
- s_multimae/configs/base_config.py +164 -0
- s_multimae/configs/data_augmentation_config.py +19 -0
- s_multimae/configs/experiment_config.py +31 -0
- s_multimae/configs/experiment_configs/__init__.py +0 -0
- s_multimae/configs/experiment_configs/expv1_dynamic.py +277 -0
- s_multimae/da/__init__.py +0 -0
- s_multimae/da/base_da.py +33 -0
- s_multimae/da/dav6.py +147 -0
- s_multimae/data_augmentation.py +19 -0
- s_multimae/model/__init__.py +0 -0
- s_multimae/model/components.py +117 -0
- s_multimae/model/multimae.py +938 -0
- s_multimae/model_pl.py +105 -0
- s_multimae/rgbd_model.py +60 -0
- s_multimae/utils.py +236 -0
- s_multimae/visualize_2d_posemb.py +58 -0
- s_multimae/visualizer.py +711 -0
- streamlit_apps/__init__.py +0 -0
- streamlit_apps/app.py +91 -0
- streamlit_apps/app_utils/__init__.py +0 -0
- streamlit_apps/app_utils/app_env.py +16 -0
- streamlit_apps/app_utils/app_utils.py +83 -0
- streamlit_apps/app_utils/base_model.py +54 -0
- streamlit_apps/app_utils/color_selection_ui.py +10 -0
- streamlit_apps/app_utils/depth_model.py +77 -0
- streamlit_apps/app_utils/depth_selection_ui.py +27 -0
- streamlit_apps/app_utils/device.py +5 -0
- streamlit_apps/app_utils/dpt/__init__.py +0 -0
- streamlit_apps/app_utils/dpt/base_model.py +16 -0
- streamlit_apps/app_utils/dpt/blocks.py +383 -0
- streamlit_apps/app_utils/dpt/midas_net.py +78 -0
- streamlit_apps/app_utils/dpt/models.py +124 -0
- streamlit_apps/app_utils/dpt/transforms.py +231 -0
- streamlit_apps/app_utils/dpt/vit.py +576 -0
- streamlit_apps/app_utils/image_inference.py +88 -0
- streamlit_apps/app_utils/model.py +84 -0
- streamlit_apps/app_utils/smultimae_model.py +43 -0
- streamlit_apps/app_utils/sod_selection_ui.py +111 -0
.gitignore
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
env
|
2 |
+
__pycache__/
|
3 |
+
*.pth
|
4 |
+
*.pt
|
5 |
+
|
6 |
+
datasets/**/benchmark/*
|
7 |
+
datasets/**/test/*
|
8 |
+
datasets/**/train/*
|
9 |
+
datasets/**/dev/*
|
10 |
+
datasets/**/*.zip
|
11 |
+
|
12 |
+
sources/deployment/*
|
13 |
+
sources/experiment/*
|
14 |
+
sources/pickle/*
|
15 |
+
sources/csv/*/*
|
16 |
+
sources/json/*
|
17 |
+
sotas/*
|
18 |
+
continue_training/*
|
19 |
+
|
20 |
+
weights/*
|
21 |
+
|
22 |
+
!*.gitkeep
|
23 |
+
logs
|
24 |
+
wandb
|
25 |
+
tmp/
|
26 |
+
|
27 |
+
wandb_cache
|
28 |
+
script.md
|
README.md
CHANGED
@@ -9,4 +9,75 @@ app_file: app.py
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# S-MultiMAE
|
13 |
+
|
14 |
+
This repository provides the official implementation of `S-MultiMAE A Multi-Ground Truth approach for RGB-D Saliency Detection`
|
15 |
+
|
16 |
+
_Nguyen Truong Thinh Huynh, Van Linh Pham, Xuan Toan Mai and Tuan Anh Tran_
|
17 |
+
|
18 |
+

|
19 |
+
|
20 |
+
## Model weights
|
21 |
+
|
22 |
+
| Backbone | #params | Training paradigm | Weights | Input size |
|
23 |
+
| -------- | ----------- | ----------------- | ---------------------------------------------------------------------------------------------- | ---------- |
|
24 |
+
| ViT-L | 328,318,529 | Multi-GT | [Download](https://drive.google.com/file/d/1YhAuu3DI2adPLQgbgoSt74ilZbpuKihh/view?usp=sharing) | 224x224 |
|
25 |
+
| ViT-B | 107,654,977 | Multi-GT | [Download](https://drive.google.com/file/d/13Omafif3pvPKgg3Isp_srkHf8CSPx33d/view?usp=sharing) | 224x224 |
|
26 |
+
|
27 |
+
## How to run
|
28 |
+
|
29 |
+
### Create a virtual environment
|
30 |
+
|
31 |
+
We recommend using python 3.10 or higher.
|
32 |
+
|
33 |
+
```bash
|
34 |
+
python3.10 -m venv env
|
35 |
+
source env/bin/activate
|
36 |
+
pip install -r requirements.txt
|
37 |
+
```
|
38 |
+
|
39 |
+
### Download trained weights
|
40 |
+
|
41 |
+
- Download model weights and put it in the folder `weights`. You may also need to download the weights of [DPT model]() (a rgb2depth model). The `weights` folder will look like this:
|
42 |
+
|
43 |
+
```bash
|
44 |
+
├── weights
|
45 |
+
│ ├── omnidata_rgb2depth_dpt_hybrid.pth
|
46 |
+
│ ├── s-multimae-cfgv4_0_2006-top1.pth
|
47 |
+
│ ├── s-multimae-cfgv4_0_2007-top1.pth
|
48 |
+
```
|
49 |
+
|
50 |
+
### Run
|
51 |
+
|
52 |
+
- Run streamlit app
|
53 |
+
|
54 |
+
```
|
55 |
+
streamlit run streamlit_apps/app.py --server.port 9113 --browser.gatherUsageStats False --server.fileWatcherType none
|
56 |
+
```
|
57 |
+
|
58 |
+
## Datasets
|
59 |
+
|
60 |
+
### COME15K dataset
|
61 |
+
|
62 |
+
| | 1 GT | 2 GTs | 3 GTs | 4 GTs | 5 GTs |
|
63 |
+
| --------------------- | ------ | ----- | ------ | ----- | ----- |
|
64 |
+
| COME8K (8025 samples) | 77.61% | 1.71% | 18.28% | 2.24% | 0.16% |
|
65 |
+
| COME-E (4600 samples) | 70.5% | 1.87% | 21.15% | 5.70% | 0.78% |
|
66 |
+
| COME8K (3000 samples) | 62.3% | 2.00% | 25.63% | 8.37% | 1.70% |
|
67 |
+
|
68 |
+
```
|
69 |
+
@inproceedings{cascaded_rgbd_sod,
|
70 |
+
title={RGB-D Saliency Detection via Cascaded Mutual Information Minimization},
|
71 |
+
author={Zhang, Jing and Fan, Deng-Ping and Dai, Yuchao and Yu, Xin and Zhong, Yiran and Barnes, Nick and Shao, Ling},
|
72 |
+
booktitle={International Conference on Computer Vision (ICCV)},
|
73 |
+
year={2021}
|
74 |
+
}
|
75 |
+
```
|
76 |
+
|
77 |
+
## References
|
78 |
+
|
79 |
+
All references are cited in these files:
|
80 |
+
|
81 |
+
- [Datasets](./docs/references/Dataset.bib)
|
82 |
+
- [SOTAs](./docs/references/SOTAs.bib)
|
83 |
+
- [Others](./docs/references/References.bib)
|
definition.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Do not import other modules!
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
class PRETRAINED_BACKBONE:
|
7 |
+
MULTIMAE = "multimae"
|
8 |
+
|
9 |
+
S_MULTIMAE = "s-multimae"
|
10 |
+
LARGE_S_MULTIMAE = "large-s-multimae"
|
11 |
+
|
12 |
+
MAE = "mae"
|
13 |
+
LARGE_MAE = "large-mae"
|
14 |
+
HUGE_MAE = "huge-mae"
|
15 |
+
|
16 |
+
FINETUNE_LARGE_S_MULTIMAE = "finetune-large-s-multimae"
|
17 |
+
FINETUNE_S_MULTIMAE = "finetune-s-multimae"
|
18 |
+
|
19 |
+
VIT = "vit" # train from supervised model
|
20 |
+
|
21 |
+
NONE = None # train from scratch
|
docs/figures/proposed_method_v5.drawio.png
ADDED
![]() |
docs/references/Dataset.bib
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
% Encoding: UTF-8
|
2 |
+
|
3 |
+
% DES
|
4 |
+
@inproceedings{cheng2014depth,
|
5 |
+
title={Depth enhanced saliency detection method},
|
6 |
+
author={Cheng, Yupeng and Fu, Huazhu and Wei, Xingxing and Xiao, Jiangjian and Cao, Xiaochun},
|
7 |
+
booktitle={Proceedings of international conference on internet multimedia computing and service},
|
8 |
+
pages={23--27},
|
9 |
+
year={2014}
|
10 |
+
}
|
11 |
+
|
12 |
+
% DUT-RGBD
|
13 |
+
@inproceedings{piao2019depth,
|
14 |
+
title={Depth-induced multi-scale recurrent attention network for saliency detection},
|
15 |
+
author={Piao, Yongri and Ji, Wei and Li, Jingjing and Zhang, Miao and Lu, Huchuan},
|
16 |
+
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
|
17 |
+
pages={7254--7263},
|
18 |
+
year={2019}
|
19 |
+
}
|
20 |
+
|
21 |
+
% LFSD
|
22 |
+
@inproceedings{li2014saliency,
|
23 |
+
title={Saliency detection on light field},
|
24 |
+
author={Li, Nianyi and Ye, Jinwei and Ji, Yu and Ling, Haibin and Yu, Jingyi},
|
25 |
+
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
|
26 |
+
pages={2806--2813},
|
27 |
+
year={2014}
|
28 |
+
}
|
29 |
+
|
30 |
+
% NJU2K
|
31 |
+
@inproceedings{ju2014depth,
|
32 |
+
title={Depth saliency based on anisotropic center-surround difference},
|
33 |
+
author={Ju, Ran and Ge, Ling and Geng, Wenjing and Ren, Tongwei and Wu, Gangshan},
|
34 |
+
booktitle={2014 IEEE international conference on image processing (ICIP)},
|
35 |
+
pages={1115--1119},
|
36 |
+
year={2014},
|
37 |
+
organization={IEEE}
|
38 |
+
}
|
39 |
+
|
40 |
+
% SSD
|
41 |
+
@inproceedings{zhu2017three,
|
42 |
+
title={A three-pathway psychobiological framework of salient object detection using stereoscopic technology},
|
43 |
+
author={Zhu, Chunbiao and Li, Ge},
|
44 |
+
booktitle={Proceedings of the IEEE international conference on computer vision workshops},
|
45 |
+
pages={3008--3014},
|
46 |
+
year={2017}
|
47 |
+
}
|
48 |
+
|
49 |
+
% Holo50K
|
50 |
+
@article{hua2020holopix50k,
|
51 |
+
title={Holopix50k: A large-scale in-the-wild stereo image dataset},
|
52 |
+
author={Hua, Yiwen and Kohli, Puneet and Uplavikar, Pritish and Ravi, Anand and Gunaseelan, Saravana and Orozco, Jason and Li, Edward},
|
53 |
+
journal={arXiv preprint arXiv:2003.11172},
|
54 |
+
year={2020}
|
55 |
+
}
|
56 |
+
|
57 |
+
% NLPR
|
58 |
+
@inproceedings{peng2014rgbd,
|
59 |
+
title={RGBD salient object detection: A benchmark and algorithms},
|
60 |
+
author={Peng, Houwen and Li, Bing and Xiong, Weihua and Hu, Weiming and Ji, Rongrong},
|
61 |
+
booktitle={European conference on computer vision},
|
62 |
+
pages={92--109},
|
63 |
+
year={2014},
|
64 |
+
organization={Springer}
|
65 |
+
}
|
66 |
+
|
67 |
+
|
68 |
+
% SIP
|
69 |
+
@article{fan2020rethinking,
|
70 |
+
title={Rethinking RGB-D salient object detection: Models, data sets, and large-scale benchmarks},
|
71 |
+
author={/project/634a6386039ac5d46d8c6ab0Fan, Deng-Ping and Lin, Zheng and Zhang, Zhao and Zhu, Menglong and Cheng, Ming-Ming},
|
72 |
+
journal={IEEE Transactions on neural networks and learning systems},
|
73 |
+
volume={32},
|
74 |
+
number={5},
|
75 |
+
pages={2075--2089},
|
76 |
+
year={2020},
|
77 |
+
publisher={IEEE}
|
78 |
+
}
|
79 |
+
|
80 |
+
% STERE
|
81 |
+
@inproceedings{niu2012leveraging,
|
82 |
+
title={Leveraging stereopsis for saliency analysis},
|
83 |
+
author={Niu, Yuzhen and Geng, Yujie and Li, Xueqing and Liu, Feng},
|
84 |
+
booktitle={2012 IEEE Conference on Computer Vision and Pattern Recognition},
|
85 |
+
pages={454--461},
|
86 |
+
year={2012},
|
87 |
+
organization={IEEE}
|
88 |
+
}
|
89 |
+
|
90 |
+
% RGB-Thermal
|
91 |
+
@inproceedings{ha2017mfnet,
|
92 |
+
title={MFNet: Towards real-time semantic segmentation for autonomous vehicles with multi-spectral scenes},
|
93 |
+
author={Ha, Qishen and Watanabe, Kohei and Karasawa, Takumi and Ushiku, Yoshitaka and Harada, Tatsuya},
|
94 |
+
booktitle={2017 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
|
95 |
+
pages={5108--5115},
|
96 |
+
year={2017},
|
97 |
+
organization={IEEE}
|
98 |
+
}
|
99 |
+
|
100 |
+
%RGB-Polarization
|
101 |
+
@article{xiang2021polarization,
|
102 |
+
title={Polarization-driven semantic segmentation via efficient attention-bridged fusion},
|
103 |
+
author={Xiang, Kaite and Yang, Kailun and Wang, Kaiwei},
|
104 |
+
journal={Optics Express},
|
105 |
+
volume={29},
|
106 |
+
number={4},
|
107 |
+
pages={4802--4820},
|
108 |
+
year={2021},
|
109 |
+
publisher={Optica Publishing Group}
|
110 |
+
}
|
111 |
+
|
112 |
+
% ImageNet
|
113 |
+
@article{russakovsky2015imagenet,
|
114 |
+
title={Imagenet large scale visual recognition challenge},
|
115 |
+
author={Russakovsky, Olga and Deng, Jia and Su, Hao and Krause, Jonathan and Satheesh, Sanjeev and Ma, Sean and Huang, Zhiheng and Karpathy, Andrej and Khosla, Aditya and Bernstein, Michael and others},
|
116 |
+
journal={International journal of computer vision},
|
117 |
+
volume={115},
|
118 |
+
number={3},
|
119 |
+
pages={211--252},
|
120 |
+
year={2015},
|
121 |
+
publisher={Springer}
|
122 |
+
}
|
123 |
+
|
124 |
+
@Comment{jabref-meta: databaseType:bibtex;}
|
docs/references/References.bib
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
% Encoding: UTF-8
|
2 |
+
|
3 |
+
% An Empirical Study of Training Self-Supervised Vision Transformers
|
4 |
+
@inproceedings{chen2021empirical,
|
5 |
+
title={An empirical study of training self-supervised vision transformers},
|
6 |
+
author={Chen, Xinlei and Xie, Saining and He, Kaiming},
|
7 |
+
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
|
8 |
+
pages={9640--9649},
|
9 |
+
year={2021}
|
10 |
+
}
|
11 |
+
|
12 |
+
% 2D positional embedding
|
13 |
+
@article{raisi20202d,
|
14 |
+
title={2D positional embedding-based transformer for scene text recognition},
|
15 |
+
author={Raisi, Zobeir and Naiel, Mohamed A and Fieguth, Paul and Wardell, Steven and Zelek, John},
|
16 |
+
journal={Journal of Computational Vision and Imaging Systems},
|
17 |
+
volume={6},
|
18 |
+
number={1},
|
19 |
+
pages={1--4},
|
20 |
+
year={2020}
|
21 |
+
}
|
22 |
+
|
23 |
+
% Layer Normalization
|
24 |
+
@article{ba2016layer,
|
25 |
+
title={Layer normalization},
|
26 |
+
author={Ba, Jimmy Lei and Kiros, Jamie Ryan and Hinton, Geoffrey E},
|
27 |
+
journal={arXiv preprint arXiv:1607.06450},
|
28 |
+
year={2016}
|
29 |
+
}
|
30 |
+
|
31 |
+
% Batch Normalization
|
32 |
+
@inproceedings{ioffe2015batch,
|
33 |
+
title={Batch normalization: Accelerating deep network training by reducing internal covariate shift},
|
34 |
+
author={Ioffe, Sergey and Szegedy, Christian},
|
35 |
+
booktitle={International conference on machine learning},
|
36 |
+
pages={448--456},
|
37 |
+
year={2015},
|
38 |
+
organization={PMLR}
|
39 |
+
}
|
40 |
+
|
41 |
+
% ReLU
|
42 |
+
@article{fukushima1975cognitron,
|
43 |
+
title={Cognitron: A self-organizing multilayered neural network},
|
44 |
+
author={Fukushima, Kunihiko},
|
45 |
+
journal={Biological cybernetics},
|
46 |
+
volume={20},
|
47 |
+
number={3},
|
48 |
+
pages={121--136},
|
49 |
+
year={1975},
|
50 |
+
publisher={Springer}
|
51 |
+
}
|
52 |
+
|
53 |
+
% Weight Normalization
|
54 |
+
@article{salimans2016weight,
|
55 |
+
title={Weight normalization: A simple reparameterization to accelerate training of deep neural networks},
|
56 |
+
author={Salimans, Tim and Kingma, Durk P},
|
57 |
+
journal={Advances in neural information processing systems},
|
58 |
+
volume={29},
|
59 |
+
year={2016}
|
60 |
+
}
|
61 |
+
|
62 |
+
% Stochastic depth
|
63 |
+
@inproceedings{huang2016deep,
|
64 |
+
title={Deep networks with stochastic depth},
|
65 |
+
author={Huang, Gao and Sun, Yu and Liu, Zhuang and Sedra, Daniel and Weinberger, Kilian Q},
|
66 |
+
booktitle={European conference on computer vision},
|
67 |
+
pages={646--661},
|
68 |
+
year={2016},
|
69 |
+
organization={Springer}
|
70 |
+
}
|
71 |
+
|
72 |
+
% Stereo Matching Algorithm
|
73 |
+
@article{zhong2020displacement,
|
74 |
+
title={Displacement-invariant cost computation for efficient stereo matching},
|
75 |
+
author={Zhong, Yiran and Loop, Charles and Byeon, Wonmin and Birchfield, Stan and Dai, Yuchao and Zhang, Kaihao and Kamenev, Alexey and Breuel, Thomas and Li, Hongdong and Kautz, Jan},
|
76 |
+
journal={arXiv preprint arXiv:2012.00899},
|
77 |
+
year={2020}
|
78 |
+
}
|
79 |
+
|
80 |
+
% wandb
|
81 |
+
@misc{wandb,
|
82 |
+
title = {Experiment Tracking with Weights and Biases},
|
83 |
+
year = {2020},
|
84 |
+
note = {Software available from wandb.com},
|
85 |
+
url={https://www.wandb.com/},
|
86 |
+
author = {Biewald, Lukas},
|
87 |
+
}
|
88 |
+
|
89 |
+
%
|
90 |
+
@article{borji2015salient,
|
91 |
+
title={Salient object detection: A benchmark},
|
92 |
+
author={Borji, Ali and Cheng, Ming-Ming and Jiang, Huaizu and Li, Jia},
|
93 |
+
journal={IEEE transactions on image processing},
|
94 |
+
volume={24},
|
95 |
+
number={12},
|
96 |
+
pages={5706--5722},
|
97 |
+
year={2015},
|
98 |
+
publisher={IEEE}
|
99 |
+
}
|
100 |
+
|
101 |
+
% SOD metrics
|
102 |
+
@misc{sodmetrics,
|
103 |
+
title = {PySODMetrics: A simple and efficient implementation of SOD metrics},
|
104 |
+
howpublished = {\url{https://github.com/lartpang/PySODMetrics}},
|
105 |
+
note = {Accessed: 2022-10-31}
|
106 |
+
}
|
107 |
+
|
108 |
+
% MAE
|
109 |
+
@inproceedings{perazzi2012saliency,
|
110 |
+
title={Saliency filters: Contrast based filtering for salient region detection},
|
111 |
+
author={Perazzi, Federico and Kr{\"a}henb{\"u}hl, Philipp and Pritch, Yael and Hornung, Alexander},
|
112 |
+
booktitle={2012 IEEE conference on computer vision and pattern recognition},
|
113 |
+
pages={733--740},
|
114 |
+
year={2012},
|
115 |
+
organization={IEEE}
|
116 |
+
}
|
117 |
+
|
118 |
+
% F-measure
|
119 |
+
@inproceedings{achanta2009frequency,
|
120 |
+
title={Frequency-tuned salient region detection},
|
121 |
+
author={Achanta, Radhakrishna and Hemami, Sheila and Estrada, Francisco and Susstrunk, Sabine},
|
122 |
+
booktitle={2009 IEEE conference on computer vision and pattern recognition},
|
123 |
+
pages={1597--1604},
|
124 |
+
year={2009},
|
125 |
+
organization={IEEE}
|
126 |
+
}
|
127 |
+
|
128 |
+
% E-measure
|
129 |
+
@article{fan2018enhanced,
|
130 |
+
title={Enhanced-alignment measure for binary foreground map evaluation},
|
131 |
+
author={Fan, Deng-Ping and Gong, Cheng and Cao, Yang and Ren, Bo and Cheng, Ming-Ming and Borji, Ali},
|
132 |
+
journal={arXiv preprint arXiv:1805.10421},
|
133 |
+
year={2018}
|
134 |
+
}
|
135 |
+
|
136 |
+
% S-measure
|
137 |
+
@inproceedings{fan2017structure,
|
138 |
+
title={Structure-measure: A new way to evaluate foreground maps},
|
139 |
+
author={Fan, Deng-Ping and Cheng, Ming-Ming and Liu, Yun and Li, Tao and Borji, Ali},
|
140 |
+
booktitle={Proceedings of the IEEE international conference on computer vision},
|
141 |
+
pages={4548--4557},
|
142 |
+
year={2017}
|
143 |
+
}
|
144 |
+
|
145 |
+
% GELU
|
146 |
+
@article{hendrycks2016gaussian,
|
147 |
+
title={Gaussian error linear units (gelus)},
|
148 |
+
author={Hendrycks, Dan and Gimpel, Kevin},
|
149 |
+
journal={arXiv preprint arXiv:1606.08415},
|
150 |
+
year={2016}
|
151 |
+
}
|
152 |
+
|
153 |
+
% Instance normalization
|
154 |
+
@article{ulyanov2016instance,
|
155 |
+
title={Instance normalization: The missing ingredient for fast stylization},
|
156 |
+
author={Ulyanov, Dmitry and Vedaldi, Andrea and Lempitsky, Victor},
|
157 |
+
journal={arXiv preprint arXiv:1607.08022},
|
158 |
+
year={2016}
|
159 |
+
}
|
160 |
+
|
161 |
+
% Group normalization
|
162 |
+
@inproceedings{wu2018group,
|
163 |
+
title={Group normalization},
|
164 |
+
author={Wu, Yuxin and He, Kaiming},
|
165 |
+
booktitle={Proceedings of the European conference on computer vision (ECCV)},
|
166 |
+
pages={3--19},
|
167 |
+
year={2018}
|
168 |
+
}
|
169 |
+
|
170 |
+
% timm
|
171 |
+
@misc{rw2019timm,
|
172 |
+
author = {Ross Wightman},
|
173 |
+
title = {PyTorch Image Models},
|
174 |
+
year = {2019},
|
175 |
+
publisher = {GitHub},
|
176 |
+
journal = {GitHub repository},
|
177 |
+
doi = {10.5281/zenodo.4414861},
|
178 |
+
howpublished = {\url{https://github.com/rwightman/pytorch-image-models}}
|
179 |
+
}
|
180 |
+
|
181 |
+
% taskonomy
|
182 |
+
@inproceedings{zamir2018taskonomy,
|
183 |
+
title={Taskonomy: Disentangling task transfer learning},
|
184 |
+
author={Zamir, Amir R and Sax, Alexander and Shen, William and Guibas, Leonidas J and Malik, Jitendra and Savarese, Silvio},
|
185 |
+
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
|
186 |
+
pages={3712--3722},
|
187 |
+
year={2018}
|
188 |
+
}
|
189 |
+
|
190 |
+
@Comment{jabref-meta: databaseType:bibtex;}
|
docs/references/SOTAs.bib
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
% Encoding: UTF-8
|
2 |
+
|
3 |
+
% COME15K CMINet
|
4 |
+
@Article{cascaded_rgbd_sod,
|
5 |
+
title={RGB-D Saliency Detection via Cascaded Mutual Information Minimization},
|
6 |
+
author={Zhang, Jing and Fan, Deng-Ping and Dai, Yuchao and Yu, Xin and Zhong, Yiran and Barnes, Nick and Shao, Ling},
|
7 |
+
booktitle={International Conference on Computer Vision (ICCV)},
|
8 |
+
year={2021}
|
9 |
+
}
|
10 |
+
|
11 |
+
% A2dele
|
12 |
+
@Article{piao2020a2dele,
|
13 |
+
title={A2dele: Adaptive and attentive depth distiller for efficient RGB-D salient object detection},
|
14 |
+
author={Piao, Yongri and Rong, Zhengkun and Zhang, Miao and Ren, Weisong and Lu, Huchuan},
|
15 |
+
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
|
16 |
+
pages={9060--9069},
|
17 |
+
year={2020}
|
18 |
+
}
|
19 |
+
|
20 |
+
% BBS-Net
|
21 |
+
@Article{fan2020bbs,
|
22 |
+
title={BBS-Net: RGB-D salient object detection with a bifurcated backbone strategy network},
|
23 |
+
author={Fan, Deng-Ping and Zhai, Yingjie and Borji, Ali and Yang, Jufeng and Shao, Ling},
|
24 |
+
booktitle={European conference on computer vision},
|
25 |
+
pages={275--292},
|
26 |
+
year={2020},
|
27 |
+
organization={Springer}
|
28 |
+
}
|
29 |
+
|
30 |
+
% MobileSal
|
31 |
+
@article{wu2021mobilesal,
|
32 |
+
title={MobileSal: Extremely efficient RGB-D salient object detection},
|
33 |
+
author={Wu, Yu-Huan and Liu, Yun and Xu, Jun and Bian, Jia-Wang and Gu, Yu-Chao and Cheng, Ming-Ming},
|
34 |
+
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
|
35 |
+
year={2021},
|
36 |
+
publisher={IEEE}
|
37 |
+
}
|
38 |
+
|
39 |
+
% ATSA
|
40 |
+
@Article{zhang2020asymmetric,
|
41 |
+
title={Asymmetric two-stream architecture for accurate RGB-D saliency detection},
|
42 |
+
author={Zhang, Miao and Fei, Sun Xiao and Liu, Jie and Xu, Shuang and Piao, Yongri and Lu, Huchuan},
|
43 |
+
booktitle={European Conference on Computer Vision},
|
44 |
+
pages={374--390},
|
45 |
+
year={2020},
|
46 |
+
organization={Springer}
|
47 |
+
}
|
48 |
+
|
49 |
+
% CDNet
|
50 |
+
@article{jin2021cdnet,
|
51 |
+
title={CDNet: Complementary depth network for RGB-D salient object detection},
|
52 |
+
author={Jin, Wen-Da and Xu, Jun and Han, Qi and Zhang, Yi and Cheng, Ming-Ming},
|
53 |
+
journal={IEEE Transactions on Image Processing},
|
54 |
+
volume={30},
|
55 |
+
pages={3376--3390},
|
56 |
+
year={2021},
|
57 |
+
publisher={IEEE}
|
58 |
+
}
|
59 |
+
|
60 |
+
% CoNet
|
61 |
+
@Article{ji2020accurate,
|
62 |
+
title={Accurate RGB-D salient object detection via collaborative learning},
|
63 |
+
author={Ji, Wei and Li, Jingjing and Zhang, Miao and Piao, Yongri and Lu, Huchuan},
|
64 |
+
booktitle={European Conference on Computer Vision},
|
65 |
+
pages={52--69},
|
66 |
+
year={2020},
|
67 |
+
organization={Springer}
|
68 |
+
}
|
69 |
+
|
70 |
+
% SPNet
|
71 |
+
@inproceedings{zhou2021specificity,
|
72 |
+
title={Specificity-preserving rgb-d saliency detection},
|
73 |
+
author={Zhou, Tao and Fu, Huazhu and Chen, Geng and Zhou, Yi and Fan, Deng-Ping and Shao, Ling},
|
74 |
+
booktitle={Proceedings of the IEEE/CVF international conference on computer vision},
|
75 |
+
pages={4681--4691},
|
76 |
+
year={2021}
|
77 |
+
}
|
78 |
+
|
79 |
+
% C2DFNet
|
80 |
+
@article{zhang2022c,
|
81 |
+
title={C2DFNet: Criss-Cross Dynamic Filter Network for RGB-D Salient Object Detection},
|
82 |
+
author={Zhang, Miao and Yao, Shunyu and Hu, Beiqi and Piao, Yongri and Ji, Wei},
|
83 |
+
journal={IEEE Transactions on Multimedia},
|
84 |
+
year={2022},
|
85 |
+
publisher={IEEE}
|
86 |
+
}
|
87 |
+
|
88 |
+
% SPSN
|
89 |
+
@inproceedings{lee2022spsn,
|
90 |
+
title={SPSN: Superpixel Prototype Sampling Network for RGB-D Salient Object Detection},
|
91 |
+
author={Lee, Minhyeok and Park, Chaewon and Cho, Suhwan and Lee, Sangyoun},
|
92 |
+
booktitle={European Conference on Computer Vision},
|
93 |
+
pages={630--647},
|
94 |
+
year={2022},
|
95 |
+
organization={Springer}
|
96 |
+
}
|
97 |
+
|
98 |
+
% ConvNeXt
|
99 |
+
@inproceedings{liu2022convnet,
|
100 |
+
title={A convnet for the 2020s},
|
101 |
+
author={Liu, Zhuang and Mao, Hanzi and Wu, Chao-Yuan and Feichtenhofer, Christoph and Darrell, Trevor and Xie, Saining},
|
102 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
103 |
+
pages={11976--11986},
|
104 |
+
year={2022}
|
105 |
+
}
|
106 |
+
|
107 |
+
% GPT-2
|
108 |
+
@article{radford2019language,
|
109 |
+
title={Language models are unsupervised multitask learners},
|
110 |
+
author={Radford, Alec and Wu, Jeffrey and Child, Rewon and Luan, David and Amodei, Dario and Sutskever, Ilya and others},
|
111 |
+
journal={OpenAI blog},
|
112 |
+
volume={1},
|
113 |
+
number={8},
|
114 |
+
pages={9},
|
115 |
+
year={2019}
|
116 |
+
}
|
117 |
+
|
118 |
+
% BERT
|
119 |
+
@article{devlin2018bert,
|
120 |
+
title={Bert: Pre-training of deep bidirectional transformers for language understanding},
|
121 |
+
author={Devlin, Jacob and Chang, Ming-Wei and Lee, Kenton and Toutanova, Kristina},
|
122 |
+
journal={arXiv preprint arXiv:1810.04805},
|
123 |
+
year={2018}
|
124 |
+
}
|
125 |
+
|
126 |
+
% UNet
|
127 |
+
@inproceedings{ronneberger2015u,
|
128 |
+
title={U-net: Convolutional networks for biomedical image segmentation},
|
129 |
+
author={Ronneberger, Olaf and Fischer, Philipp and Brox, Thomas},
|
130 |
+
booktitle={International Conference on Medical image computing and computer-assisted intervention},
|
131 |
+
pages={234--241},
|
132 |
+
year={2015},
|
133 |
+
organization={Springer}
|
134 |
+
}
|
135 |
+
|
136 |
+
% MobileNetV2
|
137 |
+
@inproceedings{sandler2018mobilenetv2,
|
138 |
+
title={Mobilenetv2: Inverted residuals and linear bottlenecks},
|
139 |
+
author={Sandler, Mark and Howard, Andrew and Zhu, Menglong and Zhmoginov, Andrey and Chen, Liang-Chieh},
|
140 |
+
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
|
141 |
+
pages={4510--4520},
|
142 |
+
year={2018}
|
143 |
+
}
|
144 |
+
|
145 |
+
% Xception
|
146 |
+
@inproceedings{chollet2017xception,
|
147 |
+
title={Xception: Deep learning with depthwise separable convolutions},
|
148 |
+
author={Chollet, Fran{\c{c}}ois},
|
149 |
+
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
|
150 |
+
pages={1251--1258},
|
151 |
+
year={2017}
|
152 |
+
}
|
153 |
+
|
154 |
+
% MobileNets
|
155 |
+
@article{howard2017mobilenets,
|
156 |
+
title={Mobilenets: Efficient convolutional neural networks for mobile vision applications},
|
157 |
+
author={Howard, Andrew G and Zhu, Menglong and Chen, Bo and Kalenichenko, Dmitry and Wang, Weijun and Weyand, Tobias and Andreetto, Marco and Adam, Hartwig},
|
158 |
+
journal={arXiv preprint arXiv:1704.04861},
|
159 |
+
year={2017}
|
160 |
+
}
|
161 |
+
|
162 |
+
% ResNet
|
163 |
+
@inproceedings{he2016deep,
|
164 |
+
title={Deep residual learning for image recognition},
|
165 |
+
author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian},
|
166 |
+
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
|
167 |
+
pages={770--778},
|
168 |
+
year={2016}
|
169 |
+
}
|
170 |
+
|
171 |
+
% ResNeXt
|
172 |
+
@inproceedings{xie2017aggregated,
|
173 |
+
title={Aggregated residual transformations for deep neural networks},
|
174 |
+
author={Xie, Saining and Girshick, Ross and Doll{\'a}r, Piotr and Tu, Zhuowen and He, Kaiming},
|
175 |
+
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
|
176 |
+
pages={1492--1500},
|
177 |
+
year={2017}
|
178 |
+
}
|
179 |
+
|
180 |
+
% MultiMAE
|
181 |
+
@article{bachmann2022multimae,
|
182 |
+
title={MultiMAE: Multi-modal Multi-task Masked Autoencoders},
|
183 |
+
author={Bachmann, Roman and Mizrahi, David and Atanov, Andrei and Zamir, Amir},
|
184 |
+
journal={arXiv preprint arXiv:2204.01678},
|
185 |
+
year={2022}
|
186 |
+
}
|
187 |
+
|
188 |
+
% MAE
|
189 |
+
@inproceedings{he2022masked,
|
190 |
+
title={Masked autoencoders are scalable vision learners},
|
191 |
+
author={He, Kaiming and Chen, Xinlei and Xie, Saining and Li, Yanghao and Doll{\'a}r, Piotr and Girshick, Ross},
|
192 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
193 |
+
pages={16000--16009},
|
194 |
+
year={2022}
|
195 |
+
}
|
196 |
+
|
197 |
+
% VisionTransformer, ViT
|
198 |
+
@article{dosovitskiy2020image,
|
199 |
+
title={An image is worth 16x16 words: Transformers for image recognition at scale},
|
200 |
+
author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and others},
|
201 |
+
journal={arXiv preprint arXiv:2010.11929},
|
202 |
+
year={2020}
|
203 |
+
}
|
204 |
+
|
205 |
+
% DANet
|
206 |
+
@Article{zhao2020single,
|
207 |
+
title={A single stream network for robust and real-time RGB-D salient object detection},
|
208 |
+
author={Zhao, Xiaoqi and Zhang, Lihe and Pang, Youwei and Lu, Huchuan and Zhang, Lei},
|
209 |
+
booktitle={European Conference on Computer Vision},
|
210 |
+
pages={646--662},
|
211 |
+
year={2020},
|
212 |
+
organization={Springer}
|
213 |
+
}
|
214 |
+
|
215 |
+
% DCF
|
216 |
+
@Article{Ji_2021_DCF,
|
217 |
+
author = {Ji, Wei and Li, Jingjing and Yu, Shuang and Zhang, Miao and Piao, Yongri and Yao, Shunyu and Bi, Qi and Ma, Kai and Zheng, Yefeng and Lu, Huchuan and Cheng, Li},
|
218 |
+
title = {Calibrated RGB-D Salient Object Detection},
|
219 |
+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
220 |
+
year = {2021},
|
221 |
+
pages = {9471-9481}
|
222 |
+
}
|
223 |
+
|
224 |
+
% MVSalNet
|
225 |
+
@inproceedings{zhou2022mvsalnet,
|
226 |
+
title={MVSalNet: Multi-view Augmentation for RGB-D Salient Object Detection},
|
227 |
+
author={Zhou, Jiayuan and Wang, Lijun and Lu, Huchuan and Huang, Kaining and Shi, Xinchu and Liu, Bocong},
|
228 |
+
booktitle={European Conference on Computer Vision},
|
229 |
+
pages={270--287},
|
230 |
+
year={2022},
|
231 |
+
organization={Springer}
|
232 |
+
}
|
233 |
+
|
234 |
+
% DSA2F
|
235 |
+
@Article{Sun2021DeepRS,
|
236 |
+
title={Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion},
|
237 |
+
author={P. Sun and Wenhu Zhang and Huanyu Wang and Songyuan Li and Xi Li},
|
238 |
+
journal={IEEE Conf. Comput. Vis. Pattern Recog.},
|
239 |
+
year={2021}
|
240 |
+
}
|
241 |
+
|
242 |
+
% FRDT
|
243 |
+
@Article{zhang2020feature,
|
244 |
+
title={Feature reintegration over differential treatment: A top-down and adaptive fusion network for RGB-D salient object detection},
|
245 |
+
author={Zhang, Miao and Zhang, Yu and Piao, Yongri and Hu, Beiqi and Lu, Huchuan},
|
246 |
+
booktitle={Proceedings of the 28th ACM international conference on multimedia},
|
247 |
+
pages={4107--4115},
|
248 |
+
year={2020}
|
249 |
+
}
|
250 |
+
|
251 |
+
% HAINet
|
252 |
+
@article{li2021hierarchical,
|
253 |
+
title={Hierarchical alternate interaction network for RGB-D salient object detection},
|
254 |
+
author={Li, Gongyang and Liu, Zhi and Chen, Minyu and Bai, Zhen and Lin, Weisi and Ling, Haibin},
|
255 |
+
journal={IEEE Transactions on Image Processing},
|
256 |
+
volume={30},
|
257 |
+
pages={3528--3542},
|
258 |
+
year={2021},
|
259 |
+
publisher={IEEE}
|
260 |
+
}
|
261 |
+
|
262 |
+
% JLDCF
|
263 |
+
@Article{fu2020jl,
|
264 |
+
title={JL-DCF: Joint learning and densely-cooperative fusion framework for RGB-D salient object detection},
|
265 |
+
author={Fu, Keren and Fan, Deng-Ping and Ji, Ge-Peng and Zhao, Qijun},
|
266 |
+
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
|
267 |
+
pages={3052--3062},
|
268 |
+
year={2020}
|
269 |
+
}
|
270 |
+
|
271 |
+
% SSLSOD
|
272 |
+
@inproceedings{zhao2022self,
|
273 |
+
title={Self-supervised pretraining for rgb-d salient object detection},
|
274 |
+
author={Zhao, Xiaoqi and Pang, Youwei and Zhang, Lihe and Lu, Huchuan and Ruan, Xiang},
|
275 |
+
booktitle={AAAI Conference on Artificial Intelligence},
|
276 |
+
volume={3},
|
277 |
+
year={2022}
|
278 |
+
}
|
279 |
+
|
280 |
+
% DFTR
|
281 |
+
@article{zhudftr,
|
282 |
+
title={DFTR: Depth-supervised Fusion Transformer for Salient Object Detection},
|
283 |
+
author={Zhu, Heqin and Sun, Xu and Li, Yuexiang and Ma, Kai and Zhou, S Kevin and Zheng, Yefeng}
|
284 |
+
}
|
285 |
+
|
286 |
+
% PGAR
|
287 |
+
@Article{chen2020progressively,
|
288 |
+
title={Progressively guided alternate refinement network for RGB-D salient object detection},
|
289 |
+
author={Chen, Shuhan and Fu, Yun},
|
290 |
+
booktitle={European Conference on Computer Vision},
|
291 |
+
pages={520--538},
|
292 |
+
year={2020},
|
293 |
+
organization={Springer}
|
294 |
+
}
|
295 |
+
|
296 |
+
% DCMF
|
297 |
+
@article{wang2022learning,
|
298 |
+
title={Learning Discriminative Cross-Modality Features for RGB-D Saliency Detection},
|
299 |
+
author={Wang, Fengyun and Pan, Jinshan and Xu, Shoukun and Tang, Jinhui},
|
300 |
+
journal={IEEE Transactions on Image Processing},
|
301 |
+
volume={31},
|
302 |
+
pages={1285--1297},
|
303 |
+
year={2022},
|
304 |
+
publisher={IEEE}
|
305 |
+
}
|
306 |
+
|
307 |
+
% RD3D
|
308 |
+
@Article{chen2021rgb,
|
309 |
+
title={RGB-D salient object detection via 3D convolutional neural networks},
|
310 |
+
author={Chen, Qian and Liu, Ze and Zhang, Yi and Fu, Keren and Zhao, Qijun and Du, Hongwei},
|
311 |
+
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
312 |
+
volume={35},
|
313 |
+
number={2},
|
314 |
+
pages={1063--1071},
|
315 |
+
year={2021}
|
316 |
+
}
|
317 |
+
|
318 |
+
% ReDWeb-S
|
319 |
+
% S2MA
|
320 |
+
@Article{liu2020learning,
|
321 |
+
title={Learning selective self-mutual attention for RGB-D saliency detection},
|
322 |
+
author={Liu, Nian and Zhang, Ni and Han, Junwei},
|
323 |
+
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
|
324 |
+
pages={13756--13765},
|
325 |
+
year={2020}
|
326 |
+
}
|
327 |
+
|
328 |
+
% SSF
|
329 |
+
@Article{zhang2020select,
|
330 |
+
title={Select, supplement and focus for RGB-D saliency detection},
|
331 |
+
author={Zhang, Miao and Ren, Weisong and Piao, Yongri and Rong, Zhengkun and Lu, Huchuan},
|
332 |
+
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
|
333 |
+
pages={3472--3481},
|
334 |
+
year={2020}
|
335 |
+
}
|
336 |
+
|
337 |
+
% UCNet
|
338 |
+
@Article{zhang2020uc,
|
339 |
+
title={UC-Net: Uncertainty inspired RGB-D saliency detection via conditional variational autoencoders},
|
340 |
+
author={Zhang, Jing and Fan, Deng-Ping and Dai, Yuchao and Anwar, Saeed and Saleh, Fatemeh Sadat and Zhang, Tong and Barnes, Nick},
|
341 |
+
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
|
342 |
+
pages={8582--8591},
|
343 |
+
year={2020}
|
344 |
+
}
|
345 |
+
|
346 |
+
% TriTransNet
|
347 |
+
@inproceedings{liu2021tritransnet,
|
348 |
+
title={TriTransNet: RGB-D salient object detection with a triplet transformer embedding network},
|
349 |
+
author={Liu, Zhengyi and Wang, Yuan and Tu, Zhengzheng and Xiao, Yun and Tang, Bin},
|
350 |
+
booktitle={Proceedings of the 29th ACM international conference on multimedia},
|
351 |
+
pages={4481--4490},
|
352 |
+
year={2021}
|
353 |
+
}
|
354 |
+
|
355 |
+
@Comment{jabref-meta: databaseType:bibtex;}
|
requirements-lock.txt
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohttp==3.9.3
|
2 |
+
aiosignal==1.3.1
|
3 |
+
albumentations==1.4.3
|
4 |
+
altair==5.3.0
|
5 |
+
async-timeout==4.0.3
|
6 |
+
attrs==23.2.0
|
7 |
+
black==24.3.0
|
8 |
+
blinker==1.7.0
|
9 |
+
cachetools==5.3.3
|
10 |
+
certifi==2024.2.2
|
11 |
+
charset-normalizer==3.3.2
|
12 |
+
click==8.1.7
|
13 |
+
contourpy==1.2.1
|
14 |
+
cycler==0.12.1
|
15 |
+
docstring_parser==0.16
|
16 |
+
einops==0.7.0
|
17 |
+
filelock==3.13.3
|
18 |
+
fonttools==4.51.0
|
19 |
+
frozenlist==1.4.1
|
20 |
+
fsspec==2024.3.1
|
21 |
+
gitdb==4.0.11
|
22 |
+
GitPython==3.1.43
|
23 |
+
huggingface-hub==0.22.2
|
24 |
+
idna==3.6
|
25 |
+
imageio==2.34.0
|
26 |
+
Jinja2==3.1.3
|
27 |
+
joblib==1.3.2
|
28 |
+
jsonschema==4.21.1
|
29 |
+
jsonschema-specifications==2023.12.1
|
30 |
+
kiwisolver==1.4.5
|
31 |
+
lazy_loader==0.4
|
32 |
+
lightning-utilities==0.11.2
|
33 |
+
markdown-it-py==3.0.0
|
34 |
+
MarkupSafe==2.1.5
|
35 |
+
matplotlib==3.8.4
|
36 |
+
mdurl==0.1.2
|
37 |
+
mpmath==1.3.0
|
38 |
+
multidict==6.0.5
|
39 |
+
mypy-extensions==1.0.0
|
40 |
+
networkx==3.3
|
41 |
+
numpy==1.26.4
|
42 |
+
nvidia-cublas-cu12==12.1.3.1
|
43 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
44 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
45 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
46 |
+
nvidia-cudnn-cu12==8.9.2.26
|
47 |
+
nvidia-cufft-cu12==11.0.2.54
|
48 |
+
nvidia-curand-cu12==10.3.2.106
|
49 |
+
nvidia-cusolver-cu12==11.4.5.107
|
50 |
+
nvidia-cusparse-cu12==12.1.0.106
|
51 |
+
nvidia-nccl-cu12==2.18.1
|
52 |
+
nvidia-nvjitlink-cu12==12.4.127
|
53 |
+
nvidia-nvtx-cu12==12.1.105
|
54 |
+
opencv-python==4.9.0.80
|
55 |
+
opencv-python-headless==4.9.0.80
|
56 |
+
packaging==24.0
|
57 |
+
pandas==2.2.1
|
58 |
+
pathspec==0.12.1
|
59 |
+
pillow==10.3.0
|
60 |
+
platformdirs==4.2.0
|
61 |
+
protobuf==4.25.3
|
62 |
+
pyarrow==15.0.2
|
63 |
+
pycocotools==2.0.7
|
64 |
+
pydeck==0.8.1b0
|
65 |
+
Pygments==2.17.2
|
66 |
+
pyparsing==3.1.2
|
67 |
+
python-dateutil==2.9.0.post0
|
68 |
+
pytorch-lightning==2.2.1
|
69 |
+
pytz==2024.1
|
70 |
+
PyYAML==6.0.1
|
71 |
+
referencing==0.34.0
|
72 |
+
requests==2.31.0
|
73 |
+
rich==13.7.1
|
74 |
+
rpds-py==0.18.0
|
75 |
+
safetensors==0.4.2
|
76 |
+
scikit-image==0.22.0
|
77 |
+
scikit-learn==1.4.1.post1
|
78 |
+
scipy==1.13.0
|
79 |
+
six==1.16.0
|
80 |
+
smmap==5.0.1
|
81 |
+
streamlit==1.33.0
|
82 |
+
sympy==1.12
|
83 |
+
tenacity==8.2.3
|
84 |
+
termcolor==2.4.0
|
85 |
+
threadpoolctl==3.4.0
|
86 |
+
tifffile==2024.2.12
|
87 |
+
timm==0.9.16
|
88 |
+
toml==0.10.2
|
89 |
+
tomli==2.0.1
|
90 |
+
toolz==0.12.1
|
91 |
+
torch==2.1.0
|
92 |
+
torchmetrics==1.3.2
|
93 |
+
torchvision==0.16.0
|
94 |
+
tornado==6.4
|
95 |
+
tqdm==4.66.2
|
96 |
+
triton==2.1.0
|
97 |
+
typed-argument-parser==1.9.0
|
98 |
+
typing-inspect==0.9.0
|
99 |
+
typing_extensions==4.11.0
|
100 |
+
tzdata==2024.1
|
101 |
+
urllib3==2.2.1
|
102 |
+
watchdog==4.0.0
|
103 |
+
yarl==1.9.4
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.1.0
|
2 |
+
torchvision
|
3 |
+
opencv-python
|
4 |
+
pycocotools
|
5 |
+
matplotlib
|
6 |
+
Pillow
|
7 |
+
numpy
|
8 |
+
einops
|
9 |
+
timm
|
10 |
+
albumentations
|
11 |
+
termcolor
|
12 |
+
tqdm
|
13 |
+
pandas
|
14 |
+
typed-argument-parser
|
15 |
+
pytorch-lightning
|
16 |
+
streamlit
|
17 |
+
black
|
18 |
+
huggingface-hub
|
s_multimae/__init__.py
ADDED
File without changes
|
s_multimae/configs/__init__.py
ADDED
File without changes
|
s_multimae/configs/base_config.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import os
|
3 |
+
from typing import Dict, Optional, Tuple, List
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import math
|
7 |
+
import albumentations as A
|
8 |
+
|
9 |
+
from definition import PRETRAINED_BACKBONE
|
10 |
+
from .data_augmentation_config import DataAugmentationConfig
|
11 |
+
|
12 |
+
|
13 |
+
class base_cfg:
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
epoch: int,
|
17 |
+
datasets_set: int,
|
18 |
+
experiment_name: Optional[str] = None,
|
19 |
+
):
|
20 |
+
self.experiment_name = experiment_name = (
|
21 |
+
self.__class__.__name__ if experiment_name is None else experiment_name
|
22 |
+
)
|
23 |
+
self.datasets_set = datasets_set
|
24 |
+
|
25 |
+
# Trainv3
|
26 |
+
self.devices: List[int] = [0, 1]
|
27 |
+
# How often to check the validation set. Pass a float in the range [0.0, 1.0] to check
|
28 |
+
self.val_check_interval: float = 1.0
|
29 |
+
|
30 |
+
# Perform a validation loop every after every N training epochs.
|
31 |
+
self.check_val_every_n_epoch: int = 2
|
32 |
+
|
33 |
+
self.precision = 16
|
34 |
+
self.transform1 = [A.HorizontalFlip(p=0.5)]
|
35 |
+
|
36 |
+
self.save_top_k = 2
|
37 |
+
|
38 |
+
# ConvNeXtAdapter
|
39 |
+
self.dec_kernel = 1 # decoder kernel size
|
40 |
+
|
41 |
+
# Version 1: as usual
|
42 |
+
# Version 2: mean, std
|
43 |
+
self.model_version = 1
|
44 |
+
|
45 |
+
self.visualized_num_dev_samples = 0
|
46 |
+
|
47 |
+
# PytorchLightning Trainer
|
48 |
+
self.sync_batchnorm = True
|
49 |
+
|
50 |
+
self.normalized_depth: bool = True
|
51 |
+
|
52 |
+
self.test_image_size: int = 224
|
53 |
+
self.image_size: int = 224
|
54 |
+
|
55 |
+
"""Whether using fp16 instead of fp32 (default)"""
|
56 |
+
self.is_fp16: bool = True
|
57 |
+
|
58 |
+
self.is_padding: bool = (
|
59 |
+
False # deprecated due to randomly switch between padding and non-padding
|
60 |
+
)
|
61 |
+
|
62 |
+
# """For debug only"""
|
63 |
+
# self.max_train_samples: Optional[int] = None
|
64 |
+
# self.max_dev_samples: Optional[int] = None
|
65 |
+
|
66 |
+
"""Whether using padding for test"""
|
67 |
+
self.is_padding_for_test: bool = False
|
68 |
+
|
69 |
+
"""Seed"""
|
70 |
+
self.seed: int = 2022
|
71 |
+
|
72 |
+
""" MultiMAE """
|
73 |
+
self.decoder_depth: int = 4
|
74 |
+
self.encoder_depth: int = 12
|
75 |
+
self.is_inference_with_no_depth: bool = False
|
76 |
+
self.inputs = ["rgb", "depth"]
|
77 |
+
self.outputs = ["sod"]
|
78 |
+
self.decoder_main_tasks: List[List[str]] = [["rgb"]]
|
79 |
+
self.learnable_pos_emb: bool = False
|
80 |
+
self.learnable_additional_gt_tokens: bool = False
|
81 |
+
self.decoder_interpolate_mode: str = "bilinear" # ['bilinear', 'nearest']
|
82 |
+
self.dim_tokens: int = 768
|
83 |
+
self.act_fn = partial(nn.ReLU, inplace=True)
|
84 |
+
self.num_heads: int = 12
|
85 |
+
self.freeze_encoder: bool = False
|
86 |
+
|
87 |
+
"""Data Augmentation"""
|
88 |
+
self.data_augmentation_version: int = 2
|
89 |
+
self.data_augmentation_config = DataAugmentationConfig()
|
90 |
+
|
91 |
+
self.ckpt_path: Optional[str] = None
|
92 |
+
self.description: str = "" # Override this
|
93 |
+
self.embed_dim: int = 6144
|
94 |
+
|
95 |
+
"""Pretrained Backbone"""
|
96 |
+
self.pretrained_backbone: Optional[PRETRAINED_BACKBONE] = (
|
97 |
+
PRETRAINED_BACKBONE.MULTIMAE
|
98 |
+
)
|
99 |
+
|
100 |
+
"""
|
101 |
+
Required only when self.pretrained_backbone in [PRETRAINED_BACKBONE.S_MULTIMAE, PRETRAINED_BACKBONE.LARGE_S_MULTIMAE].
|
102 |
+
Example: 'v1.0.4_e499' stands for version 1.0.4, epoch 499, trained 500 epochs
|
103 |
+
"""
|
104 |
+
self.pretrained_backbone_version: Optional[str] = None
|
105 |
+
|
106 |
+
"""Ground truth
|
107 |
+
V1: 1 head, each head has 1 class, BCE
|
108 |
+
V2: 1 head, each head has 5 classes, CE
|
109 |
+
V3: 5 heads, each head has 1 class, BCE
|
110 |
+
V4: 1 head, each head has 5 classes, BCE
|
111 |
+
V5: additional global token indicates individual thinker
|
112 |
+
"""
|
113 |
+
self.ground_truth_version = 1
|
114 |
+
self.additional_gt_tokens_mlp_channels = []
|
115 |
+
self.num_classes = 1
|
116 |
+
self.actual_num_classes = 1
|
117 |
+
|
118 |
+
self.is_cache = False
|
119 |
+
|
120 |
+
"""Learning rate
|
121 |
+
LR strategy:
|
122 |
+
V1: The ratio of unpretrained and pretrained is also 1:lr_scale
|
123 |
+
V2: The ratio of unpretrained and pretrained is changed gradually from 1:lr_scale -> 1:1
|
124 |
+
"""
|
125 |
+
self.lr_strategy_version = 1
|
126 |
+
self.lr: float
|
127 |
+
self.end_lr: float = 1e-11
|
128 |
+
self.lr_scale: int
|
129 |
+
self.lr_power: float = 0.9
|
130 |
+
|
131 |
+
# Deprecated from v3
|
132 |
+
self.save_checkpoints_after_each_n_epochs: int = 10 # Not used in trainv3
|
133 |
+
|
134 |
+
self.weight_decay = 0.05
|
135 |
+
self.num_workers = 2
|
136 |
+
self.num_epochs_every_restart = 100
|
137 |
+
|
138 |
+
self.betas: Tuple[float, float] = (0.9, 0.999)
|
139 |
+
|
140 |
+
self.input_patch_size: int = 16
|
141 |
+
self.output_patch_size: int = 16 # must be a square of number
|
142 |
+
|
143 |
+
"""Warmup batchsize"""
|
144 |
+
self.warmup_min_batch_size: Optional[int] = None
|
145 |
+
self.warmup_epoch_batch_size: Optional[int] = None
|
146 |
+
|
147 |
+
self.batch_size: int
|
148 |
+
self.val_batch_size: int
|
149 |
+
self.test_batch_size: int = 100
|
150 |
+
self.nepochs: int
|
151 |
+
|
152 |
+
def todict(self):
|
153 |
+
d = dict()
|
154 |
+
for k, v in self.__dict__.items():
|
155 |
+
if not k.startswith("_"):
|
156 |
+
d[k] = v
|
157 |
+
return d
|
158 |
+
|
159 |
+
@property
|
160 |
+
def total_iters_per_epoch(self):
|
161 |
+
return math.ceil(
|
162 |
+
(self.num_training_samples_per_epoch)
|
163 |
+
/ (self.batch_size * len(self.devices))
|
164 |
+
)
|
s_multimae/configs/data_augmentation_config.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class RandomGaussianBlurConfig:
|
2 |
+
def __init__(self, p=0.5, max_gaussian_kernel=19) -> None:
|
3 |
+
self.p = p
|
4 |
+
self.max_gaussian_kernel = max_gaussian_kernel
|
5 |
+
|
6 |
+
|
7 |
+
class DataAugmentationConfig:
|
8 |
+
def __init__(self) -> None:
|
9 |
+
self.mean_normalization = [0.5, 0.5, 0.5]
|
10 |
+
self.std_normalization = [0.5, 0.5, 0.5]
|
11 |
+
self.image_gaussian_config = RandomGaussianBlurConfig(
|
12 |
+
p=0.5,
|
13 |
+
max_gaussian_kernel=19,
|
14 |
+
)
|
15 |
+
self.depth_gaussian_config = RandomGaussianBlurConfig(
|
16 |
+
p=0.5,
|
17 |
+
max_gaussian_kernel=36,
|
18 |
+
)
|
19 |
+
self.random_horizontal_flip_prob = 0.5
|
s_multimae/configs/experiment_config.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Dict, Optional, Type
|
3 |
+
|
4 |
+
from .base_config import base_cfg
|
5 |
+
import importlib, inspect, os
|
6 |
+
from glob import glob
|
7 |
+
|
8 |
+
arg_cfg: Dict[str, Type[base_cfg]] = dict()
|
9 |
+
|
10 |
+
modules = []
|
11 |
+
for p in glob("s_multimae/configs/experiment_configs/*.py"):
|
12 |
+
if not p.startswith("__"):
|
13 |
+
module_name = os.path.splitext(os.path.basename(p))[0]
|
14 |
+
modules.append(f"s_multimae.configs.experiment_configs.{module_name}")
|
15 |
+
|
16 |
+
for module in modules:
|
17 |
+
for name, cls in inspect.getmembers(
|
18 |
+
importlib.import_module(module), inspect.isclass
|
19 |
+
):
|
20 |
+
if name.startswith("cfg"):
|
21 |
+
arg_cfg[name] = cls
|
22 |
+
|
23 |
+
|
24 |
+
def get_config_by_set_version(set_version: int) -> base_cfg:
|
25 |
+
if set_version not in [1, 2, 3, 4]:
|
26 |
+
raise Exception(f"Unsupported set version {set_version}")
|
27 |
+
return arg_cfg[f"cfg_set_{set_version}"]()
|
28 |
+
|
29 |
+
|
30 |
+
def get_config(cfg_name: str, epoch: Optional[int] = None) -> base_cfg:
|
31 |
+
return arg_cfg[cfg_name](epoch)
|
s_multimae/configs/experiment_configs/__init__.py
ADDED
File without changes
|
s_multimae/configs/experiment_configs/expv1_dynamic.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import albumentations as A
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
from typing import Optional
|
5 |
+
from definition import PRETRAINED_BACKBONE
|
6 |
+
from ..base_config import base_cfg
|
7 |
+
|
8 |
+
|
9 |
+
class cfgv4_0_2006(base_cfg):
|
10 |
+
def __init__(self, epoch: Optional[int] = None):
|
11 |
+
super().__init__(epoch, datasets_set=1)
|
12 |
+
|
13 |
+
self.check_val_every_n_epoch = 1
|
14 |
+
self.num_workers = 4
|
15 |
+
self.devices = [0]
|
16 |
+
|
17 |
+
self.description = "ViT Large [DoT]"
|
18 |
+
|
19 |
+
"""MultiMAE"""
|
20 |
+
self.pretrained_backbone = PRETRAINED_BACKBONE.LARGE_S_MULTIMAE
|
21 |
+
self.pretrained_backbone_version = "v2.0.5-pr"
|
22 |
+
|
23 |
+
# Large MAE
|
24 |
+
self.dim_tokens = 1024
|
25 |
+
self.encoder_depth = 24
|
26 |
+
self.num_heads = 16
|
27 |
+
|
28 |
+
self.clip_grad = None
|
29 |
+
self.normalized_depth = False
|
30 |
+
|
31 |
+
"""Decoders"""
|
32 |
+
self.decoder_main_tasks = [["rgb", "depth"]]
|
33 |
+
self.decoder_depth = 10
|
34 |
+
|
35 |
+
# ConvNeXtAdapter
|
36 |
+
self.dec_kernel = 3
|
37 |
+
|
38 |
+
# debug
|
39 |
+
# self.max_train_samples = 80
|
40 |
+
# self.max_dev_samples = 80
|
41 |
+
|
42 |
+
# Diversity of thought
|
43 |
+
self.ground_truth_version = 6
|
44 |
+
self.num_classes = 5 # ignored
|
45 |
+
self.actual_num_classes = 5 # ignored
|
46 |
+
self.additional_gt_tokens_mlp_channels = [768 * 2]
|
47 |
+
|
48 |
+
"""Learning rate"""
|
49 |
+
self.lr = 1e-5
|
50 |
+
self.end_lr = 1e-8
|
51 |
+
self.lr_scale = 100
|
52 |
+
|
53 |
+
self.batch_size = 20
|
54 |
+
self.val_batch_size = 200
|
55 |
+
self.nepochs = 400
|
56 |
+
self.num_epochs_every_restart = 100
|
57 |
+
|
58 |
+
self.data_augmentation_version = 6
|
59 |
+
self.train_function_version = 3
|
60 |
+
self.weight_decay = 5e-2
|
61 |
+
self.transform1 = [
|
62 |
+
A.HorizontalFlip(p=0.5),
|
63 |
+
]
|
64 |
+
|
65 |
+
|
66 |
+
class cfgv4_0_2007(base_cfg):
|
67 |
+
def __init__(self, epoch: Optional[int] = None):
|
68 |
+
super().__init__(epoch, datasets_set=1)
|
69 |
+
|
70 |
+
self.check_val_every_n_epoch = 1
|
71 |
+
self.num_workers = 4
|
72 |
+
self.devices = [0]
|
73 |
+
|
74 |
+
self.description = "ViT Base [DoT]"
|
75 |
+
|
76 |
+
"""MultiMAE"""
|
77 |
+
self.pretrained_backbone = PRETRAINED_BACKBONE.S_MULTIMAE
|
78 |
+
self.pretrained_backbone_version = "v2.0.1-pr"
|
79 |
+
|
80 |
+
self.clip_grad = None
|
81 |
+
self.normalized_depth = False
|
82 |
+
|
83 |
+
"""Decoders"""
|
84 |
+
self.decoder_main_tasks = [["rgb", "depth"]]
|
85 |
+
self.decoder_depth = 10
|
86 |
+
|
87 |
+
# ConvNeXtAdapter
|
88 |
+
self.dec_kernel = 3
|
89 |
+
|
90 |
+
# debug
|
91 |
+
# self.max_train_samples = 80
|
92 |
+
# self.max_dev_samples = 80
|
93 |
+
|
94 |
+
# Diversity of thought
|
95 |
+
self.ground_truth_version = 6
|
96 |
+
self.num_classes = 5 # ignored
|
97 |
+
self.actual_num_classes = 5 # ignored
|
98 |
+
self.additional_gt_tokens_mlp_channels = [768 * 2]
|
99 |
+
|
100 |
+
"""Learning rate"""
|
101 |
+
self.lr = 1e-5
|
102 |
+
self.end_lr = 1e-8
|
103 |
+
self.lr_scale = 100
|
104 |
+
|
105 |
+
self.batch_size = 40
|
106 |
+
self.val_batch_size = 200
|
107 |
+
self.nepochs = 400
|
108 |
+
self.num_epochs_every_restart = 100
|
109 |
+
|
110 |
+
self.data_augmentation_version = 6
|
111 |
+
self.train_function_version = 3
|
112 |
+
self.weight_decay = 5e-2
|
113 |
+
self.transform1 = [
|
114 |
+
A.HorizontalFlip(p=0.5),
|
115 |
+
A.OneOf(
|
116 |
+
[
|
117 |
+
A.Compose(
|
118 |
+
[
|
119 |
+
A.RandomCropFromBorders(
|
120 |
+
crop_left=0.3,
|
121 |
+
crop_right=0.3,
|
122 |
+
crop_top=0.3,
|
123 |
+
crop_bottom=0.3,
|
124 |
+
p=0.2,
|
125 |
+
),
|
126 |
+
A.ShiftScaleRotate(
|
127 |
+
shift_limit=0.0625,
|
128 |
+
scale_limit=0.1,
|
129 |
+
rotate_limit=45,
|
130 |
+
p=0.1,
|
131 |
+
),
|
132 |
+
A.Perspective(
|
133 |
+
p=0.2,
|
134 |
+
scale=(0.05, 0.1),
|
135 |
+
),
|
136 |
+
]
|
137 |
+
),
|
138 |
+
A.Compose(
|
139 |
+
[
|
140 |
+
A.RandomCropFromBorders(
|
141 |
+
crop_left=0.3,
|
142 |
+
crop_right=0.3,
|
143 |
+
crop_top=0.3,
|
144 |
+
crop_bottom=0.3,
|
145 |
+
p=0.2,
|
146 |
+
),
|
147 |
+
A.ShiftScaleRotate(
|
148 |
+
shift_limit=0.0625,
|
149 |
+
scale_limit=0.1,
|
150 |
+
rotate_limit=45,
|
151 |
+
p=0.1,
|
152 |
+
border_mode=cv2.BORDER_CONSTANT,
|
153 |
+
value=(255, 255, 255),
|
154 |
+
mask_value=0,
|
155 |
+
),
|
156 |
+
A.Perspective(
|
157 |
+
p=0.2,
|
158 |
+
scale=(0.05, 0.1),
|
159 |
+
pad_mode=cv2.BORDER_CONSTANT,
|
160 |
+
pad_val=(255, 255, 255),
|
161 |
+
mask_pad_val=0,
|
162 |
+
),
|
163 |
+
]
|
164 |
+
),
|
165 |
+
]
|
166 |
+
),
|
167 |
+
]
|
168 |
+
|
169 |
+
|
170 |
+
class cfgv4_0_2002(base_cfg):
|
171 |
+
def __init__(self, epoch: Optional[int] = None):
|
172 |
+
super().__init__(epoch, datasets_set=1)
|
173 |
+
|
174 |
+
self.check_val_every_n_epoch = 1
|
175 |
+
self.num_workers = 4
|
176 |
+
self.devices = [3]
|
177 |
+
|
178 |
+
# self.description = "Trainv3-DAv6-DiversityOfThought-NotMuchAug"
|
179 |
+
self.description = "DEBUG"
|
180 |
+
|
181 |
+
"""MultiMAE"""
|
182 |
+
self.pretrained_backbone = PRETRAINED_BACKBONE.S_MULTIMAE
|
183 |
+
self.pretrained_backbone_version = "v2.0.1-pr"
|
184 |
+
|
185 |
+
# Large MAE
|
186 |
+
# self.dim_tokens = 1024
|
187 |
+
# self.encoder_depth = 24
|
188 |
+
# self.num_heads = 16
|
189 |
+
|
190 |
+
self.clip_grad = None
|
191 |
+
self.normalized_depth = False
|
192 |
+
|
193 |
+
"""Decoders"""
|
194 |
+
self.decoder_main_tasks = [["rgb", "depth"]]
|
195 |
+
self.decoder_depth = 10
|
196 |
+
|
197 |
+
# ConvNeXtAdapter
|
198 |
+
self.dec_kernel = 3
|
199 |
+
|
200 |
+
# debug
|
201 |
+
self.max_train_samples = 20
|
202 |
+
self.max_dev_samples = 20
|
203 |
+
|
204 |
+
# Diversity of thought
|
205 |
+
self.ground_truth_version = 6
|
206 |
+
self.num_classes = 5 # ignored
|
207 |
+
self.actual_num_classes = 5 # ignored
|
208 |
+
self.additional_gt_tokens_mlp_channels = [768 * 2]
|
209 |
+
|
210 |
+
"""Learning rate"""
|
211 |
+
self.lr = 1e-5
|
212 |
+
self.end_lr = 1e-8
|
213 |
+
self.lr_scale = 100
|
214 |
+
|
215 |
+
self.batch_size = 5
|
216 |
+
self.val_batch_size = 5
|
217 |
+
self.nepochs = 400
|
218 |
+
self.num_epochs_every_restart = 100
|
219 |
+
|
220 |
+
self.data_augmentation_version = 6
|
221 |
+
self.train_function_version = 3
|
222 |
+
self.weight_decay = 5e-2
|
223 |
+
self.transform1 = [
|
224 |
+
A.HorizontalFlip(p=0.5),
|
225 |
+
A.OneOf(
|
226 |
+
[
|
227 |
+
A.Compose(
|
228 |
+
[
|
229 |
+
A.RandomCropFromBorders(
|
230 |
+
crop_left=0.3,
|
231 |
+
crop_right=0.3,
|
232 |
+
crop_top=0.3,
|
233 |
+
crop_bottom=0.3,
|
234 |
+
p=0.2,
|
235 |
+
),
|
236 |
+
A.ShiftScaleRotate(
|
237 |
+
shift_limit=0.0625,
|
238 |
+
scale_limit=0.1,
|
239 |
+
rotate_limit=45,
|
240 |
+
p=0.1,
|
241 |
+
),
|
242 |
+
A.Perspective(
|
243 |
+
p=0.2,
|
244 |
+
scale=(0.05, 0.1),
|
245 |
+
),
|
246 |
+
]
|
247 |
+
),
|
248 |
+
A.Compose(
|
249 |
+
[
|
250 |
+
A.RandomCropFromBorders(
|
251 |
+
crop_left=0.3,
|
252 |
+
crop_right=0.3,
|
253 |
+
crop_top=0.3,
|
254 |
+
crop_bottom=0.3,
|
255 |
+
p=0.2,
|
256 |
+
),
|
257 |
+
A.ShiftScaleRotate(
|
258 |
+
shift_limit=0.0625,
|
259 |
+
scale_limit=0.1,
|
260 |
+
rotate_limit=45,
|
261 |
+
p=0.1,
|
262 |
+
border_mode=cv2.BORDER_CONSTANT,
|
263 |
+
value=(255, 255, 255),
|
264 |
+
mask_value=0,
|
265 |
+
),
|
266 |
+
A.Perspective(
|
267 |
+
p=0.2,
|
268 |
+
scale=(0.05, 0.1),
|
269 |
+
pad_mode=cv2.BORDER_CONSTANT,
|
270 |
+
pad_val=(255, 255, 255),
|
271 |
+
mask_pad_val=0,
|
272 |
+
),
|
273 |
+
]
|
274 |
+
),
|
275 |
+
]
|
276 |
+
),
|
277 |
+
]
|
s_multimae/da/__init__.py
ADDED
File without changes
|
s_multimae/da/base_da.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
from typing import List, Optional, Tuple
|
3 |
+
from torch import nn, Tensor
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
class BaseDataAugmentation(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super(BaseDataAugmentation, self).__init__()
|
10 |
+
|
11 |
+
@abc.abstractmethod
|
12 |
+
def forward(
|
13 |
+
self,
|
14 |
+
image: Image.Image,
|
15 |
+
depth: Image.Image,
|
16 |
+
gt: Optional[Image.Image] = None,
|
17 |
+
ranking_gt: Optional[Image.Image] = None,
|
18 |
+
multi_gts: Optional[List[Image.Image]] = None,
|
19 |
+
is_transform: bool = True, # is augmented?
|
20 |
+
is_debug: bool = False,
|
21 |
+
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
|
22 |
+
"""
|
23 |
+
Usual case:
|
24 |
+
If gt is provided, return [image, depth, gt]
|
25 |
+
Otherwise, return [image, depth]
|
26 |
+
|
27 |
+
When ranking_gt is provided, gt will be ignored
|
28 |
+
Return [image, depth, ranking_gt]
|
29 |
+
|
30 |
+
For debugging:
|
31 |
+
Return [image, depth, gt|ranking_gt, unnormalized, Optional[ranking_gts]]
|
32 |
+
"""
|
33 |
+
pass
|
s_multimae/da/dav6.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
from torchvision import transforms
|
5 |
+
import albumentations as A
|
6 |
+
import torch
|
7 |
+
from torch import Tensor
|
8 |
+
|
9 |
+
from ..configs.base_config import base_cfg
|
10 |
+
from .base_da import BaseDataAugmentation
|
11 |
+
|
12 |
+
|
13 |
+
class DataAugmentationV6(BaseDataAugmentation):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
cfg: base_cfg,
|
17 |
+
is_padding=True,
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
self.image_size = cfg.image_size
|
21 |
+
self.is_padding = is_padding
|
22 |
+
self.cfg = cfg
|
23 |
+
|
24 |
+
self.to_tensor = transforms.ToTensor()
|
25 |
+
|
26 |
+
self.additional_targets = {
|
27 |
+
"depth": "image",
|
28 |
+
"gt": "mask",
|
29 |
+
"ranking_gt": "mask",
|
30 |
+
"multi_gts": "mask",
|
31 |
+
}
|
32 |
+
|
33 |
+
# For rgb+depth+gt
|
34 |
+
self.transform1 = A.Compose(
|
35 |
+
cfg.transform1,
|
36 |
+
additional_targets=self.additional_targets,
|
37 |
+
)
|
38 |
+
|
39 |
+
# For rgb only
|
40 |
+
self.transform2 = A.Compose(
|
41 |
+
[
|
42 |
+
A.GaussianBlur(p=0.5, blur_limit=(3, 19)),
|
43 |
+
A.RandomBrightnessContrast(p=0.5),
|
44 |
+
A.ColorJitter(p=0.5),
|
45 |
+
]
|
46 |
+
)
|
47 |
+
|
48 |
+
# For depth only
|
49 |
+
self.transform3 = A.Compose([A.GaussianBlur(p=0.5, blur_limit=(3, 37))])
|
50 |
+
|
51 |
+
# For rgb+depth+gt
|
52 |
+
self.transform4 = A.Compose(
|
53 |
+
[A.Resize(self.image_size, self.image_size)],
|
54 |
+
additional_targets=self.additional_targets,
|
55 |
+
is_check_shapes=False,
|
56 |
+
)
|
57 |
+
|
58 |
+
# For rgb only
|
59 |
+
self.transform5 = A.Compose([A.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
60 |
+
|
61 |
+
# For depth only
|
62 |
+
self.transform6 = A.Compose([A.Normalize(0.5, 0.5)])
|
63 |
+
|
64 |
+
def forward(
|
65 |
+
self,
|
66 |
+
image: Image.Image,
|
67 |
+
depth: Image.Image,
|
68 |
+
gt: Optional[Image.Image] = None,
|
69 |
+
ranking_gt: Optional[Image.Image] = None,
|
70 |
+
multi_gts: Optional[List[Image.Image]] = None,
|
71 |
+
is_transform: bool = True, # is augmented?
|
72 |
+
is_debug: bool = False,
|
73 |
+
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
|
74 |
+
## 1. Convert to numpy array: image, depth, gt, ranking_gts
|
75 |
+
image = np.array(image)
|
76 |
+
depth = np.array(depth)
|
77 |
+
d = dict(image=image, depth=depth)
|
78 |
+
if gt is not None:
|
79 |
+
gt = np.array(gt)
|
80 |
+
d["gt"] = gt
|
81 |
+
|
82 |
+
if not is_transform:
|
83 |
+
# Dev or Test
|
84 |
+
d = self.transform4(**d)
|
85 |
+
d["image"] = self.transform5(image=d["image"])["image"]
|
86 |
+
# d["depth"] = self.transform6(image=depth)["image"]
|
87 |
+
if gt is not None:
|
88 |
+
return self.to_tensors([d["image"], d["depth"], d["gt"]])
|
89 |
+
else:
|
90 |
+
return self.to_tensors([d["image"], d["depth"]])
|
91 |
+
|
92 |
+
d["depth"] = 255 - d["depth"] # inverse depth
|
93 |
+
|
94 |
+
# if ranking_gt is not None and multi_gts is not None:
|
95 |
+
# print('[WARN] Both ranking_gt and multi_gts are not none, but we prioritize multi_gts')
|
96 |
+
|
97 |
+
if ranking_gt is not None:
|
98 |
+
ranking_gt = np.array(ranking_gt)
|
99 |
+
|
100 |
+
if multi_gts is not None:
|
101 |
+
multi_gts = np.stack(multi_gts, axis=2)
|
102 |
+
d["multi_gts"] = multi_gts
|
103 |
+
|
104 |
+
## 2. First transformation for image (Contrast, GaussianBlur,...), depth (GaussianBlur,...)
|
105 |
+
d["image"] = self.transform2(image=d["image"])["image"]
|
106 |
+
d["depth"] = self.transform3(image=d["depth"])["image"]
|
107 |
+
|
108 |
+
## 3. Transformation defined in config: change perspective, rotation, size, ...
|
109 |
+
d = self.transform1(**d)
|
110 |
+
|
111 |
+
## 4. Resize
|
112 |
+
d = self.transform4(**d)
|
113 |
+
|
114 |
+
## Just backup image before normalizing it
|
115 |
+
if is_debug:
|
116 |
+
unnormalized_image = d["image"]
|
117 |
+
|
118 |
+
## 6. Construct response
|
119 |
+
d["depth"] = 255 - d["depth"] # inverse depth
|
120 |
+
d["image"] = self.transform5(image=d["image"])["image"]
|
121 |
+
# d["depth"] = self.transform6(image=depth)["image"]
|
122 |
+
rs = self.to_tensors([d["image"], d["depth"]])
|
123 |
+
if multi_gts is not None:
|
124 |
+
rs += self.to_tensors([d["multi_gts"]])
|
125 |
+
elif ranking_gt is not None:
|
126 |
+
rs += [torch.from_numpy(d["ranking_gt"]).to(torch.long)]
|
127 |
+
else:
|
128 |
+
rs += self.to_tensors([d["gt"]])
|
129 |
+
|
130 |
+
## 7. For debug only
|
131 |
+
if is_debug:
|
132 |
+
rs.append(unnormalized_image)
|
133 |
+
|
134 |
+
if ranking_gt is not None:
|
135 |
+
ranking_gts = []
|
136 |
+
for i in range(self.cfg.num_classes):
|
137 |
+
ranking_gts.append(
|
138 |
+
np.array(d["ranking_gt"] == i).astype(np.uint8) * 255
|
139 |
+
)
|
140 |
+
rs.append(ranking_gts)
|
141 |
+
if multi_gts is not None:
|
142 |
+
rs.append(d["multi_gts"])
|
143 |
+
|
144 |
+
return rs
|
145 |
+
|
146 |
+
def to_tensors(self, lst: List[Tensor]) -> List[Tensor]:
|
147 |
+
return [self.to_tensor(e) for e in lst]
|
s_multimae/data_augmentation.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
from .configs.base_config import base_cfg
|
4 |
+
from .da.dav6 import DataAugmentationV6
|
5 |
+
from .da.base_da import BaseDataAugmentation
|
6 |
+
|
7 |
+
|
8 |
+
def get_data_augmentation(
|
9 |
+
cfg: base_cfg,
|
10 |
+
image_size: int,
|
11 |
+
is_padding: bool,
|
12 |
+
) -> BaseDataAugmentation:
|
13 |
+
if cfg.data_augmentation_version == 6:
|
14 |
+
print("Using DataAugmentationV6")
|
15 |
+
return DataAugmentationV6(cfg)
|
16 |
+
else:
|
17 |
+
raise NotImplementedError(
|
18 |
+
f"Unsupported DataAugmentation version {cfg.data_augmentation_version}"
|
19 |
+
)
|
s_multimae/model/__init__.py
ADDED
File without changes
|
s_multimae/model/components.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor
|
3 |
+
import math
|
4 |
+
import warnings
|
5 |
+
from typing import Tuple, Union
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
import os
|
9 |
+
|
10 |
+
from definition import PRETRAINED_BACKBONE
|
11 |
+
from ..configs.base_config import base_cfg
|
12 |
+
|
13 |
+
|
14 |
+
def pair(t: Union[int, Tuple[int, int]]) -> Tuple[int, int]:
|
15 |
+
return t if isinstance(t, tuple) else (t, t)
|
16 |
+
|
17 |
+
|
18 |
+
def build_2d_sincos_posemb(h: int, w: int, embed_dim=1024, temperature=10000.0):
|
19 |
+
"""Sine-cosine positional embeddings from MoCo-v3
|
20 |
+
|
21 |
+
Source: https://github.com/facebookresearch/moco-v3/blob/main/vits.py
|
22 |
+
"""
|
23 |
+
grid_w = torch.arange(w, dtype=torch.float32)
|
24 |
+
grid_h = torch.arange(h, dtype=torch.float32)
|
25 |
+
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
26 |
+
assert (
|
27 |
+
embed_dim % 4 == 0
|
28 |
+
), "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
|
29 |
+
pos_dim = embed_dim // 4
|
30 |
+
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
|
31 |
+
omega = 1.0 / (temperature**omega)
|
32 |
+
out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega])
|
33 |
+
out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega])
|
34 |
+
pos_emb = torch.cat(
|
35 |
+
[torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1
|
36 |
+
)[None, :, :]
|
37 |
+
pos_emb = rearrange(pos_emb, "b (h w) d -> b d h w", h=h, w=w, d=embed_dim)
|
38 |
+
return pos_emb
|
39 |
+
|
40 |
+
|
41 |
+
def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float, b: float):
|
42 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
43 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
44 |
+
def norm_cdf(x):
|
45 |
+
# Computes standard normal cumulative distribution function
|
46 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
47 |
+
|
48 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
49 |
+
warnings.warn(
|
50 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
51 |
+
"The distribution of values may be incorrect.",
|
52 |
+
stacklevel=2,
|
53 |
+
)
|
54 |
+
|
55 |
+
with torch.no_grad():
|
56 |
+
# Values are generated by using a truncated uniform distribution and
|
57 |
+
# then using the inverse CDF for the normal distribution.
|
58 |
+
# Get upper and lower cdf values
|
59 |
+
l = norm_cdf((a - mean) / std)
|
60 |
+
u = norm_cdf((b - mean) / std)
|
61 |
+
|
62 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
63 |
+
# [2l-1, 2u-1].
|
64 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
65 |
+
|
66 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
67 |
+
# standard normal
|
68 |
+
tensor.erfinv_()
|
69 |
+
|
70 |
+
# Transform to proper mean, std
|
71 |
+
tensor.mul_(std * math.sqrt(2.0))
|
72 |
+
tensor.add_(mean)
|
73 |
+
|
74 |
+
# Clamp to ensure it's in the proper range
|
75 |
+
tensor.clamp_(min=a, max=b)
|
76 |
+
return tensor
|
77 |
+
|
78 |
+
|
79 |
+
def trunc_normal_(tensor: Tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
80 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
81 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
82 |
+
normal distribution. The values are effectively drawn from the
|
83 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
84 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
85 |
+
the bounds. The method used for generating the random values works
|
86 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
87 |
+
Args:
|
88 |
+
tensor: an n-dimensional `Tensor`
|
89 |
+
mean: the mean of the normal distribution
|
90 |
+
std: the standard deviation of the normal distribution
|
91 |
+
a: the minimum cutoff value
|
92 |
+
b: the maximum cutoff value
|
93 |
+
Examples:
|
94 |
+
>>> w = torch.empty(3, 5)
|
95 |
+
>>> nn.init.trunc_normal_(w)
|
96 |
+
"""
|
97 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
98 |
+
|
99 |
+
|
100 |
+
def drop_path(x: Tensor, drop_prob: float = 0.0, training: bool = False):
|
101 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
102 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
103 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
104 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
105 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
106 |
+
'survival rate' as the argument.
|
107 |
+
"""
|
108 |
+
if drop_prob == 0.0 or not training:
|
109 |
+
return x
|
110 |
+
keep_prob = 1 - drop_prob
|
111 |
+
shape = (x.shape[0],) + (1,) * (
|
112 |
+
x.ndim - 1
|
113 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
114 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
115 |
+
random_tensor.floor_() # binarize
|
116 |
+
output = x.div(keep_prob) * random_tensor
|
117 |
+
return output
|
s_multimae/model/multimae.py
ADDED
@@ -0,0 +1,938 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import re
|
3 |
+
from collections import OrderedDict
|
4 |
+
from functools import partial
|
5 |
+
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torchvision.ops import MLP
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
from torch import Tensor, nn
|
11 |
+
|
12 |
+
from definition import PRETRAINED_BACKBONE
|
13 |
+
from ..configs.base_config import base_cfg
|
14 |
+
from ..utils import count_parameters
|
15 |
+
from .components import (
|
16 |
+
build_2d_sincos_posemb,
|
17 |
+
drop_path,
|
18 |
+
pair,
|
19 |
+
trunc_normal_,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
class PatchedInputAdapter(nn.Module):
|
24 |
+
"""Adapter for spatial inputs, like images or feature maps.
|
25 |
+
Creates tokens from patches over the image.
|
26 |
+
|
27 |
+
:param num_channels: Number of input channels of the image/feature map
|
28 |
+
:param stride_level: Stride level compared to the full-sized image.
|
29 |
+
E.g. 4 for 1/4th the size of the image.
|
30 |
+
:param patch_size_full: Int or tuple of the patch size over the full image size.
|
31 |
+
Patch size for smaller inputs will be computed accordingly.
|
32 |
+
:param dim_tokens: Dimension of output tokens. Can be set using init method.
|
33 |
+
:param sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings
|
34 |
+
:param learnable_pos_emb: Set to True to learn positional embeddings instead
|
35 |
+
:param image_size: Default image size. Used to initialize size of positional embeddings.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
num_channels: int,
|
41 |
+
stride_level: int,
|
42 |
+
patch_size_full: Union[int, Tuple[int, int]],
|
43 |
+
dim_tokens: Optional[int] = None,
|
44 |
+
sincos_pos_emb: bool = True,
|
45 |
+
learnable_pos_emb: bool = False,
|
46 |
+
image_size: Union[int, Tuple[int]] = 224,
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
self.num_channels = num_channels
|
50 |
+
self.stride_level = stride_level
|
51 |
+
self.patch_size_full = pair(patch_size_full)
|
52 |
+
self.dim_tokens = dim_tokens
|
53 |
+
self.sincos_pos_emb = sincos_pos_emb
|
54 |
+
self.learnable_pos_emb = learnable_pos_emb
|
55 |
+
self.image_size = pair(image_size)
|
56 |
+
self.num_patches = (self.image_size[0] // patch_size_full) * (
|
57 |
+
self.image_size[1] // patch_size_full
|
58 |
+
)
|
59 |
+
|
60 |
+
# Actual patch height and width, taking into account stride of input
|
61 |
+
self.P_H = max(1, self.patch_size_full[0] // stride_level)
|
62 |
+
self.P_W = max(1, self.patch_size_full[1] // stride_level)
|
63 |
+
|
64 |
+
if self.dim_tokens is not None:
|
65 |
+
self.init(dim_tokens=dim_tokens)
|
66 |
+
|
67 |
+
def init(self, dim_tokens: int = 768):
|
68 |
+
"""
|
69 |
+
Initialize parts of encoder that are dependent on dimension of tokens.
|
70 |
+
Should be called when setting up MultiMAE.
|
71 |
+
|
72 |
+
:param dim_tokens: Dimension of tokens
|
73 |
+
"""
|
74 |
+
self.dim_tokens = dim_tokens
|
75 |
+
|
76 |
+
# Task embedding identifying from which task a given token comes from
|
77 |
+
# Fixed-size positional embeddings. Can be interpolated to different input sizes
|
78 |
+
h_posemb = self.image_size[0] // (self.stride_level * self.P_H)
|
79 |
+
w_posemb = self.image_size[1] // (self.stride_level * self.P_W)
|
80 |
+
if self.sincos_pos_emb:
|
81 |
+
self.pos_emb = build_2d_sincos_posemb(
|
82 |
+
h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens
|
83 |
+
)
|
84 |
+
self.pos_emb = nn.Parameter(
|
85 |
+
self.pos_emb, requires_grad=self.learnable_pos_emb
|
86 |
+
)
|
87 |
+
else:
|
88 |
+
self.pos_emb = nn.Parameter(
|
89 |
+
torch.zeros(1, self.dim_tokens, h_posemb, w_posemb)
|
90 |
+
)
|
91 |
+
trunc_normal_(self.pos_emb, std=0.02)
|
92 |
+
|
93 |
+
# Image -> tokens projection
|
94 |
+
self.proj = nn.Conv2d(
|
95 |
+
in_channels=self.num_channels,
|
96 |
+
out_channels=self.dim_tokens,
|
97 |
+
kernel_size=(self.P_H, self.P_W),
|
98 |
+
stride=(self.P_H, self.P_W),
|
99 |
+
)
|
100 |
+
|
101 |
+
@torch.jit.ignore
|
102 |
+
def no_weight_decay(self):
|
103 |
+
return {"pos_emb"}
|
104 |
+
|
105 |
+
def forward(self, x: Tensor) -> Tensor:
|
106 |
+
"""
|
107 |
+
Forward pass through input adapter, transforming image to sequence of tokens.
|
108 |
+
Adds task and positional encodings.
|
109 |
+
|
110 |
+
:param x: Input image tensor
|
111 |
+
"""
|
112 |
+
B, C, H, W = x.shape
|
113 |
+
assert (
|
114 |
+
self.dim_tokens is not None
|
115 |
+
), "Need to call init(dim_tokens) function first"
|
116 |
+
assert (H % self.P_H == 0) and (
|
117 |
+
W % self.P_W == 0
|
118 |
+
), f"Image sizes {H}x{W} must be divisible by patch sizes {self.P_H}x{self.P_W}"
|
119 |
+
N_H, N_W = H // self.P_H, W // self.P_W # Number of patches in height and width
|
120 |
+
|
121 |
+
# Create patches [B, C, H, W] -> [B, (H*W), C]
|
122 |
+
projected_x = self.proj(x)
|
123 |
+
x_patch = rearrange(projected_x, "b d nh nw -> b (nh nw) d")
|
124 |
+
|
125 |
+
# Create positional embedding
|
126 |
+
x_pos_emb = F.interpolate(
|
127 |
+
self.pos_emb, size=(N_H, N_W), mode="bicubic", align_corners=False
|
128 |
+
)
|
129 |
+
x_pos_emb = rearrange(x_pos_emb, "b d nh nw -> b (nh nw) d")
|
130 |
+
|
131 |
+
# Add patches and positional embeddings
|
132 |
+
x = x_patch + x_pos_emb
|
133 |
+
|
134 |
+
return x
|
135 |
+
|
136 |
+
|
137 |
+
class DropPath(nn.Module):
|
138 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
139 |
+
|
140 |
+
def __init__(self, drop_prob=None):
|
141 |
+
super(DropPath, self).__init__()
|
142 |
+
self.drop_prob = drop_prob
|
143 |
+
|
144 |
+
def forward(self, x: Tensor) -> Tensor:
|
145 |
+
return drop_path(x, self.drop_prob, self.training)
|
146 |
+
|
147 |
+
def extra_repr(self) -> str:
|
148 |
+
return "p={}".format(self.drop_prob)
|
149 |
+
|
150 |
+
|
151 |
+
class ConvNeXtBlock(nn.Module):
|
152 |
+
r"""ConvNeXt Block. There are two equivalent implementations:
|
153 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
154 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
155 |
+
We use (2) as we find it slightly faster in PyTorch
|
156 |
+
|
157 |
+
Args:
|
158 |
+
dim (int): Number of input channels.
|
159 |
+
drop_path: Stochastic depth rate. Default: 0.0
|
160 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 0 (disabled for isotropic ConvNeXt).
|
161 |
+
|
162 |
+
Code from: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py
|
163 |
+
"""
|
164 |
+
|
165 |
+
def __init__(self, dim, drop_path=0.0, layer_scale_init_value=0.0):
|
166 |
+
super().__init__()
|
167 |
+
self.dwconv = nn.Conv2d(
|
168 |
+
dim, dim, kernel_size=7, padding=3, groups=dim
|
169 |
+
) # depthwise conv
|
170 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
171 |
+
self.pwconv1 = nn.Linear(
|
172 |
+
dim, 4 * dim
|
173 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
174 |
+
self.act = nn.GELU()
|
175 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
176 |
+
self.gamma = (
|
177 |
+
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
178 |
+
if layer_scale_init_value > 0
|
179 |
+
else None
|
180 |
+
)
|
181 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
182 |
+
|
183 |
+
def forward(self, x: Tensor) -> Tensor:
|
184 |
+
input = x
|
185 |
+
x = self.dwconv(x)
|
186 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
187 |
+
x = self.norm(x)
|
188 |
+
x = self.pwconv1(x)
|
189 |
+
x = self.act(x)
|
190 |
+
x = self.pwconv2(x)
|
191 |
+
if self.gamma is not None:
|
192 |
+
x = self.gamma * x
|
193 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
194 |
+
|
195 |
+
x = input + self.drop_path(x)
|
196 |
+
return x
|
197 |
+
|
198 |
+
|
199 |
+
class ConvNeXtAdapter(nn.Module):
|
200 |
+
"""Output adapter with ConvNext blocks for semantic segmentation
|
201 |
+
|
202 |
+
:param num_classes: Number of classes
|
203 |
+
:param num_heads: Number of attention heads
|
204 |
+
:param embed_dim: Token dimension after projection, and before reshaping operation.
|
205 |
+
:param preds_per_patch: Increases size of feature map by reshaping each patch Each patch gets reshaped
|
206 |
+
from embed_dim x 1 x 1 to (embed_dim / preds_per_patch) x (preds_per_patch ** 0.5) x (preds_per_patch ** 0.5)
|
207 |
+
:param main_tasks: Tasks to use for the adapter. Only tokens coming from these tasks are kept.
|
208 |
+
:param patch_size: Size of patches
|
209 |
+
:param depth: Number of ConvNeXt blocks
|
210 |
+
:interpolate_mode: Interpolation mode for final upsampling
|
211 |
+
"""
|
212 |
+
|
213 |
+
def __init__(
|
214 |
+
self,
|
215 |
+
image_size: int,
|
216 |
+
num_classes: int,
|
217 |
+
embed_dim: int = 6144,
|
218 |
+
preds_per_patch: int = 16,
|
219 |
+
main_tasks: Iterable[str] = ("rgb",),
|
220 |
+
patch_size: int = 16,
|
221 |
+
depth: int = 4,
|
222 |
+
interpolate_mode: str = "bilinear",
|
223 |
+
act_fn: nn.Module = nn.GELU,
|
224 |
+
dec_kernel: int = 1,
|
225 |
+
):
|
226 |
+
super().__init__()
|
227 |
+
self.main_tasks = main_tasks
|
228 |
+
self.patch_size = patch_size
|
229 |
+
self.embed_dim = embed_dim
|
230 |
+
self.preds_per_patch = preds_per_patch
|
231 |
+
self.class_dim = embed_dim // preds_per_patch
|
232 |
+
self.num_classes = num_classes
|
233 |
+
self.interpolate_mode = interpolate_mode
|
234 |
+
self.image_size = image_size
|
235 |
+
|
236 |
+
self.blocks = nn.Sequential(
|
237 |
+
*[ConvNeXtBlock(dim=self.class_dim) for _ in range(depth)]
|
238 |
+
)
|
239 |
+
if dec_kernel == 1:
|
240 |
+
self.final_layer_1 = nn.Sequential(
|
241 |
+
nn.Conv2d(self.class_dim, self.class_dim // 4, 1),
|
242 |
+
nn.BatchNorm2d(self.class_dim // 4),
|
243 |
+
act_fn(),
|
244 |
+
nn.Upsample(scale_factor=2, mode=self.interpolate_mode),
|
245 |
+
)
|
246 |
+
|
247 |
+
self.final_layer_2 = nn.Sequential(
|
248 |
+
nn.Conv2d(self.class_dim // 4, self.class_dim // 16, 1),
|
249 |
+
nn.BatchNorm2d(self.class_dim // 16),
|
250 |
+
act_fn(),
|
251 |
+
nn.Upsample(size=image_size, mode=self.interpolate_mode),
|
252 |
+
)
|
253 |
+
|
254 |
+
self.final_layer = nn.Conv2d(self.class_dim // 16, self.num_classes, 1)
|
255 |
+
elif dec_kernel == 3:
|
256 |
+
self.final_layer_1 = nn.Sequential(
|
257 |
+
nn.Conv2d(
|
258 |
+
self.class_dim,
|
259 |
+
self.class_dim // 4,
|
260 |
+
kernel_size=3,
|
261 |
+
stride=1,
|
262 |
+
padding=1,
|
263 |
+
),
|
264 |
+
nn.BatchNorm2d(self.class_dim // 4),
|
265 |
+
act_fn(),
|
266 |
+
nn.Upsample(scale_factor=2, mode=self.interpolate_mode),
|
267 |
+
)
|
268 |
+
|
269 |
+
self.final_layer_2 = nn.Sequential(
|
270 |
+
nn.Conv2d(
|
271 |
+
self.class_dim // 4,
|
272 |
+
self.class_dim // 16,
|
273 |
+
kernel_size=3,
|
274 |
+
stride=1,
|
275 |
+
padding=1,
|
276 |
+
),
|
277 |
+
nn.BatchNorm2d(self.class_dim // 16),
|
278 |
+
act_fn(),
|
279 |
+
nn.Upsample(size=image_size, mode=self.interpolate_mode),
|
280 |
+
)
|
281 |
+
|
282 |
+
self.final_layer = nn.Conv2d(
|
283 |
+
self.class_dim // 16,
|
284 |
+
self.num_classes,
|
285 |
+
kernel_size=3,
|
286 |
+
stride=1,
|
287 |
+
padding=1,
|
288 |
+
)
|
289 |
+
else:
|
290 |
+
raise Exception(f"Unsupported dec_kernel {dec_kernel}")
|
291 |
+
|
292 |
+
self.apply(self._init_weights)
|
293 |
+
|
294 |
+
def init(self, dim_tokens_enc: int = 768):
|
295 |
+
"""
|
296 |
+
Initialize parts of decoder that are dependent on dimension of encoder tokens.
|
297 |
+
Should be called when setting up MultiMAE.
|
298 |
+
|
299 |
+
:param dim_tokens_enc: Dimension of tokens coming from encoder
|
300 |
+
"""
|
301 |
+
self.in_channels = dim_tokens_enc * len(self.main_tasks)
|
302 |
+
|
303 |
+
# Projection of encoder tokens to the patch dimension
|
304 |
+
self.proj_dec = nn.Linear(self.in_channels, self.embed_dim)
|
305 |
+
self._init_weights(self.proj_dec)
|
306 |
+
|
307 |
+
def _init_weights(self, m: nn.Module):
|
308 |
+
if isinstance(m, nn.Linear):
|
309 |
+
trunc_normal_(m.weight, std=0.02)
|
310 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
311 |
+
nn.init.constant_(m.bias, 0)
|
312 |
+
elif isinstance(m, nn.LayerNorm):
|
313 |
+
nn.init.constant_(m.bias, 0)
|
314 |
+
nn.init.constant_(m.weight, 1.0)
|
315 |
+
|
316 |
+
def adapt_tokens(self, encoder_tokens: Tensor, input_info: Dict):
|
317 |
+
# Adapt tokens
|
318 |
+
x = []
|
319 |
+
for task in self.main_tasks:
|
320 |
+
start_idx = input_info["tasks"][task]["start_idx"]
|
321 |
+
end_idx = input_info["tasks"][task]["end_idx"]
|
322 |
+
x.append(encoder_tokens[:, start_idx:end_idx])
|
323 |
+
|
324 |
+
x = torch.cat(x, dim=-1)
|
325 |
+
return x
|
326 |
+
|
327 |
+
def forward(self, encoder_tokens: Tensor, input_info: Dict) -> Tensor:
|
328 |
+
H, W = input_info["image_size"]
|
329 |
+
N_H, N_W = H // self.patch_size, W // self.patch_size
|
330 |
+
|
331 |
+
x = self.adapt_tokens(encoder_tokens, input_info)
|
332 |
+
|
333 |
+
x = self.proj_dec(x)
|
334 |
+
x = rearrange(
|
335 |
+
x,
|
336 |
+
"b n (p c) -> b (n p) c",
|
337 |
+
n=N_H * N_W,
|
338 |
+
p=self.preds_per_patch,
|
339 |
+
c=self.class_dim,
|
340 |
+
)
|
341 |
+
x = rearrange(
|
342 |
+
x,
|
343 |
+
"b (nh nw ph pw) c -> b c (nh ph) (nw pw)",
|
344 |
+
nh=N_H,
|
345 |
+
nw=N_W,
|
346 |
+
ph=int(self.preds_per_patch**0.5),
|
347 |
+
pw=int(self.preds_per_patch**0.5),
|
348 |
+
)
|
349 |
+
|
350 |
+
x = self.blocks(x)
|
351 |
+
|
352 |
+
# for block in self.blocks:
|
353 |
+
# x = block(x)
|
354 |
+
# print(x.shape)
|
355 |
+
|
356 |
+
# print(x.shape)
|
357 |
+
x = self.final_layer_1(x)
|
358 |
+
# print(x.shape)
|
359 |
+
x = self.final_layer_2(x)
|
360 |
+
# print(x.shape)
|
361 |
+
x = self.final_layer(x)
|
362 |
+
# print(x.shape)
|
363 |
+
|
364 |
+
# Interpolate to sod res
|
365 |
+
# x = F.interpolate(x, size=(H, W), mode=self.interpolate_mode)
|
366 |
+
|
367 |
+
return x
|
368 |
+
|
369 |
+
|
370 |
+
class Attention(nn.Module):
|
371 |
+
def __init__(
|
372 |
+
self,
|
373 |
+
dim: int,
|
374 |
+
num_heads=8,
|
375 |
+
qkv_bias=False,
|
376 |
+
attn_drop=0.0,
|
377 |
+
proj_drop=0.0,
|
378 |
+
):
|
379 |
+
super().__init__()
|
380 |
+
self.num_heads = num_heads
|
381 |
+
head_dim = dim // num_heads
|
382 |
+
self.scale = head_dim**-0.5
|
383 |
+
|
384 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
385 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
386 |
+
self.proj = nn.Linear(dim, dim)
|
387 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
388 |
+
|
389 |
+
def forward(self, x: Tensor) -> Tensor:
|
390 |
+
B, N, C = x.shape
|
391 |
+
qkv = (
|
392 |
+
self.qkv(x)
|
393 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
394 |
+
.permute(2, 0, 3, 1, 4)
|
395 |
+
)
|
396 |
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
397 |
+
|
398 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
399 |
+
attn = attn.softmax(dim=-1)
|
400 |
+
attn = self.attn_drop(attn)
|
401 |
+
|
402 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
403 |
+
x = self.proj(x)
|
404 |
+
x = self.proj_drop(x)
|
405 |
+
return x
|
406 |
+
|
407 |
+
|
408 |
+
class Mlp(nn.Module):
|
409 |
+
def __init__(
|
410 |
+
self,
|
411 |
+
in_features: int,
|
412 |
+
hidden_features: Optional[int] = None,
|
413 |
+
out_features: Optional[int] = None,
|
414 |
+
act_layer: nn.Module = nn.GELU,
|
415 |
+
drop: float = 0.0,
|
416 |
+
):
|
417 |
+
super().__init__()
|
418 |
+
out_features = out_features or in_features
|
419 |
+
hidden_features = hidden_features or in_features
|
420 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
421 |
+
self.act = act_layer()
|
422 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
423 |
+
self.drop = nn.Dropout(drop)
|
424 |
+
|
425 |
+
def forward(self, x: Tensor) -> Tensor:
|
426 |
+
x = self.fc1(x)
|
427 |
+
x = self.act(x)
|
428 |
+
# x = self.drop(x)
|
429 |
+
# commit this for the orignal BERT implement
|
430 |
+
x = self.fc2(x)
|
431 |
+
x = self.drop(x)
|
432 |
+
return x
|
433 |
+
|
434 |
+
|
435 |
+
class Block(nn.Module):
|
436 |
+
def __init__(
|
437 |
+
self,
|
438 |
+
dim: int,
|
439 |
+
num_heads: int,
|
440 |
+
mlp_ratio=4.0,
|
441 |
+
qkv_bias=False,
|
442 |
+
drop=0.0,
|
443 |
+
attn_drop=0.0,
|
444 |
+
drop_path=0.0,
|
445 |
+
act_layer=nn.GELU,
|
446 |
+
norm_layer=nn.LayerNorm,
|
447 |
+
):
|
448 |
+
super().__init__()
|
449 |
+
self.norm1 = norm_layer(dim)
|
450 |
+
self.attn = Attention(
|
451 |
+
dim,
|
452 |
+
num_heads=num_heads,
|
453 |
+
qkv_bias=qkv_bias,
|
454 |
+
attn_drop=attn_drop,
|
455 |
+
proj_drop=drop,
|
456 |
+
)
|
457 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
458 |
+
self.norm2 = norm_layer(dim)
|
459 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
460 |
+
self.mlp = Mlp(
|
461 |
+
in_features=dim,
|
462 |
+
hidden_features=mlp_hidden_dim,
|
463 |
+
act_layer=act_layer,
|
464 |
+
drop=drop,
|
465 |
+
)
|
466 |
+
|
467 |
+
def forward(self, x: Tensor) -> Tensor:
|
468 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
469 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
470 |
+
return x
|
471 |
+
|
472 |
+
|
473 |
+
class MultiMAE(nn.Module):
|
474 |
+
"""MultiMAE: Multi-task Multi-modal Masked Autoencoder
|
475 |
+
This module performs masking in its forward pass.
|
476 |
+
The MultiViT module defined below inherits from this module and performs a regular forward pass,
|
477 |
+
and should be used instead for downstream tasks
|
478 |
+
|
479 |
+
|
480 |
+
:param input_adapters: Dictionary of task -> input adapters
|
481 |
+
:param output_adapters: Optional dictionary of task -> output adapters
|
482 |
+
|
483 |
+
:param num_global_tokens: Number of additional global tokens to add (like cls tokens), default is 1
|
484 |
+
:param dim_tokens: Dimension of encoder tokens
|
485 |
+
:param depth: Depth of encoder
|
486 |
+
:param num_heads: Number of attention heads
|
487 |
+
:param mlp_ratio: MLP hidden dim ratio
|
488 |
+
:param qkv_bias: Set to False to disable bias
|
489 |
+
:param drop_rate: Dropout after MLPs and Attention
|
490 |
+
:param attn_drop_rate: Attention matrix drop rate
|
491 |
+
:param drop_path_rate: DropPath drop rate
|
492 |
+
:param norm_layer: Type of normalization layer
|
493 |
+
"""
|
494 |
+
|
495 |
+
def __init__(
|
496 |
+
self,
|
497 |
+
input_adapters: Dict[str, PatchedInputAdapter],
|
498 |
+
output_adapters: Dict[str, ConvNeXtAdapter],
|
499 |
+
num_global_tokens: int = 1,
|
500 |
+
dim_tokens: int = 768,
|
501 |
+
depth: int = 12,
|
502 |
+
num_heads: int = 12,
|
503 |
+
mlp_ratio: float = 4.0,
|
504 |
+
qkv_bias: bool = True,
|
505 |
+
drop_rate: float = 0.0,
|
506 |
+
attn_drop_rate: float = 0.0,
|
507 |
+
drop_path_rate: float = 0.0,
|
508 |
+
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
|
509 |
+
freeze_encoder: bool = False,
|
510 |
+
num_additional_gt_tokens: int = 0, # @deprecated
|
511 |
+
actual_num_additional_gt_tokens: int = 0, # @deprecated
|
512 |
+
learnable_additional_gt_tokens: bool = False,
|
513 |
+
additional_gt_tokens_mlp_channels: List[int] = [],
|
514 |
+
ground_truth_version: int = -1,
|
515 |
+
A: float = 0.5,
|
516 |
+
):
|
517 |
+
super().__init__()
|
518 |
+
self.dim_tokens = dim_tokens
|
519 |
+
self.ground_truth_version = ground_truth_version
|
520 |
+
# Initialize input and output adapters
|
521 |
+
for adapter in input_adapters.values():
|
522 |
+
adapter.init(dim_tokens=dim_tokens)
|
523 |
+
self.input_adapters = nn.ModuleDict(input_adapters)
|
524 |
+
for adapter in output_adapters.values():
|
525 |
+
adapter.init(dim_tokens_enc=dim_tokens)
|
526 |
+
self.output_adapters = nn.ModuleDict(output_adapters)
|
527 |
+
|
528 |
+
# Additional learnable tokens that can be used by encoder to process/store global information
|
529 |
+
self.num_global_tokens = num_global_tokens
|
530 |
+
self.global_tokens = nn.Parameter(torch.zeros(1, num_global_tokens, dim_tokens))
|
531 |
+
trunc_normal_(self.global_tokens, std=0.02)
|
532 |
+
|
533 |
+
self.num_additional_gt_tokens = num_additional_gt_tokens # @deprecated
|
534 |
+
self.actual_num_additional_gt_tokens = (
|
535 |
+
actual_num_additional_gt_tokens # @deprecated
|
536 |
+
)
|
537 |
+
self.A = A
|
538 |
+
self.additional_gt_tokens_mlp_channels = additional_gt_tokens_mlp_channels
|
539 |
+
self.learnable_additional_gt_tokens = learnable_additional_gt_tokens
|
540 |
+
self.init_gt_tokens()
|
541 |
+
|
542 |
+
# Transformer encoder
|
543 |
+
dpr = [
|
544 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
545 |
+
] # stochastic depth decay rule
|
546 |
+
self.encoder = nn.Sequential(
|
547 |
+
*[
|
548 |
+
Block(
|
549 |
+
dim=dim_tokens,
|
550 |
+
num_heads=num_heads,
|
551 |
+
mlp_ratio=mlp_ratio,
|
552 |
+
qkv_bias=qkv_bias,
|
553 |
+
drop=drop_rate,
|
554 |
+
attn_drop=attn_drop_rate,
|
555 |
+
drop_path=dpr[i],
|
556 |
+
norm_layer=norm_layer,
|
557 |
+
)
|
558 |
+
for i in range(depth)
|
559 |
+
]
|
560 |
+
)
|
561 |
+
|
562 |
+
print(f"Encoder {count_parameters(self.encoder)}")
|
563 |
+
|
564 |
+
if freeze_encoder:
|
565 |
+
print("Freeze encoder")
|
566 |
+
for param in self.encoder.parameters():
|
567 |
+
param.requires_grad = False
|
568 |
+
|
569 |
+
self.apply(self._init_weights)
|
570 |
+
for name, m in self.named_modules():
|
571 |
+
if isinstance(m, nn.Linear):
|
572 |
+
if "qkv" in name:
|
573 |
+
# treat the weights of Q, K, V separately
|
574 |
+
val = math.sqrt(
|
575 |
+
6.0 / float(m.weight.shape[0] // 3 + m.weight.shape[1])
|
576 |
+
)
|
577 |
+
nn.init.uniform_(m.weight, -val, val)
|
578 |
+
elif "kv" in name:
|
579 |
+
# treat the weights of K, V separately
|
580 |
+
val = math.sqrt(
|
581 |
+
6.0 / float(m.weight.shape[0] // 2 + m.weight.shape[1])
|
582 |
+
)
|
583 |
+
nn.init.uniform_(m.weight, -val, val)
|
584 |
+
|
585 |
+
if isinstance(m, nn.Conv2d):
|
586 |
+
if ".proj" in name:
|
587 |
+
# From MAE, initialize projection like nn.Linear (instead of nn.Conv2d)
|
588 |
+
w = m.weight.data
|
589 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
590 |
+
|
591 |
+
print(f"Total params: {count_parameters(self)}")
|
592 |
+
|
593 |
+
def init_gt_tokens(self):
|
594 |
+
"""Just prepare beforehand to save time in training
|
595 |
+
In inference, there is no need"""
|
596 |
+
addtional_gt_tokens: List[Tensor] = []
|
597 |
+
if self.num_additional_gt_tokens == 0:
|
598 |
+
self.token_mlp = nn.Identity()
|
599 |
+
return
|
600 |
+
if len(self.additional_gt_tokens_mlp_channels) > 0:
|
601 |
+
self.token_mlp = MLP(
|
602 |
+
self.dim_tokens,
|
603 |
+
self.additional_gt_tokens_mlp_channels + [self.dim_tokens],
|
604 |
+
)
|
605 |
+
else:
|
606 |
+
self.token_mlp = nn.Identity()
|
607 |
+
|
608 |
+
if self.ground_truth_version != 6:
|
609 |
+
T = 1 / (self.num_additional_gt_tokens * 4)
|
610 |
+
for i in range(self.actual_num_additional_gt_tokens):
|
611 |
+
t = [
|
612 |
+
2 * math.pi * (offset / self.dim_tokens - i * T)
|
613 |
+
for offset in range(self.dim_tokens)
|
614 |
+
]
|
615 |
+
addtional_gt_tokens.append(
|
616 |
+
nn.Parameter(
|
617 |
+
self.A * torch.cos(Tensor(t).unsqueeze(0).unsqueeze(0)),
|
618 |
+
requires_grad=self.learnable_additional_gt_tokens,
|
619 |
+
)
|
620 |
+
)
|
621 |
+
self.addtional_gt_tokens = nn.ParameterList(addtional_gt_tokens)
|
622 |
+
|
623 |
+
def _init_weights(self, m: nn.Module) -> None:
|
624 |
+
if isinstance(m, nn.Linear):
|
625 |
+
nn.init.xavier_uniform_(m.weight)
|
626 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
627 |
+
nn.init.constant_(m.bias, 0)
|
628 |
+
elif isinstance(m, nn.LayerNorm):
|
629 |
+
nn.init.constant_(m.bias, 0)
|
630 |
+
nn.init.constant_(m.weight, 1.0)
|
631 |
+
|
632 |
+
@torch.jit.ignore
|
633 |
+
def no_weight_decay(self):
|
634 |
+
no_wd_set = {"global_tokens"}
|
635 |
+
|
636 |
+
for task, adapter in self.input_adapters.items():
|
637 |
+
if hasattr(adapter, "no_weight_decay"):
|
638 |
+
to_skip = adapter.no_weight_decay()
|
639 |
+
to_skip = set([f"input_adapters.{task}.{name}" for name in to_skip])
|
640 |
+
no_wd_set = no_wd_set | to_skip
|
641 |
+
|
642 |
+
for task, adapter in self.output_adapters.items():
|
643 |
+
if hasattr(adapter, "no_weight_decay"):
|
644 |
+
to_skip = adapter.no_weight_decay()
|
645 |
+
to_skip = set([f"output_adapters.{task}.{name}" for name in to_skip])
|
646 |
+
no_wd_set = no_wd_set | to_skip
|
647 |
+
|
648 |
+
return no_wd_set
|
649 |
+
|
650 |
+
def generate_input_info(
|
651 |
+
self, input_task_tokens: Dict[str, Tensor], image_size: Tuple[int, int]
|
652 |
+
) -> Dict[str, Tensor]:
|
653 |
+
input_info = OrderedDict()
|
654 |
+
i = 0
|
655 |
+
input_info["tasks"] = {}
|
656 |
+
for domain, tensor in input_task_tokens.items():
|
657 |
+
num_tokens: Union[int, Tensor] = tensor.shape[1]
|
658 |
+
|
659 |
+
if type(num_tokens) == Tensor:
|
660 |
+
num_tokens = num_tokens.item()
|
661 |
+
|
662 |
+
d = {
|
663 |
+
"num_tokens": num_tokens,
|
664 |
+
"has_2d_posemb": True,
|
665 |
+
"start_idx": i,
|
666 |
+
"end_idx": i + num_tokens,
|
667 |
+
}
|
668 |
+
i += num_tokens
|
669 |
+
input_info["tasks"][domain] = d
|
670 |
+
|
671 |
+
input_info["image_size"] = image_size
|
672 |
+
input_info["num_task_tokens"] = i
|
673 |
+
input_info["num_global_tokens"] = self.num_global_tokens
|
674 |
+
|
675 |
+
return input_info
|
676 |
+
|
677 |
+
|
678 |
+
class MultiViT(MultiMAE):
|
679 |
+
def extract_B_H_W(self, x: Dict[str, Tensor]) -> Tuple[int, int, int]:
|
680 |
+
# If input x is a Tensor, assume it's RGB
|
681 |
+
# x = {'rgb': x} if isinstance(x, Tensor) else x
|
682 |
+
# Need image size for tokens->image reconstruction
|
683 |
+
if "rgb" in x:
|
684 |
+
B, _, H, W = x["rgb"].shape
|
685 |
+
elif "sod" in x:
|
686 |
+
B, H, W = x["sod"].shape
|
687 |
+
H *= self.input_adapters["sod"].stride_level
|
688 |
+
W *= self.input_adapters["sod"].stride_level
|
689 |
+
else:
|
690 |
+
B, _, H, W = list(x.values())[0].shape
|
691 |
+
return B, H, W
|
692 |
+
|
693 |
+
def process_input(
|
694 |
+
self,
|
695 |
+
x: Dict[str, Tensor],
|
696 |
+
gt_index_lst: List[int],
|
697 |
+
num_gts_lst: List[int],
|
698 |
+
) -> Tuple[Tensor, Dict[str, Tensor]]:
|
699 |
+
"""
|
700 |
+
len(gt_i) must equal to x.shape[0] when self.num_additional_gt_tokens > 0
|
701 |
+
"""
|
702 |
+
B, H, W = self.extract_B_H_W(x)
|
703 |
+
|
704 |
+
# Encode selected inputs to tokens
|
705 |
+
input_task_tokens: Dict[str, Tensor] = {
|
706 |
+
domain: self.input_adapters[domain](tensor)
|
707 |
+
for domain, tensor in x.items()
|
708 |
+
if domain in self.input_adapters
|
709 |
+
}
|
710 |
+
|
711 |
+
input_info = self.generate_input_info(
|
712 |
+
input_task_tokens=input_task_tokens, image_size=(H, W)
|
713 |
+
)
|
714 |
+
input_tokens = torch.cat(
|
715 |
+
[task_tokens for task_tokens in input_task_tokens.values()], dim=1
|
716 |
+
)
|
717 |
+
|
718 |
+
# Add global tokens to input tokens
|
719 |
+
global_tokens = repeat(self.global_tokens, "() n d -> b n d", b=B)
|
720 |
+
|
721 |
+
if self.ground_truth_version == 6:
|
722 |
+
# We need two inputs: gt_index, num_gts
|
723 |
+
assert len(gt_index_lst) == len(num_gts_lst)
|
724 |
+
additional_gt_tokens = []
|
725 |
+
for gt_index, num_gts in zip(gt_index_lst, num_gts_lst):
|
726 |
+
T = 1 / num_gts
|
727 |
+
i = gt_index
|
728 |
+
t = [
|
729 |
+
2 * math.pi * (offset / self.dim_tokens - i * T)
|
730 |
+
for offset in range(self.dim_tokens)
|
731 |
+
]
|
732 |
+
additional_gt_token = self.A * torch.cos(
|
733 |
+
Tensor(t).unsqueeze(0).unsqueeze(0)
|
734 |
+
)
|
735 |
+
additional_gt_tokens.append(additional_gt_token)
|
736 |
+
additional_gt_tokens = torch.cat(additional_gt_tokens, dim=0).to(
|
737 |
+
input_tokens.device
|
738 |
+
)
|
739 |
+
additional_gt_tokens = self.token_mlp(additional_gt_tokens)
|
740 |
+
input_tokens = torch.cat(
|
741 |
+
[input_tokens, global_tokens, additional_gt_tokens], dim=1
|
742 |
+
)
|
743 |
+
else:
|
744 |
+
if self.num_additional_gt_tokens > 0:
|
745 |
+
|
746 |
+
assert gt_index_lst is not None and len(gt_index_lst) == B
|
747 |
+
additional_gt_tokens: Tensor = torch.cat(
|
748 |
+
[self.addtional_gt_tokens[gt_i] for gt_i in gt_index_lst], dim=0
|
749 |
+
)
|
750 |
+
additional_gt_tokens = self.token_mlp(additional_gt_tokens)
|
751 |
+
input_tokens = torch.cat(
|
752 |
+
[input_tokens, global_tokens, additional_gt_tokens], dim=1
|
753 |
+
)
|
754 |
+
else:
|
755 |
+
input_tokens = torch.cat([input_tokens, global_tokens], dim=1)
|
756 |
+
|
757 |
+
return input_tokens, input_info
|
758 |
+
|
759 |
+
def forward(
|
760 |
+
self,
|
761 |
+
x: Dict[str, Tensor],
|
762 |
+
gt_index_lst: Optional[List[int]] = None,
|
763 |
+
max_gts_lst: Optional[List[int]] = None,
|
764 |
+
) -> Dict[str, Tensor]:
|
765 |
+
"""
|
766 |
+
Forward pass through input adapters, transformer encoder and output adapters.
|
767 |
+
|
768 |
+
:param x: Dictionary of tensors
|
769 |
+
:param outputs: List of outputs. For ex: outputs=['sod', 'depth']. Make sure 'sod' placed first!
|
770 |
+
"""
|
771 |
+
input_tokens, input_info = self.process_input(x, gt_index_lst, max_gts_lst)
|
772 |
+
|
773 |
+
# Pass tokens through Transformer
|
774 |
+
encoder_tokens = self.encoder(input_tokens)
|
775 |
+
|
776 |
+
# Decode tokens for each task using task-specific output adapters
|
777 |
+
preds = {
|
778 |
+
domain: self.output_adapters[domain](
|
779 |
+
encoder_tokens=encoder_tokens,
|
780 |
+
input_info=input_info,
|
781 |
+
)
|
782 |
+
for domain in self.output_adapters
|
783 |
+
}
|
784 |
+
|
785 |
+
return preds
|
786 |
+
|
787 |
+
|
788 |
+
def interpolate_pos_embed_multimae(
|
789 |
+
model: MultiViT,
|
790 |
+
checkpoint_model: Dict[str, Tensor],
|
791 |
+
) -> None:
|
792 |
+
pattern = "input_adapters\.(.*)\.pos_emb"
|
793 |
+
matched_keys = [k for k in checkpoint_model if bool(re.match(pattern, k))]
|
794 |
+
|
795 |
+
for key in matched_keys:
|
796 |
+
domain = re.match(pattern, key).group(1) # group(0) is entire matched regex
|
797 |
+
if getattr(model.input_adapters, domain, None) is not None:
|
798 |
+
pos_embed_checkpoint = checkpoint_model[key]
|
799 |
+
_, _, orig_H, orig_W = pos_embed_checkpoint.shape
|
800 |
+
_, _, new_H, new_W = getattr(model.input_adapters, domain).pos_emb.shape
|
801 |
+
if (orig_H != new_H) or (orig_W != new_W):
|
802 |
+
print(
|
803 |
+
f"Key {key}: Position interpolate from {orig_H}x{orig_W} to {new_H}x{new_W}"
|
804 |
+
)
|
805 |
+
pos_embed_checkpoint = torch.nn.functional.interpolate(
|
806 |
+
pos_embed_checkpoint,
|
807 |
+
size=(new_H, new_W),
|
808 |
+
mode="bicubic",
|
809 |
+
align_corners=False,
|
810 |
+
)
|
811 |
+
checkpoint_model[key] = pos_embed_checkpoint
|
812 |
+
|
813 |
+
|
814 |
+
def construct_adapters(cfg: base_cfg):
|
815 |
+
INPUT_ADAPTERS = {
|
816 |
+
"rgb": PatchedInputAdapter(
|
817 |
+
num_channels=3,
|
818 |
+
stride_level=1,
|
819 |
+
patch_size_full=cfg.input_patch_size,
|
820 |
+
image_size=cfg.image_size,
|
821 |
+
learnable_pos_emb=cfg.learnable_pos_emb,
|
822 |
+
),
|
823 |
+
"depth": PatchedInputAdapter(
|
824 |
+
num_channels=1,
|
825 |
+
stride_level=1,
|
826 |
+
patch_size_full=cfg.input_patch_size,
|
827 |
+
image_size=cfg.image_size,
|
828 |
+
learnable_pos_emb=cfg.learnable_pos_emb,
|
829 |
+
),
|
830 |
+
}
|
831 |
+
|
832 |
+
num_classes = cfg.num_classes
|
833 |
+
if cfg.ground_truth_version in [5, 6]:
|
834 |
+
num_classes = 1
|
835 |
+
|
836 |
+
OUTPUT_ADAPTERS = {
|
837 |
+
"sod": partial(
|
838 |
+
ConvNeXtAdapter,
|
839 |
+
num_classes=num_classes,
|
840 |
+
image_size=cfg.image_size,
|
841 |
+
embed_dim=cfg.embed_dim,
|
842 |
+
patch_size=cfg.input_patch_size,
|
843 |
+
preds_per_patch=cfg.output_patch_size,
|
844 |
+
depth=cfg.decoder_depth,
|
845 |
+
interpolate_mode=cfg.decoder_interpolate_mode,
|
846 |
+
main_tasks=cfg.decoder_main_tasks,
|
847 |
+
act_fn=cfg.act_fn,
|
848 |
+
dec_kernel=cfg.dec_kernel,
|
849 |
+
),
|
850 |
+
"rgb": partial(
|
851 |
+
ConvNeXtAdapter,
|
852 |
+
num_classes=3,
|
853 |
+
image_size=cfg.image_size,
|
854 |
+
embed_dim=cfg.embed_dim,
|
855 |
+
patch_size=cfg.input_patch_size,
|
856 |
+
preds_per_patch=cfg.output_patch_size,
|
857 |
+
depth=cfg.decoder_depth,
|
858 |
+
interpolate_mode=cfg.decoder_interpolate_mode,
|
859 |
+
main_tasks=cfg.decoder_main_tasks,
|
860 |
+
act_fn=cfg.act_fn,
|
861 |
+
dec_kernel=cfg.dec_kernel,
|
862 |
+
),
|
863 |
+
"depth": partial(
|
864 |
+
ConvNeXtAdapter,
|
865 |
+
num_classes=1,
|
866 |
+
image_size=cfg.image_size,
|
867 |
+
embed_dim=cfg.embed_dim,
|
868 |
+
patch_size=cfg.input_patch_size,
|
869 |
+
preds_per_patch=cfg.output_patch_size,
|
870 |
+
depth=cfg.decoder_depth,
|
871 |
+
interpolate_mode=cfg.decoder_interpolate_mode,
|
872 |
+
main_tasks=cfg.decoder_main_tasks,
|
873 |
+
act_fn=cfg.act_fn,
|
874 |
+
dec_kernel=cfg.dec_kernel,
|
875 |
+
),
|
876 |
+
}
|
877 |
+
|
878 |
+
if cfg.ground_truth_version == 3:
|
879 |
+
for i in range(cfg.num_classes):
|
880 |
+
OUTPUT_ADAPTERS[f"sod{i}"] = partial(
|
881 |
+
ConvNeXtAdapter,
|
882 |
+
num_classes=1,
|
883 |
+
image_size=cfg.image_size,
|
884 |
+
embed_dim=cfg.embed_dim,
|
885 |
+
patch_size=cfg.input_patch_size,
|
886 |
+
preds_per_patch=cfg.output_patch_size,
|
887 |
+
depth=cfg.decoder_depth,
|
888 |
+
interpolate_mode=cfg.decoder_interpolate_mode,
|
889 |
+
main_tasks=cfg.decoder_main_tasks,
|
890 |
+
act_fn=cfg.act_fn,
|
891 |
+
dec_kernel=cfg.dec_kernel,
|
892 |
+
)
|
893 |
+
return INPUT_ADAPTERS, OUTPUT_ADAPTERS
|
894 |
+
|
895 |
+
|
896 |
+
def generate_smultimae_model(cfg: base_cfg) -> Tuple[MultiViT, List[Dict]]:
|
897 |
+
"""MULTIMAE"""
|
898 |
+
assert len(cfg.decoder_main_tasks) == len(
|
899 |
+
cfg.outputs
|
900 |
+
), "Length of decoder main tasks must match length of outputs"
|
901 |
+
|
902 |
+
INPUT_ADAPTERS, OUTPUT_ADAPTERS = construct_adapters(cfg)
|
903 |
+
|
904 |
+
input_adapters = dict()
|
905 |
+
for input_key in cfg.inputs:
|
906 |
+
input_adapters[input_key] = INPUT_ADAPTERS[input_key]
|
907 |
+
|
908 |
+
output_adapters = dict()
|
909 |
+
for output_key, decoder_main_tasks_per_output in zip(
|
910 |
+
cfg.outputs, cfg.decoder_main_tasks
|
911 |
+
):
|
912 |
+
output_adapters[output_key] = OUTPUT_ADAPTERS[output_key](
|
913 |
+
main_tasks=decoder_main_tasks_per_output
|
914 |
+
)
|
915 |
+
|
916 |
+
num_additional_gt_tokens = 0 # @deprecated
|
917 |
+
actual_num_additional_gt_tokens = 0 # @deprecated
|
918 |
+
if cfg.ground_truth_version in [5, 6]: # @deprecated
|
919 |
+
num_additional_gt_tokens = cfg.num_classes # @deprecated
|
920 |
+
actual_num_additional_gt_tokens = cfg.actual_num_classes # @deprecated
|
921 |
+
model = MultiViT(
|
922 |
+
input_adapters=input_adapters,
|
923 |
+
output_adapters=output_adapters,
|
924 |
+
freeze_encoder=cfg.freeze_encoder,
|
925 |
+
drop_path_rate=0.1,
|
926 |
+
dim_tokens=cfg.dim_tokens,
|
927 |
+
depth=cfg.encoder_depth,
|
928 |
+
num_heads=cfg.num_heads,
|
929 |
+
mlp_ratio=4,
|
930 |
+
qkv_bias=True,
|
931 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
932 |
+
num_additional_gt_tokens=num_additional_gt_tokens, # @deprecated
|
933 |
+
actual_num_additional_gt_tokens=actual_num_additional_gt_tokens, # @deprecated
|
934 |
+
ground_truth_version=cfg.ground_truth_version,
|
935 |
+
)
|
936 |
+
|
937 |
+
# return load_pretrained_backbone(cfg, model)
|
938 |
+
return model, []
|
s_multimae/model_pl.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
import os
|
3 |
+
from typing import Any, Dict, List, Optional, Tuple
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
from torch import Tensor, nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from .configs.base_config import base_cfg
|
12 |
+
from .rgbd_model import RGBDModel
|
13 |
+
|
14 |
+
|
15 |
+
class ModelPL(pl.LightningModule):
|
16 |
+
def __init__(self, cfg: base_cfg):
|
17 |
+
super().__init__()
|
18 |
+
self.cfg = cfg
|
19 |
+
self.model = RGBDModel(cfg)
|
20 |
+
|
21 |
+
def forward(self, images: Tensor, depths: Tensor):
|
22 |
+
return self.model.forward(images, depths)
|
23 |
+
|
24 |
+
def __inference_v1(
|
25 |
+
self, outputs: Dict[str, Tensor], image_sizes: List[Tuple[int, int]]
|
26 |
+
):
|
27 |
+
res_lst: List[List[np.ndarray]] = []
|
28 |
+
for output, image_size in zip(outputs["sod"], image_sizes):
|
29 |
+
output: Tensor = F.interpolate(
|
30 |
+
output.unsqueeze(0),
|
31 |
+
size=(image_size[1], image_size[0]),
|
32 |
+
mode="bilinear",
|
33 |
+
align_corners=False,
|
34 |
+
)
|
35 |
+
res: np.ndarray = output.sigmoid().data.cpu().numpy().squeeze()
|
36 |
+
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
|
37 |
+
if self.cfg.is_fp16:
|
38 |
+
res = np.float32(res)
|
39 |
+
res_lst.append([(res * 255).astype(np.uint8)])
|
40 |
+
return res_lst
|
41 |
+
|
42 |
+
def __inference_v2(
|
43 |
+
self, outputs: Dict[str, Tensor], image_sizes: List[Tuple[int, int]]
|
44 |
+
):
|
45 |
+
res_lst: List[List[np.ndarray]] = []
|
46 |
+
for output, image_size in zip(outputs["sod"], image_sizes):
|
47 |
+
output: Tensor = F.interpolate(
|
48 |
+
output.unsqueeze(0),
|
49 |
+
size=(image_size[1], image_size[0]),
|
50 |
+
mode="bilinear",
|
51 |
+
align_corners=False,
|
52 |
+
)
|
53 |
+
res: np.ndarray = torch.argmax(output, dim=1).cpu().numpy().squeeze()
|
54 |
+
res_lst.append([res])
|
55 |
+
return res_lst
|
56 |
+
|
57 |
+
def __inference_v3v5(
|
58 |
+
self, outputs: Dict[str, Tensor], image_sizes: List[Tuple[int, int]]
|
59 |
+
):
|
60 |
+
res_lst: List[List[np.ndarray]] = []
|
61 |
+
for bi, image_size in enumerate(image_sizes):
|
62 |
+
res_lst_per_sample: List[np.ndarray] = []
|
63 |
+
for i in range(self.cfg.num_classes):
|
64 |
+
pred = outputs[f"sod{i}"][bi]
|
65 |
+
pred: Tensor = F.interpolate(
|
66 |
+
pred.unsqueeze(0),
|
67 |
+
size=(image_size[1], image_size[0]),
|
68 |
+
mode="bilinear",
|
69 |
+
align_corners=False,
|
70 |
+
)
|
71 |
+
res: np.ndarray = pred.sigmoid().data.cpu().numpy().squeeze()
|
72 |
+
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
|
73 |
+
if self.cfg.is_fp16:
|
74 |
+
res = np.float32(res)
|
75 |
+
res_lst_per_sample.append((res * 255).astype(np.uint8))
|
76 |
+
res_lst.append(res_lst_per_sample)
|
77 |
+
return res_lst
|
78 |
+
|
79 |
+
@torch.no_grad()
|
80 |
+
def inference(
|
81 |
+
self,
|
82 |
+
image_sizes: List[Tuple[int, int]],
|
83 |
+
images: Tensor,
|
84 |
+
depths: Optional[Tensor],
|
85 |
+
max_gts: Optional[List[int]],
|
86 |
+
) -> List[List[np.ndarray]]:
|
87 |
+
self.model.eval()
|
88 |
+
assert len(image_sizes) == len(
|
89 |
+
images
|
90 |
+
), "The number of image_sizes must equal to the number of images"
|
91 |
+
gpu_images: Tensor = images.to(self.device)
|
92 |
+
gpu_depths: Tensor = depths.to(self.device)
|
93 |
+
|
94 |
+
if self.cfg.ground_truth_version == 6:
|
95 |
+
with torch.cuda.amp.autocast(enabled=self.cfg.is_fp16):
|
96 |
+
outputs: Dict[str, Tensor] = dict()
|
97 |
+
for i in range(self.cfg.num_classes):
|
98 |
+
outputs[f"sod{i}"] = self.model.inference(
|
99 |
+
gpu_images, gpu_depths, [i] * gpu_images.shape[0], max_gts
|
100 |
+
)["sod"]
|
101 |
+
return self.__inference_v3v5(outputs, image_sizes)
|
102 |
+
else:
|
103 |
+
raise Exception(
|
104 |
+
f"Unsupported ground_truth_version {self.cfg.ground_truth_version}"
|
105 |
+
)
|
s_multimae/rgbd_model.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional
|
2 |
+
from torch import nn, Tensor
|
3 |
+
|
4 |
+
from .model.multimae import generate_smultimae_model as generate_smultimae_model_v1
|
5 |
+
from .configs.base_config import base_cfg
|
6 |
+
|
7 |
+
|
8 |
+
class RGBDModel(nn.Module):
|
9 |
+
def __init__(self, cfg: base_cfg):
|
10 |
+
super(RGBDModel, self).__init__()
|
11 |
+
|
12 |
+
self.inputs = cfg.inputs
|
13 |
+
self.outputs = cfg.outputs
|
14 |
+
|
15 |
+
self.is_no_depth = cfg.is_inference_with_no_depth
|
16 |
+
|
17 |
+
if cfg.model_version == 1:
|
18 |
+
self.model, self.opt_params = generate_smultimae_model_v1(cfg)
|
19 |
+
else:
|
20 |
+
raise Exception(f"Unsupported model version {cfg.model_version}")
|
21 |
+
|
22 |
+
def encode_decode(
|
23 |
+
self,
|
24 |
+
images: Tensor,
|
25 |
+
depths: Optional[Tensor],
|
26 |
+
gt_index_lst: Optional[List[int]] = None,
|
27 |
+
max_gts_lst: Optional[List[int]] = None,
|
28 |
+
) -> Dict[str, Tensor]:
|
29 |
+
"""Encode images with backbone and decode into a semantic segmentation
|
30 |
+
map of the same size as input.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
{
|
34 |
+
"sod": Tensor,
|
35 |
+
"depth": Optional[Tensor],
|
36 |
+
"rgb": Optional[tensor],
|
37 |
+
}
|
38 |
+
"""
|
39 |
+
inputs = {"rgb": images}
|
40 |
+
if "depth" in self.inputs:
|
41 |
+
inputs["depth"] = depths
|
42 |
+
return self.model.forward(inputs, gt_index_lst, max_gts_lst)
|
43 |
+
|
44 |
+
def forward(
|
45 |
+
self,
|
46 |
+
images: Tensor,
|
47 |
+
depths: Optional[Tensor],
|
48 |
+
gt_index_lst: Optional[List[int]] = None,
|
49 |
+
max_gts_lst: Optional[List[int]] = None,
|
50 |
+
) -> Dict[str, Tensor]:
|
51 |
+
return self.encode_decode(images, depths, gt_index_lst, max_gts_lst)
|
52 |
+
|
53 |
+
def inference(
|
54 |
+
self,
|
55 |
+
images: Tensor,
|
56 |
+
depths: Optional[Tensor],
|
57 |
+
gt_index_lst: Optional[List[int]] = None,
|
58 |
+
max_gts_lst: Optional[List[int]] = None,
|
59 |
+
) -> Dict[str, Tensor]:
|
60 |
+
return self.encode_decode(images, depths, gt_index_lst, max_gts_lst)
|
s_multimae/utils.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from glob import glob
|
3 |
+
import random
|
4 |
+
from typing import Dict, List
|
5 |
+
from torch import nn, Tensor
|
6 |
+
import os, shutil
|
7 |
+
import torch
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import numpy as np
|
10 |
+
import gc, cv2
|
11 |
+
|
12 |
+
from .visualizer import post_processing_depth
|
13 |
+
|
14 |
+
"""
|
15 |
+
This module should not depend on other s_multimae modules.
|
16 |
+
"""
|
17 |
+
|
18 |
+
num_format = "{:,}".format
|
19 |
+
|
20 |
+
|
21 |
+
def list_dirs(dir_root: str) -> List[str]:
|
22 |
+
return list(
|
23 |
+
sorted(
|
24 |
+
[
|
25 |
+
item
|
26 |
+
for item in os.listdir(dir_root)
|
27 |
+
if os.path.isdir(f"{dir_root}/{item}")
|
28 |
+
]
|
29 |
+
)
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
def clean_cache() -> None:
|
34 |
+
torch.cuda.empty_cache()
|
35 |
+
gc.collect()
|
36 |
+
|
37 |
+
|
38 |
+
def count_parameters(model: nn.Module) -> str:
|
39 |
+
"""Count the number of learnable parameters of a model"""
|
40 |
+
return num_format(sum(p.numel() for p in model.parameters() if p.requires_grad))
|
41 |
+
|
42 |
+
|
43 |
+
def ranking_gts_to_dict(
|
44 |
+
ranking_gts: List[np.ndarray | str],
|
45 |
+
) -> Dict[str, np.ndarray | str]:
|
46 |
+
"""
|
47 |
+
Return:
|
48 |
+
dict(
|
49 |
+
gt0=ranking_gts[0],
|
50 |
+
gt1=ranking_gts[1],
|
51 |
+
gt2=ranking_gts[2],
|
52 |
+
gt3=ranking_gts[3],
|
53 |
+
gt4=ranking_gts[4],
|
54 |
+
)
|
55 |
+
"""
|
56 |
+
return {f"gt{i}": v for i, v in enumerate(ranking_gts)}
|
57 |
+
|
58 |
+
|
59 |
+
def dict_to_ranking_gts(d: Dict[str, np.ndarray], l=5) -> List[np.ndarray]:
|
60 |
+
"""
|
61 |
+
Return: [ranking_gts["gt0"], ranking_gts["gt1"], ...]
|
62 |
+
"""
|
63 |
+
return [d[f"gt{i}"] for i in range(l)]
|
64 |
+
|
65 |
+
|
66 |
+
def random_choice(p: float) -> bool:
|
67 |
+
"""Return True if random float <= p"""
|
68 |
+
return random.random() <= p
|
69 |
+
|
70 |
+
|
71 |
+
def fname_without_ext(p: str) -> str:
|
72 |
+
return os.path.splitext(os.path.basename(p))[0]
|
73 |
+
|
74 |
+
|
75 |
+
def list_files(
|
76 |
+
dirpaths: List[str] = [
|
77 |
+
"datasets/v1/train/RGB",
|
78 |
+
"datasets/v1/train/GT",
|
79 |
+
"datasets/v1/train/depths",
|
80 |
+
],
|
81 |
+
) -> List[List[str]]:
|
82 |
+
assert len(dirpaths) >= 1, "dirnames must contain at least 1 item"
|
83 |
+
|
84 |
+
fullpaths_lst: List[List[str]] = []
|
85 |
+
names_lst: List[List[str]] = []
|
86 |
+
|
87 |
+
for dirname in dirpaths:
|
88 |
+
fullpaths = list(sorted(glob(os.path.join(dirname, "*"))))
|
89 |
+
names = [fname_without_ext(fullpath) for fullpath in fullpaths]
|
90 |
+
fullpaths_lst.append(fullpaths)
|
91 |
+
names_lst.append(names)
|
92 |
+
|
93 |
+
rs: List[List[str]] = [fullpaths_lst[0]] + [[] for _ in range(len(dirpaths) - 1)]
|
94 |
+
|
95 |
+
# Ensure integrity
|
96 |
+
assert (
|
97 |
+
len(set([len(e) for e in names_lst])) == 1
|
98 |
+
), f"Data is not integrity {[len(e) for e in names_lst]} | dirpath = {dirpaths}"
|
99 |
+
|
100 |
+
for name in names_lst[0]:
|
101 |
+
for i, names in enumerate(names_lst[1:]):
|
102 |
+
idx = names.index(name)
|
103 |
+
rs[i + 1].append(fullpaths_lst[i + 1][idx])
|
104 |
+
|
105 |
+
return rs
|
106 |
+
|
107 |
+
|
108 |
+
def scale_saliency_maps(inputs: Tensor) -> Tensor:
|
109 |
+
"""Input: Tensor, shape of (B, C, H, W)"""
|
110 |
+
min_v = (
|
111 |
+
torch.min(torch.flatten(inputs, 1), dim=1)[0]
|
112 |
+
.unsqueeze(1)
|
113 |
+
.unsqueeze(1)
|
114 |
+
.unsqueeze(1)
|
115 |
+
)
|
116 |
+
max_v = (
|
117 |
+
torch.max(torch.flatten(inputs, 1), dim=1)[0]
|
118 |
+
.unsqueeze(1)
|
119 |
+
.unsqueeze(1)
|
120 |
+
.unsqueeze(1)
|
121 |
+
)
|
122 |
+
return (inputs - min_v) / (max_v - min_v + 1e-8)
|
123 |
+
|
124 |
+
|
125 |
+
def get_epoch_from_ckpt_path(ckpt_path: str) -> int:
|
126 |
+
"""Example ckpt_path
|
127 |
+
os.path.join(experiment_dir_path, 'cfgv2.3', 'checkpoint_100.pt')
|
128 |
+
"""
|
129 |
+
return int(ckpt_path.split("_")[-1].split(".")[0])
|
130 |
+
|
131 |
+
|
132 |
+
def clean_dir(dir_path: str) -> None:
|
133 |
+
"""Remove a directory if existed and create an empty directory"""
|
134 |
+
if os.path.isdir(dir_path):
|
135 |
+
shutil.rmtree(dir_path)
|
136 |
+
os.makedirs(dir_path, exist_ok=True)
|
137 |
+
|
138 |
+
|
139 |
+
def get_sota_type(experiment_name: str) -> int:
|
140 |
+
"""0 for SOTAs, 4 for experiment version 4, e.g. ..."""
|
141 |
+
if "cfgv" not in experiment_name:
|
142 |
+
return 0
|
143 |
+
|
144 |
+
half_right = experiment_name.split("cfgv")[1]
|
145 |
+
return int(half_right.split("_")[0])
|
146 |
+
|
147 |
+
|
148 |
+
def hex_to_rgb(hex: str) -> np.ndarray:
|
149 |
+
"""Convert hex color to rgb color
|
150 |
+
|
151 |
+
Args:
|
152 |
+
hex (str): "#00f900"
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
np.ndarray: numpy array of rgb color
|
156 |
+
"""
|
157 |
+
hex = hex[1:]
|
158 |
+
rgb = []
|
159 |
+
for i in (0, 2, 4):
|
160 |
+
decimal = int(hex[i : i + 2], 16)
|
161 |
+
rgb.append(decimal)
|
162 |
+
|
163 |
+
return (np.array(rgb) / 255.0)[::-1]
|
164 |
+
|
165 |
+
|
166 |
+
def normalize(data: np.ndarray) -> np.ndarray:
|
167 |
+
return (data - data.min()) / (data.max() - data.min() + 1e-8)
|
168 |
+
|
169 |
+
|
170 |
+
def post_processing_depth(depth_path: str) -> np.ndarray:
|
171 |
+
depth = np.array(Image.open(depth_path).convert("L"))
|
172 |
+
depth = (normalize(depth) * 255).astype(np.uint8)
|
173 |
+
return cv2.applyColorMap(depth, cv2.COLORMAP_SUMMER)
|
174 |
+
|
175 |
+
|
176 |
+
def convert_batch_tensors_to_numpy_images(images: Tensor) -> np.ndarray:
|
177 |
+
"""images of shape (batch_size, channels, width, height)"""
|
178 |
+
images = torch.permute(images, (0, 2, 3, 1))
|
179 |
+
images = images.numpy()
|
180 |
+
if images.shape[3] == 1:
|
181 |
+
return np.squeeze(images, axis=3)
|
182 |
+
else:
|
183 |
+
return images
|
184 |
+
|
185 |
+
|
186 |
+
def join_horizontally(lst: List[np.ndarray]) -> np.ndarray:
|
187 |
+
return np.concatenate(lst, axis=1)
|
188 |
+
|
189 |
+
|
190 |
+
def join_vertically(lst: List[np.ndarray]) -> np.ndarray:
|
191 |
+
return np.concatenate(lst, axis=0)
|
192 |
+
|
193 |
+
|
194 |
+
def plot_batch_of_pairs(
|
195 |
+
images: Tensor,
|
196 |
+
depths: Tensor,
|
197 |
+
gts: Tensor,
|
198 |
+
save_file_path: str,
|
199 |
+
) -> None:
|
200 |
+
images = convert_batch_tensors_to_numpy_images(images)
|
201 |
+
depths = convert_batch_tensors_to_numpy_images(depths)
|
202 |
+
gts = convert_batch_tensors_to_numpy_images(gts)
|
203 |
+
batch_size = images.shape[0]
|
204 |
+
samples: List[np.ndarray] = []
|
205 |
+
|
206 |
+
# fig, axes = plt.subplots(batch_size, 3, figsize=(3*batch_size, 20)) # (number of images, 3)
|
207 |
+
for i in range(batch_size):
|
208 |
+
samples.append(
|
209 |
+
join_horizontally(
|
210 |
+
[
|
211 |
+
((images[i] + 1.0) / 2 * 255).astype(np.uint8),
|
212 |
+
post_processing_depth(depths[i]),
|
213 |
+
post_processing_depth(gts[i]),
|
214 |
+
]
|
215 |
+
)
|
216 |
+
)
|
217 |
+
# axes[i, 0].imshow(images[i])
|
218 |
+
# axes[i, 1].imshow(depths[i])
|
219 |
+
# axes[i, 2].imshow(gts[i])
|
220 |
+
# plt.show()
|
221 |
+
|
222 |
+
final = join_vertically(samples)
|
223 |
+
cv2.imwrite(save_file_path, cv2.cvtColor(final, cv2.COLOR_RGB2BGR))
|
224 |
+
print(f"Saved to file {save_file_path}")
|
225 |
+
|
226 |
+
|
227 |
+
def plot_pairs(image: np.ndarray, depth: np.ndarray, gt: np.ndarray) -> None:
|
228 |
+
batch_size = 1
|
229 |
+
fig, axes = plt.subplots(
|
230 |
+
batch_size, 3, figsize=(3 * batch_size, 20)
|
231 |
+
) # (number of images, 3)
|
232 |
+
for i in range(batch_size):
|
233 |
+
axes[i, 0].imshow(image)
|
234 |
+
axes[i, 1].imshow(depth)
|
235 |
+
axes[i, 2].imshow(gt)
|
236 |
+
plt.show()
|
s_multimae/visualize_2d_posemb.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from torch import Tensor
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
|
5 |
+
from s_multimae.model.multimae import build_2d_sincos_posemb
|
6 |
+
|
7 |
+
|
8 |
+
def visualize_2d_posemb():
|
9 |
+
NH, NW = 14, 14
|
10 |
+
dim_tokens = 768
|
11 |
+
|
12 |
+
colors = [
|
13 |
+
"Greys",
|
14 |
+
"Purples",
|
15 |
+
"Blues",
|
16 |
+
"Greens",
|
17 |
+
"Oranges",
|
18 |
+
"Reds",
|
19 |
+
"YlOrBr",
|
20 |
+
"YlOrRd",
|
21 |
+
"OrRd",
|
22 |
+
"PuRd",
|
23 |
+
"RdPu",
|
24 |
+
"BuPu",
|
25 |
+
"GnBu",
|
26 |
+
"PuBu",
|
27 |
+
"YlGnBu",
|
28 |
+
"PuBuGn",
|
29 |
+
"BuGn",
|
30 |
+
"YlGn",
|
31 |
+
]
|
32 |
+
|
33 |
+
pos_emb: Tensor = build_2d_sincos_posemb(NH, NW, dim_tokens)
|
34 |
+
pos_emb_numpy: np.ndarray = (
|
35 |
+
pos_emb.squeeze(0).permute(1, 2, 0).numpy()
|
36 |
+
) # 14 x 14 x 768
|
37 |
+
|
38 |
+
x = np.linspace(0, NH - 1, NH)
|
39 |
+
y = np.linspace(0, NW - 1, NW)
|
40 |
+
X, Y = np.meshgrid(x, y)
|
41 |
+
|
42 |
+
for color, i in zip(colors, range(0, pos_emb_numpy.shape[2], 100)):
|
43 |
+
ax = plt.axes(projection="3d")
|
44 |
+
Z = pos_emb_numpy[:, :, i]
|
45 |
+
|
46 |
+
# plt.imshow(Z, cmap='viridis')
|
47 |
+
# plt.savefig(f'posemb_visualization/test_{i}.png')
|
48 |
+
|
49 |
+
ax.plot_surface(
|
50 |
+
X,
|
51 |
+
Y,
|
52 |
+
Z,
|
53 |
+
# rstride=1, cstride=1,
|
54 |
+
cmap="viridis",
|
55 |
+
edgecolor="none",
|
56 |
+
)
|
57 |
+
plt.show()
|
58 |
+
plt.savefig(f"posemb_visualization/test_{i}.png")
|
s_multimae/visualizer.py
ADDED
@@ -0,0 +1,711 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import colorsys
|
2 |
+
from typing import Union
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import matplotlib.colors as mplc
|
6 |
+
import pycocotools.mask as mask_util
|
7 |
+
import matplotlib.figure as mplfigure
|
8 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
9 |
+
import matplotlib as mpl
|
10 |
+
from enum import Enum, unique
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
_LARGE_MASK_AREA_THRESH = 120000
|
14 |
+
_COLORS = (
|
15 |
+
np.array(
|
16 |
+
[
|
17 |
+
0.000,
|
18 |
+
0.447,
|
19 |
+
0.741,
|
20 |
+
0.850,
|
21 |
+
0.325,
|
22 |
+
0.098,
|
23 |
+
0.929,
|
24 |
+
0.694,
|
25 |
+
0.125,
|
26 |
+
0.494,
|
27 |
+
0.184,
|
28 |
+
0.556,
|
29 |
+
0.466,
|
30 |
+
0.674,
|
31 |
+
0.188,
|
32 |
+
0.301,
|
33 |
+
0.745,
|
34 |
+
0.933,
|
35 |
+
0.635,
|
36 |
+
0.078,
|
37 |
+
0.184,
|
38 |
+
0.300,
|
39 |
+
0.300,
|
40 |
+
0.300,
|
41 |
+
0.600,
|
42 |
+
0.600,
|
43 |
+
0.600,
|
44 |
+
1.000,
|
45 |
+
0.000,
|
46 |
+
0.000,
|
47 |
+
1.000,
|
48 |
+
0.500,
|
49 |
+
0.000,
|
50 |
+
0.749,
|
51 |
+
0.749,
|
52 |
+
0.000,
|
53 |
+
0.000,
|
54 |
+
1.000,
|
55 |
+
0.000,
|
56 |
+
0.000,
|
57 |
+
0.000,
|
58 |
+
1.000,
|
59 |
+
0.667,
|
60 |
+
0.000,
|
61 |
+
1.000,
|
62 |
+
0.333,
|
63 |
+
0.333,
|
64 |
+
0.000,
|
65 |
+
0.333,
|
66 |
+
0.667,
|
67 |
+
0.000,
|
68 |
+
0.333,
|
69 |
+
1.000,
|
70 |
+
0.000,
|
71 |
+
0.667,
|
72 |
+
0.333,
|
73 |
+
0.000,
|
74 |
+
0.667,
|
75 |
+
0.667,
|
76 |
+
0.000,
|
77 |
+
0.667,
|
78 |
+
1.000,
|
79 |
+
0.000,
|
80 |
+
1.000,
|
81 |
+
0.333,
|
82 |
+
0.000,
|
83 |
+
1.000,
|
84 |
+
0.667,
|
85 |
+
0.000,
|
86 |
+
1.000,
|
87 |
+
1.000,
|
88 |
+
0.000,
|
89 |
+
0.000,
|
90 |
+
0.333,
|
91 |
+
0.500,
|
92 |
+
0.000,
|
93 |
+
0.667,
|
94 |
+
0.500,
|
95 |
+
0.000,
|
96 |
+
1.000,
|
97 |
+
0.500,
|
98 |
+
0.333,
|
99 |
+
0.000,
|
100 |
+
0.500,
|
101 |
+
0.333,
|
102 |
+
0.333,
|
103 |
+
0.500,
|
104 |
+
0.333,
|
105 |
+
0.667,
|
106 |
+
0.500,
|
107 |
+
0.333,
|
108 |
+
1.000,
|
109 |
+
0.500,
|
110 |
+
0.667,
|
111 |
+
0.000,
|
112 |
+
0.500,
|
113 |
+
0.667,
|
114 |
+
0.333,
|
115 |
+
0.500,
|
116 |
+
0.667,
|
117 |
+
0.667,
|
118 |
+
0.500,
|
119 |
+
0.667,
|
120 |
+
1.000,
|
121 |
+
0.500,
|
122 |
+
1.000,
|
123 |
+
0.000,
|
124 |
+
0.500,
|
125 |
+
1.000,
|
126 |
+
0.333,
|
127 |
+
0.500,
|
128 |
+
1.000,
|
129 |
+
0.667,
|
130 |
+
0.500,
|
131 |
+
1.000,
|
132 |
+
1.000,
|
133 |
+
0.500,
|
134 |
+
0.000,
|
135 |
+
0.333,
|
136 |
+
1.000,
|
137 |
+
0.000,
|
138 |
+
0.667,
|
139 |
+
1.000,
|
140 |
+
0.000,
|
141 |
+
1.000,
|
142 |
+
1.000,
|
143 |
+
0.333,
|
144 |
+
0.000,
|
145 |
+
1.000,
|
146 |
+
0.333,
|
147 |
+
0.333,
|
148 |
+
1.000,
|
149 |
+
0.333,
|
150 |
+
0.667,
|
151 |
+
1.000,
|
152 |
+
0.333,
|
153 |
+
1.000,
|
154 |
+
1.000,
|
155 |
+
0.667,
|
156 |
+
0.000,
|
157 |
+
1.000,
|
158 |
+
0.667,
|
159 |
+
0.333,
|
160 |
+
1.000,
|
161 |
+
0.667,
|
162 |
+
0.667,
|
163 |
+
1.000,
|
164 |
+
0.667,
|
165 |
+
1.000,
|
166 |
+
1.000,
|
167 |
+
1.000,
|
168 |
+
0.000,
|
169 |
+
1.000,
|
170 |
+
1.000,
|
171 |
+
0.333,
|
172 |
+
1.000,
|
173 |
+
1.000,
|
174 |
+
0.667,
|
175 |
+
1.000,
|
176 |
+
0.333,
|
177 |
+
0.000,
|
178 |
+
0.000,
|
179 |
+
0.500,
|
180 |
+
0.000,
|
181 |
+
0.000,
|
182 |
+
0.667,
|
183 |
+
0.000,
|
184 |
+
0.000,
|
185 |
+
0.833,
|
186 |
+
0.000,
|
187 |
+
0.000,
|
188 |
+
1.000,
|
189 |
+
0.000,
|
190 |
+
0.000,
|
191 |
+
0.000,
|
192 |
+
0.167,
|
193 |
+
0.000,
|
194 |
+
0.000,
|
195 |
+
0.333,
|
196 |
+
0.000,
|
197 |
+
0.000,
|
198 |
+
0.500,
|
199 |
+
0.000,
|
200 |
+
0.000,
|
201 |
+
0.667,
|
202 |
+
0.000,
|
203 |
+
0.000,
|
204 |
+
0.833,
|
205 |
+
0.000,
|
206 |
+
0.000,
|
207 |
+
1.000,
|
208 |
+
0.000,
|
209 |
+
0.000,
|
210 |
+
0.000,
|
211 |
+
0.167,
|
212 |
+
0.000,
|
213 |
+
0.000,
|
214 |
+
0.333,
|
215 |
+
0.000,
|
216 |
+
0.000,
|
217 |
+
0.500,
|
218 |
+
0.000,
|
219 |
+
0.000,
|
220 |
+
0.667,
|
221 |
+
0.000,
|
222 |
+
0.000,
|
223 |
+
0.833,
|
224 |
+
0.000,
|
225 |
+
0.000,
|
226 |
+
1.000,
|
227 |
+
0.000,
|
228 |
+
0.000,
|
229 |
+
0.000,
|
230 |
+
0.143,
|
231 |
+
0.143,
|
232 |
+
0.143,
|
233 |
+
0.857,
|
234 |
+
0.857,
|
235 |
+
0.857,
|
236 |
+
1.000,
|
237 |
+
1.000,
|
238 |
+
1.000,
|
239 |
+
]
|
240 |
+
)
|
241 |
+
.astype(np.float32)
|
242 |
+
.reshape(-1, 3)
|
243 |
+
)
|
244 |
+
|
245 |
+
|
246 |
+
def random_color(rgb=False, maximum=255):
|
247 |
+
"""
|
248 |
+
Args:
|
249 |
+
rgb (bool): whether to return RGB colors or BGR colors.
|
250 |
+
maximum (int): either 255 or 1
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
ndarray: a vector of 3 numbers
|
254 |
+
"""
|
255 |
+
idx = np.random.randint(0, len(_COLORS))
|
256 |
+
ret = _COLORS[idx] * maximum
|
257 |
+
if not rgb:
|
258 |
+
ret = ret[::-1]
|
259 |
+
return ret
|
260 |
+
|
261 |
+
|
262 |
+
@unique
|
263 |
+
class ColorMode(Enum):
|
264 |
+
"""
|
265 |
+
Enum of different color modes to use for instance visualizations.
|
266 |
+
"""
|
267 |
+
|
268 |
+
IMAGE = 0
|
269 |
+
"""
|
270 |
+
Picks a random color for every instance and overlay segmentations with low opacity.
|
271 |
+
"""
|
272 |
+
SEGMENTATION = 1
|
273 |
+
"""
|
274 |
+
Let instances of the same category have similar colors
|
275 |
+
(from metadata.thing_colors), and overlay them with
|
276 |
+
high opacity. This provides more attention on the quality of segmentation.
|
277 |
+
"""
|
278 |
+
IMAGE_BW = 2
|
279 |
+
"""
|
280 |
+
Same as IMAGE, but convert all areas without masks to gray-scale.
|
281 |
+
Only available for drawing per-instance mask predictions.
|
282 |
+
"""
|
283 |
+
|
284 |
+
|
285 |
+
class VisImage:
|
286 |
+
def __init__(self, img, scale=1.0):
|
287 |
+
"""
|
288 |
+
Args:
|
289 |
+
img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
|
290 |
+
scale (float): scale the input image
|
291 |
+
"""
|
292 |
+
self.img = img
|
293 |
+
self.scale = scale
|
294 |
+
self.width, self.height = img.shape[1], img.shape[0]
|
295 |
+
self._setup_figure(img)
|
296 |
+
|
297 |
+
def _setup_figure(self, img):
|
298 |
+
"""
|
299 |
+
Args:
|
300 |
+
Same as in :meth:`__init__()`.
|
301 |
+
|
302 |
+
Returns:
|
303 |
+
fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
|
304 |
+
ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
|
305 |
+
"""
|
306 |
+
fig = mplfigure.Figure(frameon=False)
|
307 |
+
self.dpi = fig.get_dpi()
|
308 |
+
# add a small 1e-2 to avoid precision lost due to matplotlib's truncation
|
309 |
+
# (https://github.com/matplotlib/matplotlib/issues/15363)
|
310 |
+
fig.set_size_inches(
|
311 |
+
(self.width * self.scale + 1e-2) / self.dpi,
|
312 |
+
(self.height * self.scale + 1e-2) / self.dpi,
|
313 |
+
)
|
314 |
+
self.canvas = FigureCanvasAgg(fig)
|
315 |
+
# self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
|
316 |
+
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
|
317 |
+
ax.axis("off")
|
318 |
+
self.fig = fig
|
319 |
+
self.ax = ax
|
320 |
+
self.reset_image(img)
|
321 |
+
|
322 |
+
def reset_image(self, img):
|
323 |
+
"""
|
324 |
+
Args:
|
325 |
+
img: same as in __init__
|
326 |
+
"""
|
327 |
+
img = img.astype("uint8")
|
328 |
+
self.ax.imshow(
|
329 |
+
img, extent=(0, self.width, self.height, 0), interpolation="nearest"
|
330 |
+
)
|
331 |
+
|
332 |
+
def save(self, filepath):
|
333 |
+
"""
|
334 |
+
Args:
|
335 |
+
filepath (str): a string that contains the absolute path, including the file name, where
|
336 |
+
the visualized image will be saved.
|
337 |
+
"""
|
338 |
+
self.fig.savefig(filepath)
|
339 |
+
|
340 |
+
def get_image(self):
|
341 |
+
"""
|
342 |
+
Returns:
|
343 |
+
ndarray:
|
344 |
+
the visualized image of shape (H, W, 3) (RGB) in uint8 type.
|
345 |
+
The shape is scaled w.r.t the input image using the given `scale` argument.
|
346 |
+
"""
|
347 |
+
canvas = self.canvas
|
348 |
+
s, (width, height) = canvas.print_to_buffer()
|
349 |
+
# buf = io.BytesIO() # works for cairo backend
|
350 |
+
# canvas.print_rgba(buf)
|
351 |
+
# width, height = self.width, self.height
|
352 |
+
# s = buf.getvalue()
|
353 |
+
|
354 |
+
buffer = np.frombuffer(s, dtype="uint8")
|
355 |
+
|
356 |
+
img_rgba = buffer.reshape(height, width, 4)
|
357 |
+
rgb, alpha = np.split(img_rgba, [3], axis=2)
|
358 |
+
return rgb.astype("uint8")
|
359 |
+
|
360 |
+
|
361 |
+
class GenericMask:
|
362 |
+
"""
|
363 |
+
Attribute:
|
364 |
+
polygons (list[ndarray]): list[ndarray]: polygons for this mask.
|
365 |
+
Each ndarray has format [x, y, x, y, ...]
|
366 |
+
mask (ndarray): a binary mask
|
367 |
+
"""
|
368 |
+
|
369 |
+
def __init__(self, mask_or_polygons, height, width):
|
370 |
+
self._mask = self._polygons = self._has_holes = None
|
371 |
+
self.height = height
|
372 |
+
self.width = width
|
373 |
+
|
374 |
+
m = mask_or_polygons
|
375 |
+
if isinstance(m, dict):
|
376 |
+
# RLEs
|
377 |
+
assert "counts" in m and "size" in m
|
378 |
+
if isinstance(m["counts"], list): # uncompressed RLEs
|
379 |
+
h, w = m["size"]
|
380 |
+
assert h == height and w == width
|
381 |
+
m = mask_util.frPyObjects(m, h, w)
|
382 |
+
self._mask = mask_util.decode(m)[:, :]
|
383 |
+
return
|
384 |
+
|
385 |
+
if isinstance(m, list): # list[ndarray]
|
386 |
+
self._polygons = [np.asarray(x).reshape(-1) for x in m]
|
387 |
+
return
|
388 |
+
|
389 |
+
if isinstance(m, np.ndarray): # assumed to be a binary mask
|
390 |
+
assert m.shape[1] != 2, m.shape
|
391 |
+
assert m.shape == (
|
392 |
+
height,
|
393 |
+
width,
|
394 |
+
), f"mask shape: {m.shape}, target dims: {height}, {width}"
|
395 |
+
self._mask = m.astype("uint8")
|
396 |
+
return
|
397 |
+
|
398 |
+
raise ValueError(
|
399 |
+
"GenericMask cannot handle object {} of type '{}'".format(m, type(m))
|
400 |
+
)
|
401 |
+
|
402 |
+
@property
|
403 |
+
def mask(self):
|
404 |
+
if self._mask is None:
|
405 |
+
self._mask = self.polygons_to_mask(self._polygons)
|
406 |
+
return self._mask
|
407 |
+
|
408 |
+
@property
|
409 |
+
def polygons(self):
|
410 |
+
if self._polygons is None:
|
411 |
+
self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
|
412 |
+
return self._polygons
|
413 |
+
|
414 |
+
@property
|
415 |
+
def has_holes(self):
|
416 |
+
if self._has_holes is None:
|
417 |
+
if self._mask is not None:
|
418 |
+
self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
|
419 |
+
else:
|
420 |
+
self._has_holes = (
|
421 |
+
False # if original format is polygon, does not have holes
|
422 |
+
)
|
423 |
+
return self._has_holes
|
424 |
+
|
425 |
+
def mask_to_polygons(self, mask):
|
426 |
+
# cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
|
427 |
+
# hierarchy. External contours (boundary) of the object are placed in hierarchy-1.
|
428 |
+
# Internal contours (holes) are placed in hierarchy-2.
|
429 |
+
# cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.
|
430 |
+
mask = np.ascontiguousarray(
|
431 |
+
mask
|
432 |
+
) # some versions of cv2 does not support incontiguous arr
|
433 |
+
res = cv2.findContours(
|
434 |
+
mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE
|
435 |
+
)
|
436 |
+
hierarchy = res[-1]
|
437 |
+
if hierarchy is None: # empty mask
|
438 |
+
return [], False
|
439 |
+
has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
|
440 |
+
res = res[-2]
|
441 |
+
res = [x.flatten() for x in res]
|
442 |
+
# These coordinates from OpenCV are integers in range [0, W-1 or H-1].
|
443 |
+
# We add 0.5 to turn them into real-value coordinate space. A better solution
|
444 |
+
# would be to first +0.5 and then dilate the returned polygon by 0.5.
|
445 |
+
res = [x + 0.5 for x in res if len(x) >= 6]
|
446 |
+
return res, has_holes
|
447 |
+
|
448 |
+
def polygons_to_mask(self, polygons):
|
449 |
+
rle = mask_util.frPyObjects(polygons, self.height, self.width)
|
450 |
+
rle = mask_util.merge(rle)
|
451 |
+
return mask_util.decode(rle)[:, :]
|
452 |
+
|
453 |
+
def area(self):
|
454 |
+
return self.mask.sum()
|
455 |
+
|
456 |
+
def bbox(self):
|
457 |
+
p = mask_util.frPyObjects(self.polygons, self.height, self.width)
|
458 |
+
p = mask_util.merge(p)
|
459 |
+
bbox = mask_util.toBbox(p)
|
460 |
+
bbox[2] += bbox[0]
|
461 |
+
bbox[3] += bbox[1]
|
462 |
+
return bbox
|
463 |
+
|
464 |
+
|
465 |
+
class Visualizer:
|
466 |
+
"""
|
467 |
+
Visualizer that draws data about detection/segmentation on images.
|
468 |
+
|
469 |
+
It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
|
470 |
+
that draw primitive objects to images, as well as high-level wrappers like
|
471 |
+
`draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
|
472 |
+
that draw composite data in some pre-defined style.
|
473 |
+
|
474 |
+
Note that the exact visualization style for the high-level wrappers are subject to change.
|
475 |
+
Style such as color, opacity, label contents, visibility of labels, or even the visibility
|
476 |
+
of objects themselves (e.g. when the object is too small) may change according
|
477 |
+
to different heuristics, as long as the results still look visually reasonable.
|
478 |
+
|
479 |
+
To obtain a consistent style, you can implement custom drawing functions with the
|
480 |
+
abovementioned primitive methods instead. If you need more customized visualization
|
481 |
+
styles, you can process the data yourself following their format documented in
|
482 |
+
tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
|
483 |
+
intend to satisfy everyone's preference on drawing styles.
|
484 |
+
|
485 |
+
This visualizer focuses on high rendering quality rather than performance. It is not
|
486 |
+
designed to be used for real-time applications.
|
487 |
+
"""
|
488 |
+
|
489 |
+
# TODO implement a fast, rasterized version using OpenCV
|
490 |
+
|
491 |
+
def __init__(
|
492 |
+
self,
|
493 |
+
img_rgb: Union[Image.Image, np.ndarray],
|
494 |
+
scale=1.0,
|
495 |
+
instance_mode=ColorMode.IMAGE,
|
496 |
+
):
|
497 |
+
"""
|
498 |
+
Args:
|
499 |
+
img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
|
500 |
+
the height and width of the image respectively. C is the number of
|
501 |
+
color channels. The image is required to be in RGB format since that
|
502 |
+
is a requirement of the Matplotlib library. The image is also expected
|
503 |
+
to be in the range [0, 255].
|
504 |
+
instance_mode (ColorMode): defines one of the pre-defined style for drawing
|
505 |
+
instances on an image.
|
506 |
+
"""
|
507 |
+
if type(img_rgb) == np.ndarray:
|
508 |
+
img_rgb = img_rgb[:, :, ::-1]
|
509 |
+
else:
|
510 |
+
img_rgb = np.array(img_rgb)[:, :, ::-1]
|
511 |
+
self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
|
512 |
+
self.output = VisImage(self.img, scale=scale)
|
513 |
+
|
514 |
+
# too small texts are useless, therefore clamp to 9
|
515 |
+
self._default_font_size = max(
|
516 |
+
np.sqrt(self.output.height * self.output.width) // 90, 10 // scale
|
517 |
+
)
|
518 |
+
self._instance_mode = instance_mode
|
519 |
+
|
520 |
+
def draw_binary_mask(
|
521 |
+
self,
|
522 |
+
binary_mask,
|
523 |
+
color=None,
|
524 |
+
*,
|
525 |
+
edge_color=None,
|
526 |
+
text=None,
|
527 |
+
alpha=0.5,
|
528 |
+
area_threshold=10,
|
529 |
+
):
|
530 |
+
"""
|
531 |
+
Args:
|
532 |
+
binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
|
533 |
+
W is the image width. Each value in the array is either a 0 or 1 value of uint8
|
534 |
+
type.
|
535 |
+
color: color of the mask. Refer to `matplotlib.colors` for a full list of
|
536 |
+
formats that are accepted. If None, will pick a random color.
|
537 |
+
edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
|
538 |
+
full list of formats that are accepted.
|
539 |
+
text (str): if None, will be drawn on the object
|
540 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
541 |
+
area_threshold (float): a connected component smaller than this area will not be shown.
|
542 |
+
|
543 |
+
Returns:
|
544 |
+
output (VisImage): image object with mask drawn.
|
545 |
+
"""
|
546 |
+
if color is None:
|
547 |
+
color = random_color(rgb=True, maximum=1)
|
548 |
+
color = mplc.to_rgb(color)
|
549 |
+
|
550 |
+
has_valid_segment = False
|
551 |
+
binary_mask = binary_mask.astype("uint8") # opencv needs uint8
|
552 |
+
mask = GenericMask(binary_mask, self.output.height, self.output.width)
|
553 |
+
shape2d = (binary_mask.shape[0], binary_mask.shape[1])
|
554 |
+
|
555 |
+
if not mask.has_holes:
|
556 |
+
# draw polygons for regular masks
|
557 |
+
for segment in mask.polygons:
|
558 |
+
area = mask_util.area(
|
559 |
+
mask_util.frPyObjects([segment], shape2d[0], shape2d[1])
|
560 |
+
)
|
561 |
+
if area < (area_threshold or 0):
|
562 |
+
continue
|
563 |
+
has_valid_segment = True
|
564 |
+
segment = segment.reshape(-1, 2)
|
565 |
+
self.draw_polygon(
|
566 |
+
segment, color=color, edge_color=edge_color, alpha=alpha
|
567 |
+
)
|
568 |
+
else:
|
569 |
+
# TODO: Use Path/PathPatch to draw vector graphics:
|
570 |
+
# https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
|
571 |
+
rgba = np.zeros(shape2d + (4,), dtype="float32")
|
572 |
+
rgba[:, :, :3] = color
|
573 |
+
rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
|
574 |
+
has_valid_segment = True
|
575 |
+
self.output.ax.imshow(
|
576 |
+
rgba, extent=(0, self.output.width, self.output.height, 0)
|
577 |
+
)
|
578 |
+
|
579 |
+
if text is not None and has_valid_segment:
|
580 |
+
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
|
581 |
+
self._draw_text_in_mask(binary_mask, text, lighter_color)
|
582 |
+
return self.output
|
583 |
+
|
584 |
+
def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
|
585 |
+
"""
|
586 |
+
Args:
|
587 |
+
segment: numpy array of shape Nx2, containing all the points in the polygon.
|
588 |
+
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
|
589 |
+
formats that are accepted.
|
590 |
+
edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
|
591 |
+
full list of formats that are accepted. If not provided, a darker shade
|
592 |
+
of the polygon color will be used instead.
|
593 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
594 |
+
|
595 |
+
Returns:
|
596 |
+
output (VisImage): image object with polygon drawn.
|
597 |
+
"""
|
598 |
+
if edge_color is None:
|
599 |
+
# make edge color darker than the polygon color
|
600 |
+
if alpha > 0.8:
|
601 |
+
edge_color = self._change_color_brightness(
|
602 |
+
color, brightness_factor=-0.7
|
603 |
+
)
|
604 |
+
else:
|
605 |
+
edge_color = color
|
606 |
+
edge_color = mplc.to_rgb(edge_color) + (1,)
|
607 |
+
|
608 |
+
polygon = mpl.patches.Polygon(
|
609 |
+
segment,
|
610 |
+
fill=True,
|
611 |
+
facecolor=mplc.to_rgb(color) + (alpha,),
|
612 |
+
edgecolor=edge_color,
|
613 |
+
linewidth=max(self._default_font_size // 15 * self.output.scale, 1),
|
614 |
+
)
|
615 |
+
self.output.ax.add_patch(polygon)
|
616 |
+
return self.output
|
617 |
+
|
618 |
+
"""
|
619 |
+
Internal methods:
|
620 |
+
"""
|
621 |
+
|
622 |
+
def _change_color_brightness(self, color, brightness_factor):
|
623 |
+
"""
|
624 |
+
Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
|
625 |
+
less or more saturation than the original color.
|
626 |
+
|
627 |
+
Args:
|
628 |
+
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
|
629 |
+
formats that are accepted.
|
630 |
+
brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
|
631 |
+
0 will correspond to no change, a factor in [-1.0, 0) range will result in
|
632 |
+
a darker color and a factor in (0, 1.0] range will result in a lighter color.
|
633 |
+
|
634 |
+
Returns:
|
635 |
+
modified_color (tuple[double]): a tuple containing the RGB values of the
|
636 |
+
modified color. Each value in the tuple is in the [0.0, 1.0] range.
|
637 |
+
"""
|
638 |
+
assert brightness_factor >= -1.0 and brightness_factor <= 1.0
|
639 |
+
color = mplc.to_rgb(color)
|
640 |
+
polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
|
641 |
+
modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
|
642 |
+
modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
|
643 |
+
modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
|
644 |
+
modified_color = colorsys.hls_to_rgb(
|
645 |
+
polygon_color[0], modified_lightness, polygon_color[2]
|
646 |
+
)
|
647 |
+
return modified_color
|
648 |
+
|
649 |
+
def _draw_text_in_mask(self, binary_mask, text, color):
|
650 |
+
"""
|
651 |
+
Find proper places to draw text given a binary mask.
|
652 |
+
"""
|
653 |
+
# TODO sometimes drawn on wrong objects. the heuristics here can improve.
|
654 |
+
_num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(
|
655 |
+
binary_mask, 8
|
656 |
+
)
|
657 |
+
if stats[1:, -1].size == 0:
|
658 |
+
return
|
659 |
+
largest_component_id = np.argmax(stats[1:, -1]) + 1
|
660 |
+
|
661 |
+
# draw text on the largest component, as well as other very large components.
|
662 |
+
for cid in range(1, _num_cc):
|
663 |
+
if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
|
664 |
+
# median is more stable than centroid
|
665 |
+
# center = centroids[largest_component_id]
|
666 |
+
center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
|
667 |
+
self.draw_text(text, center, color=color)
|
668 |
+
|
669 |
+
def get_output(self):
|
670 |
+
"""
|
671 |
+
Returns:
|
672 |
+
output (VisImage): the image output containing the visualizations added
|
673 |
+
to the image.
|
674 |
+
"""
|
675 |
+
return self.output
|
676 |
+
|
677 |
+
|
678 |
+
def apply_threshold(pred: np.ndarray) -> np.ndarray:
|
679 |
+
"""Apply threshold to a salient map
|
680 |
+
|
681 |
+
Args:
|
682 |
+
pred (np.ndarray): each pixel is in range [0, 255]
|
683 |
+
|
684 |
+
Returns:
|
685 |
+
np.ndarray: each pixel is only 0.0 or 1.0
|
686 |
+
"""
|
687 |
+
binary_mask = pred / 255.0
|
688 |
+
binary_mask[binary_mask >= 0.5] = 1.0
|
689 |
+
binary_mask[binary_mask < 0.5] = 0.0
|
690 |
+
return binary_mask
|
691 |
+
|
692 |
+
|
693 |
+
def normalize(data: np.ndarray) -> np.ndarray:
|
694 |
+
return (data - data.min()) / (data.max() - data.min() + 1e-8)
|
695 |
+
|
696 |
+
|
697 |
+
def post_processing_depth(depth: np.ndarray) -> np.ndarray:
|
698 |
+
depth = (normalize(depth) * 255).astype(np.uint8)
|
699 |
+
return cv2.applyColorMap(depth, cv2.COLORMAP_OCEAN)
|
700 |
+
|
701 |
+
|
702 |
+
def apply_vis_to_image(
|
703 |
+
rgb: np.ndarray, binary_mask: np.ndarray, color: np.ndarray
|
704 |
+
) -> np.ndarray:
|
705 |
+
if rgb.shape[:2] != binary_mask.shape[:2]:
|
706 |
+
print(rgb.shape, binary_mask.shape)
|
707 |
+
binary_mask = cv2.resize(binary_mask, [rgb.shape[1], rgb.shape[0]])
|
708 |
+
visualizer = Visualizer(rgb)
|
709 |
+
vis_image: VisImage = visualizer.draw_binary_mask(binary_mask, color)
|
710 |
+
vis_image = vis_image.get_image()[:, :, ::-1]
|
711 |
+
return vis_image
|
streamlit_apps/__init__.py
ADDED
File without changes
|
streamlit_apps/app.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
sys.path.append(os.getcwd())
|
4 |
+
|
5 |
+
import multiprocessing
|
6 |
+
|
7 |
+
import streamlit as st
|
8 |
+
|
9 |
+
from app_utils.color_selection_ui import color_selection_ui
|
10 |
+
from app_utils.depth_selection_ui import depth_selection_ui
|
11 |
+
from app_utils.device import device
|
12 |
+
from app_utils.sod_selection_ui import sod_selection_ui
|
13 |
+
|
14 |
+
|
15 |
+
class MODE:
|
16 |
+
IMAGE = "image"
|
17 |
+
VIDEO = "video"
|
18 |
+
WEBRTC = "webrtc"
|
19 |
+
DEMO = "demo"
|
20 |
+
|
21 |
+
|
22 |
+
TITLE = "S-MultiMAE: A Multi-Ground Truth approach for RGB-D Saliency Detection"
|
23 |
+
|
24 |
+
st.set_page_config(
|
25 |
+
page_title=TITLE,
|
26 |
+
page_icon="🧊",
|
27 |
+
layout="wide",
|
28 |
+
# initial_sidebar_state="expanded",
|
29 |
+
# menu_items={
|
30 |
+
# 'Get Help': 'https://www.extremelycoolapp.com/help',
|
31 |
+
# 'Report a bug': "https://www.extremelycoolapp.com/bug",
|
32 |
+
# 'About': "# This is a header. This is an *extremely* cool app!"
|
33 |
+
# }
|
34 |
+
)
|
35 |
+
st.title(TITLE)
|
36 |
+
|
37 |
+
with st.expander("INTRODUCTION"):
|
38 |
+
st.text(
|
39 |
+
f"""Demo for S-MultiMAE.
|
40 |
+
Device: {device.type}
|
41 |
+
Number of CPU(s): {multiprocessing.cpu_count()}"""
|
42 |
+
)
|
43 |
+
st.image("docs/figures/proposed_method_v5.drawio.png", use_column_width="always")
|
44 |
+
|
45 |
+
with st.expander("SETTINGS", expanded=True):
|
46 |
+
col1, col2 = st.columns(2)
|
47 |
+
|
48 |
+
with col1:
|
49 |
+
mode = st.radio(
|
50 |
+
"Mode",
|
51 |
+
(
|
52 |
+
MODE.IMAGE,
|
53 |
+
# MODE.VIDEO,
|
54 |
+
# MODE.WEBRTC,
|
55 |
+
# MODE.DEMO,
|
56 |
+
),
|
57 |
+
)
|
58 |
+
st.markdown("---")
|
59 |
+
color = color_selection_ui()
|
60 |
+
|
61 |
+
with col2:
|
62 |
+
depth_model = depth_selection_ui()
|
63 |
+
st.markdown("---")
|
64 |
+
sod_model, da = sod_selection_ui()
|
65 |
+
|
66 |
+
with st.expander("HOW TO USE", expanded=True):
|
67 |
+
st.text(
|
68 |
+
"(1) You can change the model type (using different backbones) in the settings."
|
69 |
+
)
|
70 |
+
st.text("(2) Upload an RGB image.")
|
71 |
+
st.text(
|
72 |
+
"(3) (Optional) Provide its corresponding depth. If not present, a pseudo-depth will be inferred by a rgb2depth model."
|
73 |
+
)
|
74 |
+
st.text(
|
75 |
+
"(4) You may try a different number of sets of salient objects the model can produce."
|
76 |
+
)
|
77 |
+
st.text("""(5) Click "Predict Salient Objects".""")
|
78 |
+
|
79 |
+
if mode == MODE.IMAGE:
|
80 |
+
from app_utils.image_inference import image_inference
|
81 |
+
|
82 |
+
image_inference(depth_model, sod_model, da, color)
|
83 |
+
# elif mode == MODE.VIDEO:
|
84 |
+
# from video_inference import video_inference
|
85 |
+
# video_inference(depth_model, sod_model, color)
|
86 |
+
# elif mode == MODE.WEBRTC:
|
87 |
+
# from webrtc_app import webrtc_app
|
88 |
+
# webrtc_app(depth_model, sod_model, color)
|
89 |
+
# elif mode == MODE.DEMO:
|
90 |
+
# from demo import demo
|
91 |
+
# demo()
|
streamlit_apps/app_utils/__init__.py
ADDED
File without changes
|
streamlit_apps/app_utils/app_env.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
app_env = os.environ.get("APP_ENVIRONMENT", "HUGGINGFACE")
|
4 |
+
|
5 |
+
IMAGE_SIZE = 224
|
6 |
+
|
7 |
+
|
8 |
+
class DEPTH_MODEL_TYPE:
|
9 |
+
DPT_DEPTH = "DPTDepth"
|
10 |
+
REL_DEPTH = "RelDepth"
|
11 |
+
|
12 |
+
|
13 |
+
class SOD_MODEL_TYPE:
|
14 |
+
S_MULTIMAE = "S-MultiMAE"
|
15 |
+
SPNET = "SPNet"
|
16 |
+
BBSNET = "BBSNet"
|
streamlit_apps/app_utils/app_utils.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
from typing import Tuple, Union
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import streamlit as st
|
8 |
+
from PIL import Image
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
num_format = "{:,}".format
|
12 |
+
|
13 |
+
|
14 |
+
def count_parameters(model: nn.Module) -> str:
|
15 |
+
"""Count the number of parameters of a model"""
|
16 |
+
return num_format(sum(p.numel() for p in model.parameters() if p.requires_grad))
|
17 |
+
|
18 |
+
|
19 |
+
class FrameRate:
|
20 |
+
def __init__(self) -> None:
|
21 |
+
self.c: int = 0
|
22 |
+
self.start_time: float = None
|
23 |
+
self.NO_FRAMES = 100
|
24 |
+
self.fps: float = -1
|
25 |
+
|
26 |
+
def reset(self) -> None:
|
27 |
+
self.start_time = time.time()
|
28 |
+
self.c = 0
|
29 |
+
self.fps = -1
|
30 |
+
|
31 |
+
def count(self) -> None:
|
32 |
+
self.c += 1
|
33 |
+
if self.c % self.NO_FRAMES == 0:
|
34 |
+
self.c = 0
|
35 |
+
end_time = time.time()
|
36 |
+
self.fps = self.NO_FRAMES / (end_time - self.start_time)
|
37 |
+
self.start_time = end_time
|
38 |
+
|
39 |
+
def show_fps(self, image: np.ndarray) -> np.ndarray:
|
40 |
+
if self.fps != -1:
|
41 |
+
return cv2.putText(
|
42 |
+
image,
|
43 |
+
f"FPS {self.fps:.0f}",
|
44 |
+
(50, 50),
|
45 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
46 |
+
fontScale=1,
|
47 |
+
color=(255, 0, 0),
|
48 |
+
thickness=2,
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
return image
|
52 |
+
|
53 |
+
|
54 |
+
class ImgContainer:
|
55 |
+
img: np.ndarray = None # raw image
|
56 |
+
frame_rate: FrameRate = FrameRate()
|
57 |
+
|
58 |
+
|
59 |
+
def load_video(video_path: str) -> bytes:
|
60 |
+
if not os.path.isfile(video_path):
|
61 |
+
return
|
62 |
+
with st.spinner(f"Loading video {video_path} ..."):
|
63 |
+
video_bytes = open(video_path, "rb").read()
|
64 |
+
st.video(video_bytes, format="video/mp4")
|
65 |
+
|
66 |
+
|
67 |
+
def normalize(data: np.ndarray) -> np.ndarray:
|
68 |
+
return (data - data.min()) / (data.max() - data.min() + 1e-8)
|
69 |
+
|
70 |
+
|
71 |
+
def get_size(image: Union[Image.Image, np.ndarray]) -> Tuple[int, int]:
|
72 |
+
"""Get resolution (w, h) of an image
|
73 |
+
An input image can be Pillow Image or CV2 Image
|
74 |
+
"""
|
75 |
+
if type(image) == np.ndarray:
|
76 |
+
return (image.shape[1], image.shape[0])
|
77 |
+
else:
|
78 |
+
return image.size
|
79 |
+
|
80 |
+
|
81 |
+
def random_choice(p: float) -> bool:
|
82 |
+
"""Return True if random float <= p"""
|
83 |
+
return random.random() <= p
|
streamlit_apps/app_utils/base_model.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import numpy as np
|
3 |
+
from torch import Tensor, nn
|
4 |
+
|
5 |
+
|
6 |
+
class BaseRGBDModel(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super(BaseRGBDModel, self).__init__()
|
9 |
+
"""
|
10 |
+
Requirements:
|
11 |
+
1. Construct a model
|
12 |
+
2. Load pretrained weights
|
13 |
+
3. Load model into device
|
14 |
+
4. Construct preprocessing
|
15 |
+
"""
|
16 |
+
|
17 |
+
def inference(
|
18 |
+
self,
|
19 |
+
image: Tensor,
|
20 |
+
depth: Tensor,
|
21 |
+
origin_shape: np.array,
|
22 |
+
) -> List[np.ndarray]:
|
23 |
+
"""
|
24 |
+
Given:
|
25 |
+
- An image (Tensor) with original shape [c, h, w]
|
26 |
+
- A depth image (Tensor) with a shape of [c, h, w], do not need to be the same shape as image
|
27 |
+
|
28 |
+
Requirements:
|
29 |
+
1. Preprocessing
|
30 |
+
2. Inference
|
31 |
+
3. Return saliency maps np.float32 between 0.0 and 1.0,
|
32 |
+
with the same size as original size
|
33 |
+
|
34 |
+
"""
|
35 |
+
raise NotImplementedError()
|
36 |
+
|
37 |
+
def batch_inference(
|
38 |
+
self,
|
39 |
+
images: Tensor,
|
40 |
+
depths: Tensor,
|
41 |
+
) -> List[np.ndarray]:
|
42 |
+
"""
|
43 |
+
Given:
|
44 |
+
- A batch of images (Tensor) with original shape [b, c, h, w]
|
45 |
+
- A batch of depths (Tensor) with a shape of [b, c, h, w], do not need to be the same shape as image
|
46 |
+
|
47 |
+
Requirements:
|
48 |
+
1. Preprocessing
|
49 |
+
2. Inference
|
50 |
+
3. Return saliency maps np.float32 between 0.0 and 1.0,
|
51 |
+
with the same size as original size
|
52 |
+
|
53 |
+
"""
|
54 |
+
raise NotImplementedError()
|
streamlit_apps/app_utils/color_selection_ui.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
from s_multimae.utils import hex_to_rgb
|
5 |
+
|
6 |
+
|
7 |
+
def color_selection_ui() -> np.ndarray:
|
8 |
+
color = st.color_picker("Pick A Color", value="#00f900", key="color")
|
9 |
+
color = hex_to_rgb(color)
|
10 |
+
return color
|
streamlit_apps/app_utils/depth_model.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
import torchvision.transforms.functional as TF
|
5 |
+
from torch import Tensor, nn
|
6 |
+
|
7 |
+
from .app_utils import count_parameters
|
8 |
+
from .device import device
|
9 |
+
from .dpt.models import DPTDepthModel
|
10 |
+
|
11 |
+
|
12 |
+
class BaseDepthModel:
|
13 |
+
def __init__(self, image_size: int) -> None:
|
14 |
+
self.image_size = image_size
|
15 |
+
self.model: nn.Module = None
|
16 |
+
|
17 |
+
def forward(self, image: Tensor) -> Tensor:
|
18 |
+
"""Perform forward inference for an image
|
19 |
+
Input image of shape [c, h, w]
|
20 |
+
Return of shape [c, h, w]
|
21 |
+
"""
|
22 |
+
raise NotImplementedError()
|
23 |
+
|
24 |
+
def batch_forward(self, images: Tensor) -> Tensor:
|
25 |
+
"""Perform forward inference for a batch of images
|
26 |
+
Input images of shape [b, c, h, w]
|
27 |
+
Return of shape [b, c, h, w]"""
|
28 |
+
raise NotImplementedError()
|
29 |
+
|
30 |
+
def get_number_of_parameters(self) -> int:
|
31 |
+
return count_parameters(self.model)
|
32 |
+
|
33 |
+
|
34 |
+
class DPTDepth(BaseDepthModel):
|
35 |
+
def __init__(self, image_size: int) -> None:
|
36 |
+
super().__init__(image_size)
|
37 |
+
print("DPTDepthconstructor")
|
38 |
+
weights_fname = "omnidata_rgb2depth_dpt_hybrid.pth"
|
39 |
+
weights_path = os.path.join("weights", weights_fname)
|
40 |
+
if not os.path.isfile(weights_path):
|
41 |
+
from huggingface_hub import hf_hub_download
|
42 |
+
hf_hub_download(repo_id="RGBD-SOD/S-MultiMAE", filename=weights_fname)
|
43 |
+
os.system(f"mv {weights_fname} weights")
|
44 |
+
omnidata_ckpt = torch.load(
|
45 |
+
weights_path,
|
46 |
+
map_location="cpu",
|
47 |
+
)
|
48 |
+
|
49 |
+
self.model = DPTDepthModel()
|
50 |
+
self.model.load_state_dict(omnidata_ckpt)
|
51 |
+
self.model: DPTDepthModel = self.model.to(device).eval()
|
52 |
+
|
53 |
+
self.transform = transforms.Compose(
|
54 |
+
[
|
55 |
+
transforms.Resize(
|
56 |
+
(self.image_size, self.image_size),
|
57 |
+
interpolation=TF.InterpolationMode.BICUBIC,
|
58 |
+
),
|
59 |
+
transforms.Normalize(
|
60 |
+
(0.5, 0.5, 0.5),
|
61 |
+
(0.5, 0.5, 0.5),
|
62 |
+
),
|
63 |
+
]
|
64 |
+
)
|
65 |
+
|
66 |
+
def forward(self, image: Tensor) -> Tensor:
|
67 |
+
depth_model_input = self.transform(image.unsqueeze(0))
|
68 |
+
return self.model.forward(depth_model_input.to(device)).squeeze(0)
|
69 |
+
|
70 |
+
def batch_forward(self, images: Tensor) -> Tensor:
|
71 |
+
images: Tensor = TF.resize(
|
72 |
+
images,
|
73 |
+
(self.image_size, self.image_size),
|
74 |
+
interpolation=TF.InterpolationMode.BICUBIC,
|
75 |
+
)
|
76 |
+
depth_model_input = (images - 0.5) / 0.5
|
77 |
+
return self.model(depth_model_input.to(device))
|
streamlit_apps/app_utils/depth_selection_ui.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
from .app_env import DEPTH_MODEL_TYPE, IMAGE_SIZE
|
4 |
+
from .depth_model import BaseDepthModel, DPTDepth
|
5 |
+
|
6 |
+
|
7 |
+
@st.cache_resource
|
8 |
+
def load_depth_model(depth_model_type: DEPTH_MODEL_TYPE) -> DPTDepth:
|
9 |
+
if depth_model_type == DEPTH_MODEL_TYPE.DPT_DEPTH:
|
10 |
+
return DPTDepth(IMAGE_SIZE)
|
11 |
+
else:
|
12 |
+
return DPTDepth(IMAGE_SIZE)
|
13 |
+
|
14 |
+
|
15 |
+
def depth_selection_ui() -> BaseDepthModel:
|
16 |
+
depth_model: BaseDepthModel = None
|
17 |
+
depth_model_type = st.selectbox(
|
18 |
+
"Choose depth model",
|
19 |
+
(
|
20 |
+
DEPTH_MODEL_TYPE.DPT_DEPTH,
|
21 |
+
# DEPTH_MODEL_TYPE.REL_DEPTH,
|
22 |
+
),
|
23 |
+
key="depth_model_type",
|
24 |
+
)
|
25 |
+
depth_model = load_depth_model(depth_model_type)
|
26 |
+
st.text(f"Number of parameters {depth_model.get_number_of_parameters()}")
|
27 |
+
return depth_model
|
streamlit_apps/app_utils/device.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
cpu_device = torch.device("cpu")
|
4 |
+
# device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
|
5 |
+
device = cpu_device
|
streamlit_apps/app_utils/dpt/__init__.py
ADDED
File without changes
|
streamlit_apps/app_utils/dpt/base_model.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class BaseModel(torch.nn.Module):
|
5 |
+
def load(self, path):
|
6 |
+
"""Load model from file.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
path (str): file path
|
10 |
+
"""
|
11 |
+
parameters = torch.load(path, map_location=torch.device("cpu"))
|
12 |
+
|
13 |
+
if "optimizer" in parameters:
|
14 |
+
parameters = parameters["model"]
|
15 |
+
|
16 |
+
self.load_state_dict(parameters)
|
streamlit_apps/app_utils/dpt/blocks.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .vit import (
|
5 |
+
_make_pretrained_vitb_rn50_384,
|
6 |
+
_make_pretrained_vitl16_384,
|
7 |
+
_make_pretrained_vitb16_384,
|
8 |
+
forward_vit,
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
def _make_encoder(
|
13 |
+
backbone,
|
14 |
+
features,
|
15 |
+
use_pretrained,
|
16 |
+
groups=1,
|
17 |
+
expand=False,
|
18 |
+
exportable=True,
|
19 |
+
hooks=None,
|
20 |
+
use_vit_only=False,
|
21 |
+
use_readout="ignore",
|
22 |
+
enable_attention_hooks=False,
|
23 |
+
):
|
24 |
+
if backbone == "vitl16_384":
|
25 |
+
pretrained = _make_pretrained_vitl16_384(
|
26 |
+
use_pretrained,
|
27 |
+
hooks=hooks,
|
28 |
+
use_readout=use_readout,
|
29 |
+
enable_attention_hooks=enable_attention_hooks,
|
30 |
+
)
|
31 |
+
scratch = _make_scratch(
|
32 |
+
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
33 |
+
) # ViT-L/16 - 85.0% Top1 (backbone)
|
34 |
+
elif backbone == "vitb_rn50_384":
|
35 |
+
pretrained = _make_pretrained_vitb_rn50_384(
|
36 |
+
use_pretrained,
|
37 |
+
hooks=hooks,
|
38 |
+
use_vit_only=use_vit_only,
|
39 |
+
use_readout=use_readout,
|
40 |
+
enable_attention_hooks=enable_attention_hooks,
|
41 |
+
)
|
42 |
+
scratch = _make_scratch(
|
43 |
+
[256, 512, 768, 768], features, groups=groups, expand=expand
|
44 |
+
) # ViT-H/16 - 85.0% Top1 (backbone)
|
45 |
+
elif backbone == "vitb16_384":
|
46 |
+
pretrained = _make_pretrained_vitb16_384(
|
47 |
+
use_pretrained,
|
48 |
+
hooks=hooks,
|
49 |
+
use_readout=use_readout,
|
50 |
+
enable_attention_hooks=enable_attention_hooks,
|
51 |
+
)
|
52 |
+
scratch = _make_scratch(
|
53 |
+
[96, 192, 384, 768], features, groups=groups, expand=expand
|
54 |
+
) # ViT-B/16 - 84.6% Top1 (backbone)
|
55 |
+
elif backbone == "resnext101_wsl":
|
56 |
+
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
57 |
+
scratch = _make_scratch(
|
58 |
+
[256, 512, 1024, 2048], features, groups=groups, expand=expand
|
59 |
+
) # efficientnet_lite3
|
60 |
+
else:
|
61 |
+
print(f"Backbone '{backbone}' not implemented")
|
62 |
+
assert False
|
63 |
+
|
64 |
+
return pretrained, scratch
|
65 |
+
|
66 |
+
|
67 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
68 |
+
scratch = nn.Module()
|
69 |
+
|
70 |
+
out_shape1 = out_shape
|
71 |
+
out_shape2 = out_shape
|
72 |
+
out_shape3 = out_shape
|
73 |
+
out_shape4 = out_shape
|
74 |
+
if expand == True:
|
75 |
+
out_shape1 = out_shape
|
76 |
+
out_shape2 = out_shape * 2
|
77 |
+
out_shape3 = out_shape * 4
|
78 |
+
out_shape4 = out_shape * 8
|
79 |
+
|
80 |
+
scratch.layer1_rn = nn.Conv2d(
|
81 |
+
in_shape[0],
|
82 |
+
out_shape1,
|
83 |
+
kernel_size=3,
|
84 |
+
stride=1,
|
85 |
+
padding=1,
|
86 |
+
bias=False,
|
87 |
+
groups=groups,
|
88 |
+
)
|
89 |
+
scratch.layer2_rn = nn.Conv2d(
|
90 |
+
in_shape[1],
|
91 |
+
out_shape2,
|
92 |
+
kernel_size=3,
|
93 |
+
stride=1,
|
94 |
+
padding=1,
|
95 |
+
bias=False,
|
96 |
+
groups=groups,
|
97 |
+
)
|
98 |
+
scratch.layer3_rn = nn.Conv2d(
|
99 |
+
in_shape[2],
|
100 |
+
out_shape3,
|
101 |
+
kernel_size=3,
|
102 |
+
stride=1,
|
103 |
+
padding=1,
|
104 |
+
bias=False,
|
105 |
+
groups=groups,
|
106 |
+
)
|
107 |
+
scratch.layer4_rn = nn.Conv2d(
|
108 |
+
in_shape[3],
|
109 |
+
out_shape4,
|
110 |
+
kernel_size=3,
|
111 |
+
stride=1,
|
112 |
+
padding=1,
|
113 |
+
bias=False,
|
114 |
+
groups=groups,
|
115 |
+
)
|
116 |
+
|
117 |
+
return scratch
|
118 |
+
|
119 |
+
|
120 |
+
def _make_resnet_backbone(resnet):
|
121 |
+
pretrained = nn.Module()
|
122 |
+
pretrained.layer1 = nn.Sequential(
|
123 |
+
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
124 |
+
)
|
125 |
+
|
126 |
+
pretrained.layer2 = resnet.layer2
|
127 |
+
pretrained.layer3 = resnet.layer3
|
128 |
+
pretrained.layer4 = resnet.layer4
|
129 |
+
|
130 |
+
return pretrained
|
131 |
+
|
132 |
+
|
133 |
+
def _make_pretrained_resnext101_wsl(use_pretrained):
|
134 |
+
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
135 |
+
return _make_resnet_backbone(resnet)
|
136 |
+
|
137 |
+
|
138 |
+
class Interpolate(nn.Module):
|
139 |
+
"""Interpolation module."""
|
140 |
+
|
141 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
142 |
+
"""Init.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
scale_factor (float): scaling
|
146 |
+
mode (str): interpolation mode
|
147 |
+
"""
|
148 |
+
super(Interpolate, self).__init__()
|
149 |
+
|
150 |
+
self.interp = nn.functional.interpolate
|
151 |
+
self.scale_factor = scale_factor
|
152 |
+
self.mode = mode
|
153 |
+
self.align_corners = align_corners
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
"""Forward pass.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
x (tensor): input
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
tensor: interpolated data
|
163 |
+
"""
|
164 |
+
|
165 |
+
x = self.interp(
|
166 |
+
x,
|
167 |
+
scale_factor=self.scale_factor,
|
168 |
+
mode=self.mode,
|
169 |
+
align_corners=self.align_corners,
|
170 |
+
)
|
171 |
+
|
172 |
+
return x
|
173 |
+
|
174 |
+
|
175 |
+
class ResidualConvUnit(nn.Module):
|
176 |
+
"""Residual convolution module."""
|
177 |
+
|
178 |
+
def __init__(self, features):
|
179 |
+
"""Init.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
features (int): number of features
|
183 |
+
"""
|
184 |
+
super().__init__()
|
185 |
+
|
186 |
+
self.conv1 = nn.Conv2d(
|
187 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
188 |
+
)
|
189 |
+
|
190 |
+
self.conv2 = nn.Conv2d(
|
191 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
192 |
+
)
|
193 |
+
|
194 |
+
self.relu = nn.ReLU(inplace=True)
|
195 |
+
|
196 |
+
def forward(self, x):
|
197 |
+
"""Forward pass.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
x (tensor): input
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
tensor: output
|
204 |
+
"""
|
205 |
+
out = self.relu(x)
|
206 |
+
out = self.conv1(out)
|
207 |
+
out = self.relu(out)
|
208 |
+
out = self.conv2(out)
|
209 |
+
|
210 |
+
return out + x
|
211 |
+
|
212 |
+
|
213 |
+
class FeatureFusionBlock(nn.Module):
|
214 |
+
"""Feature fusion block."""
|
215 |
+
|
216 |
+
def __init__(self, features):
|
217 |
+
"""Init.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
features (int): number of features
|
221 |
+
"""
|
222 |
+
super(FeatureFusionBlock, self).__init__()
|
223 |
+
|
224 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
225 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
226 |
+
|
227 |
+
def forward(self, *xs):
|
228 |
+
"""Forward pass.
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
tensor: output
|
232 |
+
"""
|
233 |
+
output = xs[0]
|
234 |
+
|
235 |
+
if len(xs) == 2:
|
236 |
+
output += self.resConfUnit1(xs[1])
|
237 |
+
|
238 |
+
output = self.resConfUnit2(output)
|
239 |
+
|
240 |
+
output = nn.functional.interpolate(
|
241 |
+
output, scale_factor=2, mode="bilinear", align_corners=True
|
242 |
+
)
|
243 |
+
|
244 |
+
return output
|
245 |
+
|
246 |
+
|
247 |
+
class ResidualConvUnit_custom(nn.Module):
|
248 |
+
"""Residual convolution module."""
|
249 |
+
|
250 |
+
def __init__(self, features, activation, bn):
|
251 |
+
"""Init.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
features (int): number of features
|
255 |
+
"""
|
256 |
+
super().__init__()
|
257 |
+
|
258 |
+
self.bn = bn
|
259 |
+
|
260 |
+
self.groups = 1
|
261 |
+
|
262 |
+
self.conv1 = nn.Conv2d(
|
263 |
+
features,
|
264 |
+
features,
|
265 |
+
kernel_size=3,
|
266 |
+
stride=1,
|
267 |
+
padding=1,
|
268 |
+
bias=not self.bn,
|
269 |
+
groups=self.groups,
|
270 |
+
)
|
271 |
+
|
272 |
+
self.conv2 = nn.Conv2d(
|
273 |
+
features,
|
274 |
+
features,
|
275 |
+
kernel_size=3,
|
276 |
+
stride=1,
|
277 |
+
padding=1,
|
278 |
+
bias=not self.bn,
|
279 |
+
groups=self.groups,
|
280 |
+
)
|
281 |
+
|
282 |
+
if self.bn == True:
|
283 |
+
self.bn1 = nn.BatchNorm2d(features)
|
284 |
+
self.bn2 = nn.BatchNorm2d(features)
|
285 |
+
|
286 |
+
self.activation = activation
|
287 |
+
|
288 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
289 |
+
|
290 |
+
def forward(self, x):
|
291 |
+
"""Forward pass.
|
292 |
+
|
293 |
+
Args:
|
294 |
+
x (tensor): input
|
295 |
+
|
296 |
+
Returns:
|
297 |
+
tensor: output
|
298 |
+
"""
|
299 |
+
|
300 |
+
out = self.activation(x)
|
301 |
+
out = self.conv1(out)
|
302 |
+
if self.bn == True:
|
303 |
+
out = self.bn1(out)
|
304 |
+
|
305 |
+
out = self.activation(out)
|
306 |
+
out = self.conv2(out)
|
307 |
+
if self.bn == True:
|
308 |
+
out = self.bn2(out)
|
309 |
+
|
310 |
+
if self.groups > 1:
|
311 |
+
out = self.conv_merge(out)
|
312 |
+
|
313 |
+
return self.skip_add.add(out, x)
|
314 |
+
|
315 |
+
# return out + x
|
316 |
+
|
317 |
+
|
318 |
+
class FeatureFusionBlock_custom(nn.Module):
|
319 |
+
"""Feature fusion block."""
|
320 |
+
|
321 |
+
def __init__(
|
322 |
+
self,
|
323 |
+
features,
|
324 |
+
activation,
|
325 |
+
deconv=False,
|
326 |
+
bn=False,
|
327 |
+
expand=False,
|
328 |
+
align_corners=True,
|
329 |
+
):
|
330 |
+
"""Init.
|
331 |
+
|
332 |
+
Args:
|
333 |
+
features (int): number of features
|
334 |
+
"""
|
335 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
336 |
+
|
337 |
+
self.deconv = deconv
|
338 |
+
self.align_corners = align_corners
|
339 |
+
|
340 |
+
self.groups = 1
|
341 |
+
|
342 |
+
self.expand = expand
|
343 |
+
out_features = features
|
344 |
+
if self.expand == True:
|
345 |
+
out_features = features // 2
|
346 |
+
|
347 |
+
self.out_conv = nn.Conv2d(
|
348 |
+
features,
|
349 |
+
out_features,
|
350 |
+
kernel_size=1,
|
351 |
+
stride=1,
|
352 |
+
padding=0,
|
353 |
+
bias=True,
|
354 |
+
groups=1,
|
355 |
+
)
|
356 |
+
|
357 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
358 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
359 |
+
|
360 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
361 |
+
|
362 |
+
def forward(self, *xs):
|
363 |
+
"""Forward pass.
|
364 |
+
|
365 |
+
Returns:
|
366 |
+
tensor: output
|
367 |
+
"""
|
368 |
+
output = xs[0]
|
369 |
+
|
370 |
+
if len(xs) == 2:
|
371 |
+
res = self.resConfUnit1(xs[1])
|
372 |
+
output = self.skip_add.add(output, res)
|
373 |
+
# output += res
|
374 |
+
|
375 |
+
output = self.resConfUnit2(output)
|
376 |
+
|
377 |
+
output = nn.functional.interpolate(
|
378 |
+
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
379 |
+
)
|
380 |
+
|
381 |
+
output = self.out_conv(output)
|
382 |
+
|
383 |
+
return output
|
streamlit_apps/app_utils/dpt/midas_net.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from .base_model import BaseModel
|
10 |
+
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
11 |
+
|
12 |
+
|
13 |
+
class MidasNet_large(BaseModel):
|
14 |
+
"""Network for monocular depth estimation."""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=256, non_negative=True):
|
17 |
+
"""Init.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
path (str, optional): Path to saved model. Defaults to None.
|
21 |
+
features (int, optional): Number of features. Defaults to 256.
|
22 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
23 |
+
"""
|
24 |
+
print("Loading weights: ", path)
|
25 |
+
|
26 |
+
super(MidasNet_large, self).__init__()
|
27 |
+
|
28 |
+
use_pretrained = False if path is None else True
|
29 |
+
|
30 |
+
self.pretrained, self.scratch = _make_encoder(
|
31 |
+
backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained
|
32 |
+
)
|
33 |
+
|
34 |
+
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
35 |
+
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
36 |
+
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
37 |
+
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
38 |
+
|
39 |
+
self.scratch.output_conv = nn.Sequential(
|
40 |
+
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
41 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
42 |
+
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
43 |
+
nn.ReLU(True),
|
44 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
45 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
46 |
+
)
|
47 |
+
|
48 |
+
if path:
|
49 |
+
self.load(path)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
"""Forward pass.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
x (tensor): input data (image)
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
tensor: depth
|
59 |
+
"""
|
60 |
+
|
61 |
+
layer_1 = self.pretrained.layer1(x)
|
62 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
63 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
64 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
65 |
+
|
66 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
67 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
68 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
69 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
70 |
+
|
71 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
72 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
73 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
74 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
75 |
+
|
76 |
+
out = self.scratch.output_conv(path_1)
|
77 |
+
|
78 |
+
return torch.squeeze(out, dim=1)
|
streamlit_apps/app_utils/dpt/models.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch import Tensor
|
4 |
+
|
5 |
+
from .base_model import BaseModel
|
6 |
+
from .blocks import (
|
7 |
+
FeatureFusionBlock_custom,
|
8 |
+
Interpolate,
|
9 |
+
_make_encoder,
|
10 |
+
forward_vit,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
def _make_fusion_block(features, use_bn):
|
15 |
+
return FeatureFusionBlock_custom(
|
16 |
+
features,
|
17 |
+
nn.ReLU(False),
|
18 |
+
deconv=False,
|
19 |
+
bn=use_bn,
|
20 |
+
expand=False,
|
21 |
+
align_corners=True,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class DPT(BaseModel):
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
head,
|
29 |
+
features=256,
|
30 |
+
backbone="vitb_rn50_384",
|
31 |
+
readout="project",
|
32 |
+
channels_last=False,
|
33 |
+
use_bn=False,
|
34 |
+
enable_attention_hooks=False,
|
35 |
+
):
|
36 |
+
super(DPT, self).__init__()
|
37 |
+
|
38 |
+
self.channels_last = channels_last
|
39 |
+
|
40 |
+
hooks = {
|
41 |
+
"vitb_rn50_384": [0, 1, 8, 11],
|
42 |
+
"vitb16_384": [2, 5, 8, 11],
|
43 |
+
"vitl16_384": [5, 11, 17, 23],
|
44 |
+
}
|
45 |
+
|
46 |
+
# Instantiate backbone and reassemble blocks
|
47 |
+
self.pretrained, self.scratch = _make_encoder(
|
48 |
+
backbone,
|
49 |
+
features,
|
50 |
+
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
51 |
+
groups=1,
|
52 |
+
expand=False,
|
53 |
+
exportable=False,
|
54 |
+
hooks=hooks[backbone],
|
55 |
+
use_readout=readout,
|
56 |
+
enable_attention_hooks=enable_attention_hooks,
|
57 |
+
)
|
58 |
+
|
59 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
60 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
61 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
62 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
63 |
+
|
64 |
+
self.scratch.output_conv = head
|
65 |
+
|
66 |
+
def forward(self, x: Tensor) -> Tensor:
|
67 |
+
if self.channels_last == True:
|
68 |
+
x.contiguous(memory_format=torch.channels_last)
|
69 |
+
|
70 |
+
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
71 |
+
|
72 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
73 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
74 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
75 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
76 |
+
|
77 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
78 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
79 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
80 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
81 |
+
|
82 |
+
out = self.scratch.output_conv(path_1)
|
83 |
+
|
84 |
+
return out
|
85 |
+
|
86 |
+
|
87 |
+
class DPTDepthModel(DPT):
|
88 |
+
def __init__(
|
89 |
+
self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs
|
90 |
+
):
|
91 |
+
features = kwargs["features"] if "features" in kwargs else 256
|
92 |
+
|
93 |
+
self.scale = scale
|
94 |
+
self.shift = shift
|
95 |
+
self.invert = invert
|
96 |
+
|
97 |
+
head = nn.Sequential(
|
98 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
99 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
100 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
101 |
+
nn.ReLU(True),
|
102 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
103 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
104 |
+
nn.Identity(),
|
105 |
+
)
|
106 |
+
|
107 |
+
super().__init__(head, **kwargs)
|
108 |
+
|
109 |
+
if path is not None:
|
110 |
+
self.load(path)
|
111 |
+
|
112 |
+
def forward(self, x: Tensor) -> Tensor:
|
113 |
+
"""Input x of shape [b, c, h, w]
|
114 |
+
Return tensor of shape [b, c, h, w]
|
115 |
+
"""
|
116 |
+
inv_depth = super().forward(x)
|
117 |
+
|
118 |
+
if self.invert:
|
119 |
+
depth = self.scale * inv_depth + self.shift
|
120 |
+
depth[depth < 1e-8] = 1e-8
|
121 |
+
depth = 1.0 / depth
|
122 |
+
return depth
|
123 |
+
else:
|
124 |
+
return inv_depth
|
streamlit_apps/app_utils/dpt/transforms.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
7 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
sample (dict): sample
|
11 |
+
size (tuple): image size
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
tuple: new size
|
15 |
+
"""
|
16 |
+
shape = list(sample["disparity"].shape)
|
17 |
+
|
18 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
19 |
+
return sample
|
20 |
+
|
21 |
+
scale = [0, 0]
|
22 |
+
scale[0] = size[0] / shape[0]
|
23 |
+
scale[1] = size[1] / shape[1]
|
24 |
+
|
25 |
+
scale = max(scale)
|
26 |
+
|
27 |
+
shape[0] = math.ceil(scale * shape[0])
|
28 |
+
shape[1] = math.ceil(scale * shape[1])
|
29 |
+
|
30 |
+
# resize
|
31 |
+
sample["image"] = cv2.resize(
|
32 |
+
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
33 |
+
)
|
34 |
+
|
35 |
+
sample["disparity"] = cv2.resize(
|
36 |
+
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
37 |
+
)
|
38 |
+
sample["mask"] = cv2.resize(
|
39 |
+
sample["mask"].astype(np.float32),
|
40 |
+
tuple(shape[::-1]),
|
41 |
+
interpolation=cv2.INTER_NEAREST,
|
42 |
+
)
|
43 |
+
sample["mask"] = sample["mask"].astype(bool)
|
44 |
+
|
45 |
+
return tuple(shape)
|
46 |
+
|
47 |
+
|
48 |
+
class Resize(object):
|
49 |
+
"""Resize sample to given size (width, height)."""
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
width,
|
54 |
+
height,
|
55 |
+
resize_target=True,
|
56 |
+
keep_aspect_ratio=False,
|
57 |
+
ensure_multiple_of=1,
|
58 |
+
resize_method="lower_bound",
|
59 |
+
image_interpolation_method=cv2.INTER_AREA,
|
60 |
+
):
|
61 |
+
"""Init.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
width (int): desired output width
|
65 |
+
height (int): desired output height
|
66 |
+
resize_target (bool, optional):
|
67 |
+
True: Resize the full sample (image, mask, target).
|
68 |
+
False: Resize image only.
|
69 |
+
Defaults to True.
|
70 |
+
keep_aspect_ratio (bool, optional):
|
71 |
+
True: Keep the aspect ratio of the input sample.
|
72 |
+
Output sample might not have the given width and height, and
|
73 |
+
resize behaviour depends on the parameter 'resize_method'.
|
74 |
+
Defaults to False.
|
75 |
+
ensure_multiple_of (int, optional):
|
76 |
+
Output width and height is constrained to be multiple of this parameter.
|
77 |
+
Defaults to 1.
|
78 |
+
resize_method (str, optional):
|
79 |
+
"lower_bound": Output will be at least as large as the given size.
|
80 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
81 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
82 |
+
Defaults to "lower_bound".
|
83 |
+
"""
|
84 |
+
self.__width = width
|
85 |
+
self.__height = height
|
86 |
+
|
87 |
+
self.__resize_target = resize_target
|
88 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
89 |
+
self.__multiple_of = ensure_multiple_of
|
90 |
+
self.__resize_method = resize_method
|
91 |
+
self.__image_interpolation_method = image_interpolation_method
|
92 |
+
|
93 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
94 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
95 |
+
|
96 |
+
if max_val is not None and y > max_val:
|
97 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
98 |
+
|
99 |
+
if y < min_val:
|
100 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
101 |
+
|
102 |
+
return y
|
103 |
+
|
104 |
+
def get_size(self, width, height):
|
105 |
+
# determine new height and width
|
106 |
+
scale_height = self.__height / height
|
107 |
+
scale_width = self.__width / width
|
108 |
+
|
109 |
+
if self.__keep_aspect_ratio:
|
110 |
+
if self.__resize_method == "lower_bound":
|
111 |
+
# scale such that output size is lower bound
|
112 |
+
if scale_width > scale_height:
|
113 |
+
# fit width
|
114 |
+
scale_height = scale_width
|
115 |
+
else:
|
116 |
+
# fit height
|
117 |
+
scale_width = scale_height
|
118 |
+
elif self.__resize_method == "upper_bound":
|
119 |
+
# scale such that output size is upper bound
|
120 |
+
if scale_width < scale_height:
|
121 |
+
# fit width
|
122 |
+
scale_height = scale_width
|
123 |
+
else:
|
124 |
+
# fit height
|
125 |
+
scale_width = scale_height
|
126 |
+
elif self.__resize_method == "minimal":
|
127 |
+
# scale as least as possbile
|
128 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
129 |
+
# fit width
|
130 |
+
scale_height = scale_width
|
131 |
+
else:
|
132 |
+
# fit height
|
133 |
+
scale_width = scale_height
|
134 |
+
else:
|
135 |
+
raise ValueError(
|
136 |
+
f"resize_method {self.__resize_method} not implemented"
|
137 |
+
)
|
138 |
+
|
139 |
+
if self.__resize_method == "lower_bound":
|
140 |
+
new_height = self.constrain_to_multiple_of(
|
141 |
+
scale_height * height, min_val=self.__height
|
142 |
+
)
|
143 |
+
new_width = self.constrain_to_multiple_of(
|
144 |
+
scale_width * width, min_val=self.__width
|
145 |
+
)
|
146 |
+
elif self.__resize_method == "upper_bound":
|
147 |
+
new_height = self.constrain_to_multiple_of(
|
148 |
+
scale_height * height, max_val=self.__height
|
149 |
+
)
|
150 |
+
new_width = self.constrain_to_multiple_of(
|
151 |
+
scale_width * width, max_val=self.__width
|
152 |
+
)
|
153 |
+
elif self.__resize_method == "minimal":
|
154 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
155 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
156 |
+
else:
|
157 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
158 |
+
|
159 |
+
return (new_width, new_height)
|
160 |
+
|
161 |
+
def __call__(self, sample):
|
162 |
+
width, height = self.get_size(
|
163 |
+
sample["image"].shape[1], sample["image"].shape[0]
|
164 |
+
)
|
165 |
+
|
166 |
+
# resize sample
|
167 |
+
sample["image"] = cv2.resize(
|
168 |
+
sample["image"],
|
169 |
+
(width, height),
|
170 |
+
interpolation=self.__image_interpolation_method,
|
171 |
+
)
|
172 |
+
|
173 |
+
if self.__resize_target:
|
174 |
+
if "disparity" in sample:
|
175 |
+
sample["disparity"] = cv2.resize(
|
176 |
+
sample["disparity"],
|
177 |
+
(width, height),
|
178 |
+
interpolation=cv2.INTER_NEAREST,
|
179 |
+
)
|
180 |
+
|
181 |
+
if "depth" in sample:
|
182 |
+
sample["depth"] = cv2.resize(
|
183 |
+
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
184 |
+
)
|
185 |
+
|
186 |
+
sample["mask"] = cv2.resize(
|
187 |
+
sample["mask"].astype(np.float32),
|
188 |
+
(width, height),
|
189 |
+
interpolation=cv2.INTER_NEAREST,
|
190 |
+
)
|
191 |
+
sample["mask"] = sample["mask"].astype(bool)
|
192 |
+
|
193 |
+
return sample
|
194 |
+
|
195 |
+
|
196 |
+
class NormalizeImage(object):
|
197 |
+
"""Normlize image by given mean and std."""
|
198 |
+
|
199 |
+
def __init__(self, mean, std):
|
200 |
+
self.__mean = mean
|
201 |
+
self.__std = std
|
202 |
+
|
203 |
+
def __call__(self, sample):
|
204 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
205 |
+
|
206 |
+
return sample
|
207 |
+
|
208 |
+
|
209 |
+
class PrepareForNet(object):
|
210 |
+
"""Prepare sample for usage as network input."""
|
211 |
+
|
212 |
+
def __init__(self):
|
213 |
+
pass
|
214 |
+
|
215 |
+
def __call__(self, sample):
|
216 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
217 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
218 |
+
|
219 |
+
if "mask" in sample:
|
220 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
221 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
222 |
+
|
223 |
+
if "disparity" in sample:
|
224 |
+
disparity = sample["disparity"].astype(np.float32)
|
225 |
+
sample["disparity"] = np.ascontiguousarray(disparity)
|
226 |
+
|
227 |
+
if "depth" in sample:
|
228 |
+
depth = sample["depth"].astype(np.float32)
|
229 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
230 |
+
|
231 |
+
return sample
|
streamlit_apps/app_utils/dpt/vit.py
ADDED
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import timm
|
4 |
+
import types
|
5 |
+
import math
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
activations = {}
|
10 |
+
|
11 |
+
|
12 |
+
def get_activation(name):
|
13 |
+
def hook(model, input, output):
|
14 |
+
activations[name] = output
|
15 |
+
|
16 |
+
return hook
|
17 |
+
|
18 |
+
|
19 |
+
attention = {}
|
20 |
+
|
21 |
+
|
22 |
+
def get_attention(name):
|
23 |
+
def hook(module, input, output):
|
24 |
+
x = input[0]
|
25 |
+
B, N, C = x.shape
|
26 |
+
qkv = (
|
27 |
+
module.qkv(x)
|
28 |
+
.reshape(B, N, 3, module.num_heads, C // module.num_heads)
|
29 |
+
.permute(2, 0, 3, 1, 4)
|
30 |
+
)
|
31 |
+
q, k, v = (
|
32 |
+
qkv[0],
|
33 |
+
qkv[1],
|
34 |
+
qkv[2],
|
35 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
36 |
+
|
37 |
+
attn = (q @ k.transpose(-2, -1)) * module.scale
|
38 |
+
|
39 |
+
attn = attn.softmax(dim=-1) # [:,:,1,1:]
|
40 |
+
attention[name] = attn
|
41 |
+
|
42 |
+
return hook
|
43 |
+
|
44 |
+
|
45 |
+
def get_mean_attention_map(attn, token, shape):
|
46 |
+
attn = attn[:, :, token, 1:]
|
47 |
+
attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float()
|
48 |
+
attn = torch.nn.functional.interpolate(
|
49 |
+
attn, size=shape[2:], mode="bicubic", align_corners=False
|
50 |
+
).squeeze(0)
|
51 |
+
|
52 |
+
all_attn = torch.mean(attn, 0)
|
53 |
+
|
54 |
+
return all_attn
|
55 |
+
|
56 |
+
|
57 |
+
class Slice(nn.Module):
|
58 |
+
def __init__(self, start_index=1):
|
59 |
+
super(Slice, self).__init__()
|
60 |
+
self.start_index = start_index
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
return x[:, self.start_index :]
|
64 |
+
|
65 |
+
|
66 |
+
class AddReadout(nn.Module):
|
67 |
+
def __init__(self, start_index=1):
|
68 |
+
super(AddReadout, self).__init__()
|
69 |
+
self.start_index = start_index
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
if self.start_index == 2:
|
73 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
74 |
+
else:
|
75 |
+
readout = x[:, 0]
|
76 |
+
return x[:, self.start_index :] + readout.unsqueeze(1)
|
77 |
+
|
78 |
+
|
79 |
+
class ProjectReadout(nn.Module):
|
80 |
+
def __init__(self, in_features, start_index=1):
|
81 |
+
super(ProjectReadout, self).__init__()
|
82 |
+
self.start_index = start_index
|
83 |
+
|
84 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
88 |
+
features = torch.cat((x[:, self.start_index :], readout), -1)
|
89 |
+
|
90 |
+
return self.project(features)
|
91 |
+
|
92 |
+
|
93 |
+
class Transpose(nn.Module):
|
94 |
+
def __init__(self, dim0, dim1):
|
95 |
+
super(Transpose, self).__init__()
|
96 |
+
self.dim0 = dim0
|
97 |
+
self.dim1 = dim1
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
x = x.transpose(self.dim0, self.dim1)
|
101 |
+
return x
|
102 |
+
|
103 |
+
|
104 |
+
def forward_vit(pretrained, x):
|
105 |
+
b, c, h, w = x.shape
|
106 |
+
|
107 |
+
glob = pretrained.model.forward_flex(x)
|
108 |
+
|
109 |
+
layer_1 = pretrained.activations["1"]
|
110 |
+
layer_2 = pretrained.activations["2"]
|
111 |
+
layer_3 = pretrained.activations["3"]
|
112 |
+
layer_4 = pretrained.activations["4"]
|
113 |
+
|
114 |
+
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
115 |
+
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
116 |
+
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
117 |
+
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
118 |
+
|
119 |
+
unflatten = nn.Sequential(
|
120 |
+
nn.Unflatten(
|
121 |
+
2,
|
122 |
+
torch.Size(
|
123 |
+
[
|
124 |
+
h // pretrained.model.patch_size[1],
|
125 |
+
w // pretrained.model.patch_size[0],
|
126 |
+
]
|
127 |
+
),
|
128 |
+
)
|
129 |
+
)
|
130 |
+
|
131 |
+
if layer_1.ndim == 3:
|
132 |
+
layer_1 = unflatten(layer_1)
|
133 |
+
if layer_2.ndim == 3:
|
134 |
+
layer_2 = unflatten(layer_2)
|
135 |
+
if layer_3.ndim == 3:
|
136 |
+
layer_3 = unflatten(layer_3)
|
137 |
+
if layer_4.ndim == 3:
|
138 |
+
layer_4 = unflatten(layer_4)
|
139 |
+
|
140 |
+
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
141 |
+
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
142 |
+
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
143 |
+
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
144 |
+
|
145 |
+
return layer_1, layer_2, layer_3, layer_4
|
146 |
+
|
147 |
+
|
148 |
+
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
149 |
+
posemb_tok, posemb_grid = (
|
150 |
+
posemb[:, : self.start_index],
|
151 |
+
posemb[0, self.start_index :],
|
152 |
+
)
|
153 |
+
|
154 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
155 |
+
|
156 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
157 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
158 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
159 |
+
|
160 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
161 |
+
|
162 |
+
return posemb
|
163 |
+
|
164 |
+
|
165 |
+
def forward_flex(self, x):
|
166 |
+
b, c, h, w = x.shape
|
167 |
+
|
168 |
+
pos_embed = self._resize_pos_embed(
|
169 |
+
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
170 |
+
)
|
171 |
+
|
172 |
+
B = x.shape[0]
|
173 |
+
|
174 |
+
if hasattr(self.patch_embed, "backbone"):
|
175 |
+
x = self.patch_embed.backbone(x)
|
176 |
+
if isinstance(x, (list, tuple)):
|
177 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
178 |
+
|
179 |
+
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
180 |
+
|
181 |
+
if getattr(self, "dist_token", None) is not None:
|
182 |
+
cls_tokens = self.cls_token.expand(
|
183 |
+
B, -1, -1
|
184 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
185 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
186 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
187 |
+
else:
|
188 |
+
cls_tokens = self.cls_token.expand(
|
189 |
+
B, -1, -1
|
190 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
191 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
192 |
+
|
193 |
+
x = x + pos_embed
|
194 |
+
x = self.pos_drop(x)
|
195 |
+
|
196 |
+
for blk in self.blocks:
|
197 |
+
x = blk(x)
|
198 |
+
|
199 |
+
x = self.norm(x)
|
200 |
+
|
201 |
+
return x
|
202 |
+
|
203 |
+
|
204 |
+
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
205 |
+
if use_readout == "ignore":
|
206 |
+
readout_oper = [Slice(start_index)] * len(features)
|
207 |
+
elif use_readout == "add":
|
208 |
+
readout_oper = [AddReadout(start_index)] * len(features)
|
209 |
+
elif use_readout == "project":
|
210 |
+
readout_oper = [
|
211 |
+
ProjectReadout(vit_features, start_index) for out_feat in features
|
212 |
+
]
|
213 |
+
else:
|
214 |
+
assert (
|
215 |
+
False
|
216 |
+
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
217 |
+
|
218 |
+
return readout_oper
|
219 |
+
|
220 |
+
|
221 |
+
def _make_vit_b16_backbone(
|
222 |
+
model,
|
223 |
+
features=[96, 192, 384, 768],
|
224 |
+
size=[384, 384],
|
225 |
+
hooks=[2, 5, 8, 11],
|
226 |
+
vit_features=768,
|
227 |
+
use_readout="ignore",
|
228 |
+
start_index=1,
|
229 |
+
enable_attention_hooks=False,
|
230 |
+
):
|
231 |
+
pretrained = nn.Module()
|
232 |
+
|
233 |
+
pretrained.model = model
|
234 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
235 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
236 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
237 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
238 |
+
|
239 |
+
pretrained.activations = activations
|
240 |
+
|
241 |
+
if enable_attention_hooks:
|
242 |
+
pretrained.model.blocks[hooks[0]].attn.register_forward_hook(
|
243 |
+
get_attention("attn_1")
|
244 |
+
)
|
245 |
+
pretrained.model.blocks[hooks[1]].attn.register_forward_hook(
|
246 |
+
get_attention("attn_2")
|
247 |
+
)
|
248 |
+
pretrained.model.blocks[hooks[2]].attn.register_forward_hook(
|
249 |
+
get_attention("attn_3")
|
250 |
+
)
|
251 |
+
pretrained.model.blocks[hooks[3]].attn.register_forward_hook(
|
252 |
+
get_attention("attn_4")
|
253 |
+
)
|
254 |
+
pretrained.attention = attention
|
255 |
+
|
256 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
257 |
+
|
258 |
+
# 32, 48, 136, 384
|
259 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
260 |
+
readout_oper[0],
|
261 |
+
Transpose(1, 2),
|
262 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
263 |
+
nn.Conv2d(
|
264 |
+
in_channels=vit_features,
|
265 |
+
out_channels=features[0],
|
266 |
+
kernel_size=1,
|
267 |
+
stride=1,
|
268 |
+
padding=0,
|
269 |
+
),
|
270 |
+
nn.ConvTranspose2d(
|
271 |
+
in_channels=features[0],
|
272 |
+
out_channels=features[0],
|
273 |
+
kernel_size=4,
|
274 |
+
stride=4,
|
275 |
+
padding=0,
|
276 |
+
bias=True,
|
277 |
+
dilation=1,
|
278 |
+
groups=1,
|
279 |
+
),
|
280 |
+
)
|
281 |
+
|
282 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
283 |
+
readout_oper[1],
|
284 |
+
Transpose(1, 2),
|
285 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
286 |
+
nn.Conv2d(
|
287 |
+
in_channels=vit_features,
|
288 |
+
out_channels=features[1],
|
289 |
+
kernel_size=1,
|
290 |
+
stride=1,
|
291 |
+
padding=0,
|
292 |
+
),
|
293 |
+
nn.ConvTranspose2d(
|
294 |
+
in_channels=features[1],
|
295 |
+
out_channels=features[1],
|
296 |
+
kernel_size=2,
|
297 |
+
stride=2,
|
298 |
+
padding=0,
|
299 |
+
bias=True,
|
300 |
+
dilation=1,
|
301 |
+
groups=1,
|
302 |
+
),
|
303 |
+
)
|
304 |
+
|
305 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
306 |
+
readout_oper[2],
|
307 |
+
Transpose(1, 2),
|
308 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
309 |
+
nn.Conv2d(
|
310 |
+
in_channels=vit_features,
|
311 |
+
out_channels=features[2],
|
312 |
+
kernel_size=1,
|
313 |
+
stride=1,
|
314 |
+
padding=0,
|
315 |
+
),
|
316 |
+
)
|
317 |
+
|
318 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
319 |
+
readout_oper[3],
|
320 |
+
Transpose(1, 2),
|
321 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
322 |
+
nn.Conv2d(
|
323 |
+
in_channels=vit_features,
|
324 |
+
out_channels=features[3],
|
325 |
+
kernel_size=1,
|
326 |
+
stride=1,
|
327 |
+
padding=0,
|
328 |
+
),
|
329 |
+
nn.Conv2d(
|
330 |
+
in_channels=features[3],
|
331 |
+
out_channels=features[3],
|
332 |
+
kernel_size=3,
|
333 |
+
stride=2,
|
334 |
+
padding=1,
|
335 |
+
),
|
336 |
+
)
|
337 |
+
|
338 |
+
pretrained.model.start_index = start_index
|
339 |
+
pretrained.model.patch_size = [16, 16]
|
340 |
+
|
341 |
+
# We inject this function into the VisionTransformer instances so that
|
342 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
343 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
344 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
345 |
+
_resize_pos_embed, pretrained.model
|
346 |
+
)
|
347 |
+
|
348 |
+
return pretrained
|
349 |
+
|
350 |
+
|
351 |
+
def _make_vit_b_rn50_backbone(
|
352 |
+
model,
|
353 |
+
features=[256, 512, 768, 768],
|
354 |
+
size=[384, 384],
|
355 |
+
hooks=[0, 1, 8, 11],
|
356 |
+
vit_features=768,
|
357 |
+
use_vit_only=False,
|
358 |
+
use_readout="ignore",
|
359 |
+
start_index=1,
|
360 |
+
enable_attention_hooks=False,
|
361 |
+
):
|
362 |
+
pretrained = nn.Module()
|
363 |
+
|
364 |
+
pretrained.model = model
|
365 |
+
|
366 |
+
if use_vit_only == True:
|
367 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
368 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
369 |
+
else:
|
370 |
+
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
371 |
+
get_activation("1")
|
372 |
+
)
|
373 |
+
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
374 |
+
get_activation("2")
|
375 |
+
)
|
376 |
+
|
377 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
378 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
379 |
+
|
380 |
+
if enable_attention_hooks:
|
381 |
+
pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1"))
|
382 |
+
pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2"))
|
383 |
+
pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3"))
|
384 |
+
pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4"))
|
385 |
+
pretrained.attention = attention
|
386 |
+
|
387 |
+
pretrained.activations = activations
|
388 |
+
|
389 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
390 |
+
|
391 |
+
if use_vit_only == True:
|
392 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
393 |
+
readout_oper[0],
|
394 |
+
Transpose(1, 2),
|
395 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
396 |
+
nn.Conv2d(
|
397 |
+
in_channels=vit_features,
|
398 |
+
out_channels=features[0],
|
399 |
+
kernel_size=1,
|
400 |
+
stride=1,
|
401 |
+
padding=0,
|
402 |
+
),
|
403 |
+
nn.ConvTranspose2d(
|
404 |
+
in_channels=features[0],
|
405 |
+
out_channels=features[0],
|
406 |
+
kernel_size=4,
|
407 |
+
stride=4,
|
408 |
+
padding=0,
|
409 |
+
bias=True,
|
410 |
+
dilation=1,
|
411 |
+
groups=1,
|
412 |
+
),
|
413 |
+
)
|
414 |
+
|
415 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
416 |
+
readout_oper[1],
|
417 |
+
Transpose(1, 2),
|
418 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
419 |
+
nn.Conv2d(
|
420 |
+
in_channels=vit_features,
|
421 |
+
out_channels=features[1],
|
422 |
+
kernel_size=1,
|
423 |
+
stride=1,
|
424 |
+
padding=0,
|
425 |
+
),
|
426 |
+
nn.ConvTranspose2d(
|
427 |
+
in_channels=features[1],
|
428 |
+
out_channels=features[1],
|
429 |
+
kernel_size=2,
|
430 |
+
stride=2,
|
431 |
+
padding=0,
|
432 |
+
bias=True,
|
433 |
+
dilation=1,
|
434 |
+
groups=1,
|
435 |
+
),
|
436 |
+
)
|
437 |
+
else:
|
438 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
439 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
440 |
+
)
|
441 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
442 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
443 |
+
)
|
444 |
+
|
445 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
446 |
+
readout_oper[2],
|
447 |
+
Transpose(1, 2),
|
448 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
449 |
+
nn.Conv2d(
|
450 |
+
in_channels=vit_features,
|
451 |
+
out_channels=features[2],
|
452 |
+
kernel_size=1,
|
453 |
+
stride=1,
|
454 |
+
padding=0,
|
455 |
+
),
|
456 |
+
)
|
457 |
+
|
458 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
459 |
+
readout_oper[3],
|
460 |
+
Transpose(1, 2),
|
461 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
462 |
+
nn.Conv2d(
|
463 |
+
in_channels=vit_features,
|
464 |
+
out_channels=features[3],
|
465 |
+
kernel_size=1,
|
466 |
+
stride=1,
|
467 |
+
padding=0,
|
468 |
+
),
|
469 |
+
nn.Conv2d(
|
470 |
+
in_channels=features[3],
|
471 |
+
out_channels=features[3],
|
472 |
+
kernel_size=3,
|
473 |
+
stride=2,
|
474 |
+
padding=1,
|
475 |
+
),
|
476 |
+
)
|
477 |
+
|
478 |
+
pretrained.model.start_index = start_index
|
479 |
+
pretrained.model.patch_size = [16, 16]
|
480 |
+
|
481 |
+
# We inject this function into the VisionTransformer instances so that
|
482 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
483 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
484 |
+
|
485 |
+
# We inject this function into the VisionTransformer instances so that
|
486 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
487 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
488 |
+
_resize_pos_embed, pretrained.model
|
489 |
+
)
|
490 |
+
|
491 |
+
return pretrained
|
492 |
+
|
493 |
+
|
494 |
+
def _make_pretrained_vitb_rn50_384(
|
495 |
+
pretrained,
|
496 |
+
use_readout="ignore",
|
497 |
+
hooks=None,
|
498 |
+
use_vit_only=False,
|
499 |
+
enable_attention_hooks=False,
|
500 |
+
):
|
501 |
+
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
502 |
+
|
503 |
+
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
504 |
+
return _make_vit_b_rn50_backbone(
|
505 |
+
model,
|
506 |
+
features=[256, 512, 768, 768],
|
507 |
+
size=[384, 384],
|
508 |
+
hooks=hooks,
|
509 |
+
use_vit_only=use_vit_only,
|
510 |
+
use_readout=use_readout,
|
511 |
+
enable_attention_hooks=enable_attention_hooks,
|
512 |
+
)
|
513 |
+
|
514 |
+
|
515 |
+
def _make_pretrained_vitl16_384(
|
516 |
+
pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
|
517 |
+
):
|
518 |
+
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
519 |
+
|
520 |
+
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
521 |
+
return _make_vit_b16_backbone(
|
522 |
+
model,
|
523 |
+
features=[256, 512, 1024, 1024],
|
524 |
+
hooks=hooks,
|
525 |
+
vit_features=1024,
|
526 |
+
use_readout=use_readout,
|
527 |
+
enable_attention_hooks=enable_attention_hooks,
|
528 |
+
)
|
529 |
+
|
530 |
+
|
531 |
+
def _make_pretrained_vitb16_384(
|
532 |
+
pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
|
533 |
+
):
|
534 |
+
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
535 |
+
|
536 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
537 |
+
return _make_vit_b16_backbone(
|
538 |
+
model,
|
539 |
+
features=[96, 192, 384, 768],
|
540 |
+
hooks=hooks,
|
541 |
+
use_readout=use_readout,
|
542 |
+
enable_attention_hooks=enable_attention_hooks,
|
543 |
+
)
|
544 |
+
|
545 |
+
|
546 |
+
def _make_pretrained_deitb16_384(
|
547 |
+
pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
|
548 |
+
):
|
549 |
+
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
|
550 |
+
|
551 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
552 |
+
return _make_vit_b16_backbone(
|
553 |
+
model,
|
554 |
+
features=[96, 192, 384, 768],
|
555 |
+
hooks=hooks,
|
556 |
+
use_readout=use_readout,
|
557 |
+
enable_attention_hooks=enable_attention_hooks,
|
558 |
+
)
|
559 |
+
|
560 |
+
|
561 |
+
def _make_pretrained_deitb16_distil_384(
|
562 |
+
pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
|
563 |
+
):
|
564 |
+
model = timm.create_model(
|
565 |
+
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
|
566 |
+
)
|
567 |
+
|
568 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
569 |
+
return _make_vit_b16_backbone(
|
570 |
+
model,
|
571 |
+
features=[96, 192, 384, 768],
|
572 |
+
hooks=hooks,
|
573 |
+
use_readout=use_readout,
|
574 |
+
start_index=2,
|
575 |
+
enable_attention_hooks=enable_attention_hooks,
|
576 |
+
)
|
streamlit_apps/app_utils/image_inference.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import numpy as np
|
3 |
+
import streamlit as st
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
from s_multimae.da.base_da import BaseDataAugmentation
|
7 |
+
from .base_model import BaseRGBDModel
|
8 |
+
from .depth_model import BaseDepthModel
|
9 |
+
from .model import base_inference
|
10 |
+
|
11 |
+
if "depth" not in st.session_state:
|
12 |
+
st.session_state.depth = None
|
13 |
+
|
14 |
+
|
15 |
+
def image_inference(
|
16 |
+
depth_model: BaseDepthModel,
|
17 |
+
sod_model: BaseRGBDModel,
|
18 |
+
da: BaseDataAugmentation,
|
19 |
+
color: np.ndarray,
|
20 |
+
) -> None:
|
21 |
+
col1, col2 = st.columns(2)
|
22 |
+
image: Image = None
|
23 |
+
# depth: Image = None
|
24 |
+
|
25 |
+
with col1:
|
26 |
+
img_file_buffer = st.file_uploader(
|
27 |
+
"Upload an RGB image", key="img_file_buffer", type=["png", "jpg", "jpeg"]
|
28 |
+
)
|
29 |
+
if img_file_buffer is not None:
|
30 |
+
image = Image.open(img_file_buffer).convert("RGB")
|
31 |
+
st.image(image, caption="RGB")
|
32 |
+
|
33 |
+
with col2:
|
34 |
+
depth_file_buffer = st.file_uploader(
|
35 |
+
"Upload a depth image (Optional)",
|
36 |
+
key="depth_file_buffer",
|
37 |
+
type=["png", "jpg", "jpeg"],
|
38 |
+
)
|
39 |
+
if depth_file_buffer is not None:
|
40 |
+
st.session_state.depth = Image.open(depth_file_buffer).convert("L")
|
41 |
+
if st.session_state.depth is not None:
|
42 |
+
st.image(st.session_state.depth, caption="Depth")
|
43 |
+
|
44 |
+
if sod_model.cfg.ground_truth_version == 6:
|
45 |
+
num_sets_of_salient_objects = st.number_input(
|
46 |
+
"Number of sets of salient objects", value=1, min_value=1, max_value=10
|
47 |
+
)
|
48 |
+
else:
|
49 |
+
num_sets_of_salient_objects = 1
|
50 |
+
|
51 |
+
is_predict = st.button(
|
52 |
+
"Predict Salient Objects",
|
53 |
+
key="predict_salient_objects",
|
54 |
+
disabled=img_file_buffer is None,
|
55 |
+
)
|
56 |
+
if is_predict:
|
57 |
+
with st.spinner("Processing..."):
|
58 |
+
start_time = time.time()
|
59 |
+
pred_depth, pred_sods, pred_sms = base_inference(
|
60 |
+
depth_model,
|
61 |
+
sod_model,
|
62 |
+
da,
|
63 |
+
image,
|
64 |
+
st.session_state.depth,
|
65 |
+
color,
|
66 |
+
num_sets_of_salient_objects,
|
67 |
+
)
|
68 |
+
if st.session_state.depth is None:
|
69 |
+
st.session_state.depth = Image.fromarray(pred_depth).convert("L")
|
70 |
+
col2.image(st.session_state.depth, "Pseudo-depth")
|
71 |
+
|
72 |
+
if num_sets_of_salient_objects == 1:
|
73 |
+
st.warning(
|
74 |
+
"HINT: To view a wider variety of sets of salient objects, try to increase the number of sets the model can produce."
|
75 |
+
)
|
76 |
+
elif num_sets_of_salient_objects > 1:
|
77 |
+
st.warning(
|
78 |
+
"NOTE: As single-GT accounts for 77.61% of training samples, the model may not consistently yield different sets. The best approach is to gradually increase the number of sets of salient objects until you achieve the desired result."
|
79 |
+
)
|
80 |
+
|
81 |
+
st.info(f"Inference time: {time.time() - start_time:.4f} seconds")
|
82 |
+
|
83 |
+
sod_cols = st.columns(len(pred_sods))
|
84 |
+
|
85 |
+
for i, (pred_sod, pred_sm) in enumerate(zip(pred_sods, pred_sms)):
|
86 |
+
with sod_cols[i]:
|
87 |
+
st.image(pred_sod, "Salient Objects (Otsu threshold)")
|
88 |
+
st.image(pred_sm, "Salient Map")
|
streamlit_apps/app_utils/model.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple, Union
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torchvision.transforms.functional as TF
|
6 |
+
from PIL import Image
|
7 |
+
from torch import Tensor, nn
|
8 |
+
import torch
|
9 |
+
from skimage.filters import threshold_otsu
|
10 |
+
|
11 |
+
from s_multimae.da.base_da import BaseDataAugmentation
|
12 |
+
from s_multimae.model_pl import ModelPL
|
13 |
+
from s_multimae.visualizer import apply_vis_to_image
|
14 |
+
|
15 |
+
from .base_model import BaseRGBDModel
|
16 |
+
from .app_utils import get_size, normalize
|
17 |
+
from .depth_model import BaseDepthModel
|
18 |
+
|
19 |
+
|
20 |
+
# Environment
|
21 |
+
torch.set_grad_enabled(False)
|
22 |
+
from .device import device
|
23 |
+
|
24 |
+
print(f"device: {device}")
|
25 |
+
|
26 |
+
|
27 |
+
def post_processing_depth(depth: np.ndarray) -> np.ndarray:
|
28 |
+
depth = (normalize(depth) * 255).astype(np.uint8)
|
29 |
+
return cv2.applyColorMap(depth, cv2.COLORMAP_OCEAN)
|
30 |
+
|
31 |
+
|
32 |
+
def base_inference(
|
33 |
+
depth_model: BaseDepthModel,
|
34 |
+
sod_model: BaseRGBDModel,
|
35 |
+
da: BaseDataAugmentation,
|
36 |
+
raw_image: Union[Image.Image, np.ndarray],
|
37 |
+
raw_depth: Optional[Union[Image.Image, np.ndarray]] = None,
|
38 |
+
color: np.ndarray = None,
|
39 |
+
num_sets_of_salient_objects: int = 1,
|
40 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
41 |
+
"""Inference a pair of rgb image and depth image
|
42 |
+
if depth image is not provided, the depth_model will predict a depth image based on image
|
43 |
+
"""
|
44 |
+
origin_size = get_size(raw_image)
|
45 |
+
|
46 |
+
# Predict depth
|
47 |
+
image = TF.to_tensor(raw_image)
|
48 |
+
origin_shape = image.shape
|
49 |
+
if raw_depth is None:
|
50 |
+
depth: Tensor = depth_model.forward(image)
|
51 |
+
else:
|
52 |
+
depth = TF.to_tensor(raw_depth)
|
53 |
+
|
54 |
+
# Preprocessing
|
55 |
+
image, depth = da.forward(
|
56 |
+
raw_image, depth.cpu().detach().squeeze(0).numpy(), is_transform=False
|
57 |
+
)
|
58 |
+
|
59 |
+
# Inference
|
60 |
+
sms = sod_model.inference(image, depth, origin_shape, num_sets_of_salient_objects)
|
61 |
+
|
62 |
+
# Postprocessing
|
63 |
+
sods = []
|
64 |
+
|
65 |
+
for sm in sms:
|
66 |
+
binary_mask = np.array(sm)
|
67 |
+
t = threshold_otsu(binary_mask)
|
68 |
+
binary_mask[binary_mask < t] = 0.0
|
69 |
+
binary_mask[binary_mask >= t] = 1.0
|
70 |
+
|
71 |
+
sod = apply_vis_to_image(np.array(raw_image), binary_mask, color)
|
72 |
+
sods.append(sod)
|
73 |
+
|
74 |
+
depth = depth.permute(1, 2, 0).detach().cpu().numpy()
|
75 |
+
depth = cv2.resize(depth, origin_size)
|
76 |
+
depth = post_processing_depth(depth)
|
77 |
+
|
78 |
+
return depth, sods, [e / 255.0 for e in sms]
|
79 |
+
|
80 |
+
|
81 |
+
def transform_images(inputs: List[Image.Image], transform: nn.Module) -> Tensor:
|
82 |
+
if len(inputs) == 1:
|
83 |
+
return transform(inputs[0]).unsqueeze(0)
|
84 |
+
return torch.cat([transform(input).unsqueeze(0) for input in inputs])
|
streamlit_apps/app_utils/smultimae_model.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from torch import Tensor
|
3 |
+
from torchvision.transforms import Resize
|
4 |
+
|
5 |
+
from s_multimae.model_pl import ModelPL
|
6 |
+
from s_multimae.configs.base_config import base_cfg
|
7 |
+
|
8 |
+
from .base_model import BaseRGBDModel
|
9 |
+
|
10 |
+
|
11 |
+
class RGBDSMultiMAEModel(BaseRGBDModel):
|
12 |
+
def __init__(self, cfg: base_cfg, model: ModelPL):
|
13 |
+
"""Wrapper of RGBDModel"""
|
14 |
+
super(RGBDSMultiMAEModel, self).__init__()
|
15 |
+
self.model: ModelPL = model
|
16 |
+
self.cfg = cfg
|
17 |
+
self.resize = Resize([self.cfg.image_size, self.cfg.image_size])
|
18 |
+
|
19 |
+
def inference(
|
20 |
+
self,
|
21 |
+
image: Tensor,
|
22 |
+
depth: Tensor,
|
23 |
+
origin_shape: np.array,
|
24 |
+
num_sets_of_salient_objects: int = 1,
|
25 |
+
) -> np.ndarray:
|
26 |
+
# 1. Preprocessing
|
27 |
+
images = image.unsqueeze(0)
|
28 |
+
depths = depth.unsqueeze(0)
|
29 |
+
|
30 |
+
# images = self.resize(images)
|
31 |
+
# depths = self.resize(depths)
|
32 |
+
|
33 |
+
# 2. Inference
|
34 |
+
images, depths = images.to(self.model.device), depths.to(self.model.device)
|
35 |
+
if self.cfg.ground_truth_version == 6:
|
36 |
+
self.cfg.num_classes = num_sets_of_salient_objects
|
37 |
+
res = self.model.inference(
|
38 |
+
[[origin_shape[2], origin_shape[1]]],
|
39 |
+
images,
|
40 |
+
depths,
|
41 |
+
[num_sets_of_salient_objects],
|
42 |
+
)
|
43 |
+
return res[0]
|
streamlit_apps/app_utils/sod_selection_ui.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import streamlit as st
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .app_env import SOD_MODEL_TYPE
|
7 |
+
from .app_utils import count_parameters
|
8 |
+
from .smultimae_model import RGBDSMultiMAEModel
|
9 |
+
from .base_model import BaseRGBDModel
|
10 |
+
from .device import device
|
11 |
+
|
12 |
+
from s_multimae.da.dav6 import DataAugmentationV6
|
13 |
+
from s_multimae.configs.base_config import base_cfg
|
14 |
+
from s_multimae.configs.experiment_config import arg_cfg
|
15 |
+
from s_multimae.model_pl import ModelPL
|
16 |
+
|
17 |
+
# from spnet_model import SPNetModel
|
18 |
+
|
19 |
+
|
20 |
+
@st.cache_resource
|
21 |
+
def load_smultimae_model(
|
22 |
+
sod_model_config_key: str, top: int
|
23 |
+
) -> Tuple[BaseRGBDModel, base_cfg]:
|
24 |
+
"""
|
25 |
+
1. Construct model
|
26 |
+
2. Load pretrained weights
|
27 |
+
3. Load model into device
|
28 |
+
"""
|
29 |
+
cfg = arg_cfg[sod_model_config_key]()
|
30 |
+
|
31 |
+
weights_fname = f"s-multimae-{cfg.experiment_name}-top{top}.pth"
|
32 |
+
ckpt_path = os.path.join(
|
33 |
+
"weights", weights_fname
|
34 |
+
)
|
35 |
+
print(ckpt_path)
|
36 |
+
if not os.path.isfile(ckpt_path):
|
37 |
+
from huggingface_hub import hf_hub_download
|
38 |
+
hf_hub_download(repo_id="RGBD-SOD/S-MultiMAE", filename=weights_fname)
|
39 |
+
os.system(f"mv {weights_fname} weights")
|
40 |
+
assert os.path.isfile(ckpt_path)
|
41 |
+
|
42 |
+
# sod_model = ModelPL.load_from_checkpoint(
|
43 |
+
# ckpt_path,
|
44 |
+
# cfg=cfg,
|
45 |
+
# map_location=device,
|
46 |
+
# )
|
47 |
+
sod_model = ModelPL(cfg)
|
48 |
+
sod_model.model.load_state_dict(
|
49 |
+
torch.load(ckpt_path, map_location="cpu"), strict=False
|
50 |
+
)
|
51 |
+
da = DataAugmentationV6(cfg)
|
52 |
+
return RGBDSMultiMAEModel(cfg, sod_model), cfg, da
|
53 |
+
|
54 |
+
|
55 |
+
# @st.cache_resource
|
56 |
+
# def load_spnet_model() -> BaseRGBDModel:
|
57 |
+
# """
|
58 |
+
# 1. Construct model
|
59 |
+
# 2. Load pretrained weights
|
60 |
+
# 3. Load model into device
|
61 |
+
# """
|
62 |
+
# sod_model = SPNetModel()
|
63 |
+
# return sod_model
|
64 |
+
|
65 |
+
|
66 |
+
# @st.cache_resource
|
67 |
+
# def load_bbsnet_model() -> BaseRGBDModel:
|
68 |
+
# """
|
69 |
+
# 1. Construct model
|
70 |
+
# 2. Load pretrained weights
|
71 |
+
# 3. Load model into device
|
72 |
+
# """
|
73 |
+
# sod_model = BBSNetModel()
|
74 |
+
# return sod_model
|
75 |
+
|
76 |
+
|
77 |
+
def sod_selection_ui() -> BaseRGBDModel:
|
78 |
+
sod_model_type = st.selectbox(
|
79 |
+
"Choose SOD model",
|
80 |
+
(
|
81 |
+
SOD_MODEL_TYPE.S_MULTIMAE,
|
82 |
+
# SOD_MODEL_TYPE.SPNET,
|
83 |
+
# SOD_MODEL_TYPE.BBSNET,
|
84 |
+
),
|
85 |
+
key="sod_model_type",
|
86 |
+
)
|
87 |
+
|
88 |
+
if sod_model_type == SOD_MODEL_TYPE.S_MULTIMAE:
|
89 |
+
d = {
|
90 |
+
"S-MultiMAE [ViT-L] Multi-GT": {"top": 1, "cfg": "cfgv4_0_2006"},
|
91 |
+
"S-MultiMAE [ViT-B] Multi-GT": {"top": 1, "cfg": "cfgv4_0_2007"},
|
92 |
+
}
|
93 |
+
|
94 |
+
sod_model_config_key = st.selectbox(
|
95 |
+
"Choose config",
|
96 |
+
list(d.keys()),
|
97 |
+
key="sod_model_config_key",
|
98 |
+
)
|
99 |
+
sod_model, cfg, da = load_smultimae_model(
|
100 |
+
d[sod_model_config_key]["cfg"], d[sod_model_config_key]["top"]
|
101 |
+
)
|
102 |
+
# st.text(f"Model description: {cfg.description}")
|
103 |
+
# elif sod_model_type == SOD_MODEL_TYPE.SPNET:
|
104 |
+
# sod_model = load_spnet_model()
|
105 |
+
# st.text(f"Model description: SPNet (https://github.com/taozh2017/SPNet)")
|
106 |
+
# elif sod_model_type == SOD_MODEL_TYPE.BBSNET:
|
107 |
+
# sod_model = load_bbsnet_model()
|
108 |
+
# st.text(f"Model description: BBSNet (https://github.com/DengPingFan/BBS-Net)")
|
109 |
+
st.text(f"Number of parameters {count_parameters(sod_model)}")
|
110 |
+
|
111 |
+
return sod_model, da
|