wondervictor commited on
Commit
f773839
·
verified ·
1 Parent(s): 66b8239

Upload 10 files

Browse files
Files changed (9) hide show
  1. GETTING_STARTED.md +101 -0
  2. INSTALL.md +41 -0
  3. LICENSE +209 -0
  4. README copy.md +68 -0
  5. app.py +188 -0
  6. requirements.txt +26 -0
  7. train_net_fcclip.py +349 -0
  8. train_net_maftp.py +298 -0
  9. train_net_maskadapter.py +360 -0
GETTING_STARTED.md ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Getting Started with Mask-Adapter
2
+
3
+ This document provides a brief intro of the usage of Mask-Adapter.
4
+
5
+ Please see [Getting Started with Detectron2](https://github.com/facebookresearch/detectron2/blob/master/GETTING_STARTED.md) for full usage.
6
+
7
+
8
+ ### Inference Demo with Pre-trained Models
9
+
10
+ We provide `demo.py` that is able to demo builtin configs. Run it with:
11
+ ```
12
+ cd demo/
13
+ python demo.py \
14
+ --input input1.jpg input2.jpg \
15
+ [--other-options]
16
+ --opts MODEL.WEIGHTS /path/to/checkpoint_file
17
+ ```
18
+ The configs are made for training, therefore we need to specify `MODEL.WEIGHTS` to a model from model zoo for evaluation.
19
+ This command will run the inference and show visualizations in an OpenCV window.
20
+
21
+ For details of the command line arguments, see `demo.py -h` or look at its source code
22
+ to understand its behavior. Some common arguments are:
23
+ * To run __on your webcam__, replace `--input files` with `--webcam`.
24
+ * To run __on a video__, replace `--input files` with `--video-input video.mp4`.
25
+ * To run __on cpu__, add `MODEL.DEVICE cpu` after `--opts`.
26
+ * To save outputs to a directory (for images) or a file (for webcam or video), use `--output`.
27
+
28
+
29
+ ### Ground-truth Warmup Training
30
+ We provide the script `train_net_maskadapter.py` to train the mask-adapter using ground-truth masks.To train a model with `train_net_maskadapter.py`, first set up the corresponding datasets as described in [datasets/README.md](https://chatgpt.com/c/datasets/README.md) , and then run the following command:
31
+
32
+ ```
33
+ python train_net_maskadapter.py --num-gpus 4 \
34
+ --config-file configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml
35
+ ```
36
+
37
+ For the MAFTP model, run:
38
+
39
+
40
+ ```
41
+ python train_net_maskadapter.py --num-gpus 4 \
42
+ --config-file configs/ground-truth-warmup/mask-adapter/mask_adapter_maft_convnext_large_cocostuff_eval_ade20k.yaml \
43
+ MODEL.WEIGHTS /path/to/maftp_l.pth
44
+ ```
45
+
46
+ The configurations are set for 4-GPU training. Since we use the ADAMW optimizer, it is unclear how to scale the learning rate with batch size. If training with a single GPU, you will need to manually adjust the learning rate and batch size:
47
+
48
+
49
+ ```
50
+ python train_net_maskadapter.py \
51
+ --config-file configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml \
52
+ --num-gpus 1 SOLVER.IMS_PER_BATCH SET_TO_SOME_REASONABLE_VALUE SOLVER.BASE_LR SET_TO_SOME_REASONABLE_VALUE
53
+ ```
54
+
55
+ ### Combining Mask-Adapter Weights with Mask2Former
56
+
57
+ Since the ground-truth warmup phase for training the mask-adapter does not involve training Mask2Former, the weights obtained in the first phase will not include Mask2Former weights. To combine the weights, run the following command:
58
+
59
+
60
+ ```
61
+ python tools/weight_fuse.py \
62
+ --model_first_phase_path /path/to/first_phase.pth \
63
+ --model_sem_seg_path /path/to/maftp_l.pth \
64
+ --output_path /path/to/maftp_l_withadapter.pth
65
+ ```
66
+
67
+ ### Mixed-Masks Training
68
+ For the mixed-masks training phase, we provide two scripts: `train_net_fcclip.py` and `train_net_maftp.py`, which train the mask-adapter for FC-CLIP and MAFTP models, respectively. These two models use different backbones (CLIP) and training source data.
69
+ For FC-CLIP, run:
70
+
71
+
72
+ ```
73
+ python train_net_fcclip.py --num-gpus 4 \
74
+ --config-file configs/mixed-mask-training/fc-clip/fcclip/fcclip_convnext_large_eval_ade20k.yaml MODEL.WEIGHTS /path/to/checkpoint_file
75
+ ```
76
+
77
+ For MAFTP, run:
78
+
79
+
80
+ ```
81
+ python train_net_maftp.py --num-gpus 4 \
82
+ --config-file configs/mixed-mask-training/maftp/semantic/train_semantic_large_eval_a150.yaml MODEL.WEIGHTS /path/to/checkpoint_file
83
+ ```
84
+
85
+ To evaluate a model’s performance, for FC-CLIP, use:
86
+
87
+
88
+ ```
89
+ python train_net_fcclip.py \
90
+ --config-file configs/mixed-mask-training/fc-clip/fcclip/fcclip_convnext_large_eval_ade20k.yaml \
91
+ --eval-only MODEL.WEIGHTS /path/to/checkpoint_file
92
+ ```
93
+
94
+ For MAFTP, use:
95
+
96
+
97
+ ```
98
+ python train_net_maftp.py \
99
+ --config-file configs/mixed-mask-training/maftp/semantic/train_semantic_large_eval_a150.yaml \
100
+ --eval-only MODEL.WEIGHTS /path/to/checkpoint_file
101
+ ```
INSTALL.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Installation
2
+ ### Requirements
3
+ 1. Clone this repository
4
+ ```
5
+ git clone https://github.com/hustvl/MaskAdapter.git
6
+ ```
7
+ 2. Install the appropriate version of PyTorch for your CUDA version. Ensure that the PyTorch version is ≥ 1.9 and compatible with the version required by Detectron2. For CUDA 11.8, you can install the following:
8
+ ```
9
+ pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu118
10
+ ```
11
+ 3. Following [Detectron2 installation instructions](https://detectron2.readthedocs.io/tutorials/install.html) to install Detectron2.
12
+ ```
13
+ git clone https://github.com/facebookresearch/detectron2.git
14
+ python -m pip install -e detectron2
15
+ ```
16
+ 4. Install other requirements.
17
+ ```
18
+ pip install -r requirements.txt
19
+ cd fcclip/modeling/pixel_decoder/ops
20
+ sh make.sh
21
+ ```
22
+
23
+ ### Example conda environment configuration
24
+
25
+
26
+ ```bash
27
+ conda create --name mask_adapter python=3.8
28
+ conda activate mask_adapter
29
+ pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu118
30
+
31
+ git clone https://github.com/facebookresearch/detectron2.git
32
+ python -m pip install -e detectron2
33
+
34
+ pip install git+https://github.com/cocodataset/panopticapi.git
35
+ git clone https://github.com/hustvl/MaskAdapter.git
36
+ cd MaskAdapter
37
+ pip install -r requirements.txt
38
+ cd fcclip/modeling/pixel_decoder/ops
39
+ sh make.sh
40
+ cd ../../../..
41
+ ```
LICENSE ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ <<<<<<< HEAD
202
+ <<<<<<< HEAD
203
+ limitations under the License.
204
+ =======
205
+ limitations under the License.
206
+ >>>>>>> master
207
+ =======
208
+ limitations under the License.
209
+ >>>>>>> dd48391686cd7522a9a9dce5403bee133d823327
README copy.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MaskAdapter
2
+ <div align ="center">
3
+ <img src="./assets/logo.jpeg" width="20%">
4
+ <h1> Mask-Adapter </h1>
5
+ <h3> Mask-Adapter: The Devil is in the Masks for Open-Vocabulary Segmentation </h3>
6
+
7
+ YongKang Li<sup>1,\*</sup>, [Tianheng Cheng](https://scholar.google.com/citations?user=PH8rJHYAAAAJ&hl=zh-CN)<sup>1,\*</sup>, [Wenyu Liu](http://eic.hust.edu.cn/professor/liuwenyu)<sup>1</sup>, [Xinggang Wang](https://xwcv.github.io/)<sup>1,📧</sup>
8
+
9
+ <sup>1</sup> Huazhong University of Science and Technology,
10
+
11
+
12
+ (\* equal contribution, 📧 corresponding author)
13
+
14
+ [![arxiv paper](https://img.shields.io/badge/arXiv-Paper-red)]()
15
+ [![checkpoints](https://img.shields.io/badge/HuggingFace-🤗-orange)]()
16
+ [![🤗 HuggingFace Demo](https://img.shields.io/badge/Mask_Adapter-🤗_HF_Demo-orange)]()
17
+
18
+ </div>
19
+
20
+
21
+ <div align="center">
22
+ <img src="./assets/main_fig.png">
23
+ </div>
24
+
25
+ ## Highlights
26
+
27
+ * Mask-Adapter is a simple yet remarkably effective method and can be seamlessly integrated into open-vocabulary segmentation methods, e.g., [FC-CLIP](https://github.com/bytedance/fc-clip) and [MAFT-Plus](https://github.com/jiaosiyu1999/MAFT-Plus), to tackle the existing bottlenecks.
28
+
29
+ * Mask-Adapter effectively extends to SAM without training, achieving impressive results across multiple open-vocabulary segmentation benchmarks.
30
+
31
+ ## Updates
32
+ - [x] Release code
33
+ - [x] Release weights
34
+ - [x] Release demo with SAM-2👉 [🤗 Mask-Adapter]()
35
+ - [ ] Release weights training with addtional data
36
+
37
+
38
+ ## Installation
39
+ Please follow [installation](INSTALL.md).
40
+
41
+ ## Getting Started
42
+
43
+ See [Preparing Datasets for Mask-Adapter](datasets/README.md).Following [FC-CLIP](https://github.com/bytedance/fc-clip) and [MAFT-Plus](https://github.com/jiaosiyu1999/MAFT-Plus) to prepare datasets.
44
+
45
+ See [Getting Started with Mask-Adapter](GETTING_STARTED.md).
46
+
47
+
48
+
49
+ ## <a name="Citing Mask-Adapter"></a>Citing Mask-Adapter
50
+
51
+ If you use Mask-Adapter in your research, please use the following BibTeX entry.
52
+
53
+ ```BibTeX
54
+
55
+ ```
56
+
57
+ ## Acknowledgement
58
+
59
+ [Mask2Former](https://github.com/facebookresearch/Mask2Former)
60
+
61
+ [ODISE](https://github.com/NVlabs/ODISE)
62
+
63
+ [FC-CLIP](https://github.com/bytedance/fc-clip)
64
+
65
+ [MAFTP](https://github.com/jiaosiyu1999/MAFT-Plus)
66
+
67
+ [SAM](https://github.com/facebookresearch/segment-anything)
68
+ # MaskAdapter
app.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing as mp
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+ try:
6
+ import detectron2
7
+ except:
8
+ import os
9
+ os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
10
+
11
+ from detectron2.config import get_cfg
12
+ from detectron2.projects.deeplab import add_deeplab_config
13
+ from detectron2.data.detection_utils import read_image
14
+ from mask_adapter import add_maskformer2_config, add_fcclip_config, add_mask_adapter_config
15
+ from mask_adapter.sam_maskadapter import SAMVisualizationDemo, SAMPointVisualizationDemo
16
+ import gradio as gr
17
+ import gdown
18
+ import open_clip
19
+ from sam2.build_sam import build_sam2
20
+ from mask_adapter.modeling.meta_arch.mask_adapter_head import build_mask_adapter
21
+
22
+ # ckpt_url = 'https://drive.google.com/uc?id=1cn-ohxgXDrDfkzC1QdO-fi8IjbjXmgKy'
23
+ # output = './ovseg_swinbase_vitL14_ft_mpt.pth'
24
+ # gdown.download(ckpt_url, output, quiet=False)
25
+
26
+
27
+ def setup_cfg(config_file):
28
+ # load config from file and command-line arguments
29
+ cfg = get_cfg()
30
+ add_deeplab_config(cfg)
31
+ add_maskformer2_config(cfg)
32
+ add_fcclip_config(cfg)
33
+ add_mask_adapter_config(cfg)
34
+ cfg.merge_from_file(config_file)
35
+ cfg.freeze()
36
+ return cfg
37
+
38
+
39
+ def inference_automatic(input_img, class_names):
40
+ mp.set_start_method("spawn", force=True)
41
+ config_file = '/home/yongkangli/Mask-Adapter/configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml'
42
+ cfg = setup_cfg(config_file)
43
+
44
+ demo = SAMVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
45
+
46
+ class_names = class_names.split(',')
47
+ img = read_image(input_img, format="BGR")
48
+ _, visualized_output = demo.run_on_image(img, class_names)
49
+
50
+ return Image.fromarray(np.uint8(visualized_output.get_image())).convert('RGB')
51
+
52
+
53
+ def inference_point(input_img, evt: gr.SelectData,):
54
+ # In point mode, implement the logic to process points from the user click (x, y)
55
+ # You can adjust your segmentation logic based on clicked points.
56
+ x, y = evt.index[0], evt.index[1]
57
+ points = [[x, y]] # 假设只选择一个点作为输入
58
+ print(f"Selected point: {points}")
59
+ import time
60
+ start_time = time.time()
61
+ mp.set_start_method("spawn", force=True)
62
+ config_file = '/home/yongkangli/Mask-Adapter/configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml'
63
+ cfg = setup_cfg(config_file)
64
+
65
+ demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
66
+ end_time = time.time()
67
+ print("init time",end_time - start_time)
68
+
69
+ start_time = time.time()
70
+ img = read_image(input_img, format="BGR")
71
+
72
+ # Assume 'points' is a list of (x, y) coordinates to specify where the user clicks
73
+ # Process the image and points to create a segmentation map accordingly
74
+ _, visualized_output = demo.run_on_image_with_points(img, points)
75
+ end_time = time.time()
76
+ print("inf time",end_time - start_time)
77
+ return visualized_output
78
+
79
+
80
+ sam2_model = None
81
+ clip_model = None
82
+ mask_adapter = None
83
+
84
+ # 加载和初始化函数
85
+ def initialize_models(sam_path, adapter_pth, model_cfg, cfg):
86
+ cfg = setup_cfg(cfg)
87
+ global sam2_model, clip_model, mask_adapter
88
+
89
+ # SAM2初始化
90
+ if sam2_model is None:
91
+ sam2_model = build_sam2(model_cfg, sam_path, device="cuda", apply_postprocessing=False)
92
+ print("SAM2 model initialized.")
93
+
94
+ # CLIP模型初始化
95
+ if clip_model is None:
96
+ clip_model, _, _ = open_clip.create_model_and_transforms("convnext_large_d_320", pretrained="laion2b_s29b_b131k_ft_soup")
97
+ print("CLIP model initialized.")
98
+
99
+ # Mask Adapter模型初始化
100
+ if mask_adapter is None:
101
+ mask_adapter = build_mask_adapter(cfg, "MASKAdapterHead").cuda()
102
+ # 加载Adapter状态字典
103
+ adapter_state_dict = torch.load(adapter_pth)
104
+ adapter_state_dict = {k.replace('mask_adapter.', '').replace('adapter.', ''): v
105
+ for k, v in adapter_state_dict["model"].items()
106
+ if k.startswith('adapter') or k.startswith('mask_adapter')}
107
+ mask_adapter.load_state_dict(adapter_state_dict)
108
+ print("Mask Adapter model initialized.")
109
+
110
+ # 初始化配置和模型
111
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
112
+ sam_path = '/home/yongkangli/segment-anything-2/checkpoints/sam2.1_hiera_large.pt'
113
+ adapter_pth = './model_0279999_with_sem_new.pth'
114
+ cfg = '/home/yongkangli/Mask-Adapter/configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml'
115
+
116
+ # 调用初始化函数
117
+ initialize_models(sam_path, adapter_pth, model_cfg, cfg)
118
+
119
+ # Examples for testing
120
+ examples = [
121
+ ['./demo/images/000000001025.jpg', 'dog, beach, trees, sea, sky, snow, person, rocks, buildings, birds, beach umbrella, beach chair'],
122
+ ['./demo/images/ADE_val_00000979.jpg', 'sky,sea,mountain,pier,beach,island,,landscape,horizon'],
123
+ ['./demo/images/ADE_val_00001200.jpg', 'bridge, mountains, trees, water, sky, buildings, boats, animals, flowers, waterfalls, grasslands, rocks'],
124
+ ]
125
+
126
+ output_labels = ['segmentation map']
127
+
128
+ title = '<center><h2>Mask-Adapter + Segment Anything-2</h2></center>'
129
+
130
+ description = """
131
+ <b>Mask-Adapter: The Devil is in the Masks for Open-Vocabulary Segmentation</b><br>
132
+ Mask-Adapter effectively extends to SAM or SAM-2 without additional training, achieving impressive results across multiple open-vocabulary segmentation benchmarks.<br>
133
+ <div style="display: flex; gap: 20px;">
134
+ <a href="https://arxiv.org/abs/2406.20076">
135
+ <img src="https://img.shields.io/badge/arXiv-Paper-red" alt="arXiv Paper">
136
+ </a>
137
+ <a href="https://github.com/hustvl/MaskAdapter">
138
+ <img src="https://img.shields.io/badge/GitHub-Code-blue" alt="GitHub Code">
139
+ </a>
140
+ </div>
141
+ """
142
+
143
+ # Interface with mode selection using Tabs
144
+ with gr.Blocks() as demo:
145
+ gr.Markdown(title) # Title
146
+ gr.Markdown(description) # Description
147
+
148
+ with gr.Tabs():
149
+ with gr.TabItem("Automatic Mode"):
150
+ with gr.Row():
151
+ with gr.Column():
152
+ input_image = gr.Image(type='filepath', label="Input Image")
153
+ class_names = gr.Textbox(lines=1, placeholder=None, label='Class Names')
154
+ with gr.Column():
155
+ output_image = gr.Image(type="pil", label='Segmentation Map')
156
+
157
+ # Buttons below segmentation map (now placed under segmentation map)
158
+ run_button = gr.Button("Run Automatic Segmentation")
159
+ run_button.click(inference_automatic, inputs=[input_image, class_names], outputs=output_image)
160
+
161
+ clear_button = gr.Button("Clear")
162
+ clear_button.click(lambda: None, inputs=None, outputs=output_image)
163
+
164
+ with gr.Row():
165
+ gr.Examples(examples=examples, inputs=[input_image, class_names], outputs=output_image)
166
+
167
+ with gr.TabItem("Point Mode"):
168
+ with gr.Row(): # 水平排列
169
+ with gr.Column():
170
+ input_image = gr.Image(type='filepath', label="Upload Image", interactive=True) # 上传图片并允许交互
171
+ points_input = gr.State(value=[]) # 用于存储点击的点
172
+
173
+ with gr.Column(): # 第二列:分割图输出
174
+ output_image_point = gr.Image(type="pil", label='Segmentation Map') # 输出分割图
175
+
176
+ # 直接使用 `SelectData` 事件触发 `inference_point`
177
+ input_image.select(inference_point, inputs=[input_image], outputs=output_image_point)
178
+
179
+ # 清除分割图的按钮
180
+ clear_button_point = gr.Button("Clear Segmentation Map")
181
+ clear_button_point.click(lambda: None, inputs=None, outputs=output_image_point)
182
+
183
+
184
+
185
+
186
+ # Example images below buttons
187
+
188
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cython
2
+ scipy
3
+ shapely
4
+ timm
5
+ h5py
6
+ submitit
7
+ scikit-image
8
+ Pillow==8.4.0
9
+ opencv-python
10
+ pycocotools~=2.0.4
11
+ open_clip_torch==2.16.0
12
+
13
+ # Torch
14
+ --find-links https://download.pytorch.org/whl/cu118/torch_stable.html
15
+
16
+ torch==2.3.1+cu118
17
+ torchvision==0.18.1+cu118
18
+
19
+ # Detectron
20
+ --find-links https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
21
+ detectron2
22
+
23
+ # Segment-anything
24
+ git+https://github.com/facebookresearch/sam2.git
25
+
26
+ # open_clip
train_net_fcclip.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
3
+ All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
4
+
5
+ Reference: https://github.com/facebookresearch/Mask2Former/blob/main/train_net.py
6
+
7
+ FCCLIP Training Script.
8
+
9
+ This script is a simplified version of the training script in detectron2/tools.
10
+ """
11
+ try:
12
+ # ignore ShapelyDeprecationWarning from fvcore
13
+ from shapely.errors import ShapelyDeprecationWarning
14
+ import warnings
15
+ warnings.filterwarnings('ignore', category=ShapelyDeprecationWarning)
16
+ except:
17
+ pass
18
+
19
+ import copy
20
+ import itertools
21
+ import logging
22
+ import os
23
+
24
+ from collections import OrderedDict
25
+ from typing import Any, Dict, List, Set
26
+
27
+ import torch
28
+
29
+ import detectron2.utils.comm as comm
30
+ from detectron2.checkpoint import DetectionCheckpointer
31
+ from detectron2.config import get_cfg
32
+ from detectron2.data import MetadataCatalog, build_detection_train_loader
33
+ from detectron2.engine import (
34
+ DefaultTrainer,
35
+ default_argument_parser,
36
+ default_setup,
37
+ launch,
38
+ )
39
+ from detectron2.evaluation import (
40
+ CityscapesInstanceEvaluator,
41
+ CityscapesSemSegEvaluator,
42
+ COCOEvaluator,
43
+ COCOPanopticEvaluator,
44
+ DatasetEvaluators,
45
+ LVISEvaluator,
46
+ SemSegEvaluator,
47
+ verify_results,
48
+ )
49
+ from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler
50
+ from detectron2.solver.build import maybe_add_gradient_clipping
51
+ from detectron2.utils.logger import setup_logger
52
+
53
+ from fcclip import (
54
+ COCOInstanceNewBaselineDatasetMapper,
55
+ COCOPanopticNewBaselineDatasetMapper,
56
+ InstanceSegEvaluator,
57
+ MaskFormerInstanceDatasetMapper,
58
+ MaskFormerPanopticDatasetMapper,
59
+ MaskFormerSemanticDatasetMapper,
60
+ SemanticSegmentorWithTTA,
61
+ add_maskformer2_config,
62
+ add_fcclip_config,
63
+ add_mask_adapter_config,
64
+ )
65
+
66
+
67
+ class Trainer(DefaultTrainer):
68
+ """
69
+ Extension of the Trainer class adapted to FCCLIP.
70
+ """
71
+
72
+ @classmethod
73
+ def build_evaluator(cls, cfg, dataset_name, output_folder=None):
74
+ """
75
+ Create evaluator(s) for a given dataset.
76
+ This uses the special metadata "evaluator_type" associated with each
77
+ builtin dataset. For your own dataset, you can simply create an
78
+ evaluator manually in your script and do not have to worry about the
79
+ hacky if-else logic here.
80
+ """
81
+ if output_folder is None:
82
+ output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
83
+ evaluator_list = []
84
+ evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
85
+ # semantic segmentation
86
+ if evaluator_type in ["sem_seg", "ade20k_panoptic_seg"]:
87
+ evaluator_list.append(
88
+ SemSegEvaluator(
89
+ dataset_name,
90
+ distributed=True,
91
+ output_dir=output_folder,
92
+ )
93
+ )
94
+ # instance segmentation
95
+ if evaluator_type == "coco":
96
+ evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
97
+ # panoptic segmentation
98
+ if evaluator_type in [
99
+ "coco_panoptic_seg",
100
+ "ade20k_panoptic_seg",
101
+ "cityscapes_panoptic_seg",
102
+ "mapillary_vistas_panoptic_seg",
103
+ ]:
104
+ if cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON:
105
+ evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
106
+ # COCO
107
+ if evaluator_type == "coco_panoptic_seg" and cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON:
108
+ evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
109
+ if evaluator_type == "coco_panoptic_seg" and cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON:
110
+ evaluator_list.append(SemSegEvaluator(dataset_name, distributed=True, output_dir=output_folder))
111
+ # Mapillary Vistas
112
+ if evaluator_type == "mapillary_vistas_panoptic_seg" and cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON:
113
+ evaluator_list.append(InstanceSegEvaluator(dataset_name, output_dir=output_folder))
114
+ if evaluator_type == "mapillary_vistas_panoptic_seg" and cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON:
115
+ evaluator_list.append(SemSegEvaluator(dataset_name, distributed=True, output_dir=output_folder))
116
+ # Cityscapes
117
+ if evaluator_type == "cityscapes_instance":
118
+ assert (
119
+ torch.cuda.device_count() > comm.get_rank()
120
+ ), "CityscapesEvaluator currently do not work with multiple machines."
121
+ return CityscapesInstanceEvaluator(dataset_name)
122
+ if evaluator_type == "cityscapes_sem_seg":
123
+ assert (
124
+ torch.cuda.device_count() > comm.get_rank()
125
+ ), "CityscapesEvaluator currently do not work with multiple machines."
126
+ return CityscapesSemSegEvaluator(dataset_name)
127
+ if evaluator_type == "cityscapes_panoptic_seg":
128
+ if cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON:
129
+ assert (
130
+ torch.cuda.device_count() > comm.get_rank()
131
+ ), "CityscapesEvaluator currently do not work with multiple machines."
132
+ evaluator_list.append(CityscapesSemSegEvaluator(dataset_name))
133
+ if cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON:
134
+ assert (
135
+ torch.cuda.device_count() > comm.get_rank()
136
+ ), "CityscapesEvaluator currently do not work with multiple machines."
137
+ evaluator_list.append(CityscapesInstanceEvaluator(dataset_name))
138
+ # ADE20K
139
+ if evaluator_type == "ade20k_panoptic_seg" and cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON:
140
+ evaluator_list.append(InstanceSegEvaluator(dataset_name, output_dir=output_folder))
141
+ # LVIS
142
+ if evaluator_type == "lvis":
143
+ return LVISEvaluator(dataset_name, output_dir=output_folder)
144
+ if len(evaluator_list) == 0:
145
+ raise NotImplementedError(
146
+ "no Evaluator for the dataset {} with the type {}".format(
147
+ dataset_name, evaluator_type
148
+ )
149
+ )
150
+ elif len(evaluator_list) == 1:
151
+ return evaluator_list[0]
152
+ return DatasetEvaluators(evaluator_list)
153
+
154
+ @classmethod
155
+ def build_train_loader(cls, cfg):
156
+ # Semantic segmentation dataset mapper
157
+ if cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_semantic":
158
+ mapper = MaskFormerSemanticDatasetMapper(cfg, True)
159
+ return build_detection_train_loader(cfg, mapper=mapper)
160
+ # Panoptic segmentation dataset mapper
161
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_panoptic":
162
+ mapper = MaskFormerPanopticDatasetMapper(cfg, True)
163
+ return build_detection_train_loader(cfg, mapper=mapper)
164
+ # Instance segmentation dataset mapper
165
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_instance":
166
+ mapper = MaskFormerInstanceDatasetMapper(cfg, True)
167
+ return build_detection_train_loader(cfg, mapper=mapper)
168
+ # coco instance segmentation lsj new baseline
169
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_instance_lsj":
170
+ mapper = COCOInstanceNewBaselineDatasetMapper(cfg, True)
171
+ return build_detection_train_loader(cfg, mapper=mapper)
172
+ # coco panoptic segmentation lsj new baseline
173
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_panoptic_lsj":
174
+ mapper = COCOPanopticNewBaselineDatasetMapper(cfg, True)
175
+ return build_detection_train_loader(cfg, mapper=mapper)
176
+ else:
177
+ mapper = None
178
+ return build_detection_train_loader(cfg, mapper=mapper)
179
+
180
+ @classmethod
181
+ def build_lr_scheduler(cls, cfg, optimizer):
182
+ """
183
+ It now calls :func:`detectron2.solver.build_lr_scheduler`.
184
+ Overwrite it if you'd like a different scheduler.
185
+ """
186
+ return build_lr_scheduler(cfg, optimizer)
187
+
188
+ @classmethod
189
+ def build_optimizer(cls, cfg, model):
190
+ weight_decay_norm = cfg.SOLVER.WEIGHT_DECAY_NORM
191
+ weight_decay_embed = cfg.SOLVER.WEIGHT_DECAY_EMBED
192
+
193
+ defaults = {}
194
+ defaults["lr"] = cfg.SOLVER.BASE_LR
195
+ defaults["weight_decay"] = cfg.SOLVER.WEIGHT_DECAY
196
+
197
+ norm_module_types = (
198
+ torch.nn.BatchNorm1d,
199
+ torch.nn.BatchNorm2d,
200
+ torch.nn.BatchNorm3d,
201
+ torch.nn.SyncBatchNorm,
202
+ # NaiveSyncBatchNorm inherits from BatchNorm2d
203
+ torch.nn.GroupNorm,
204
+ torch.nn.InstanceNorm1d,
205
+ torch.nn.InstanceNorm2d,
206
+ torch.nn.InstanceNorm3d,
207
+ torch.nn.LayerNorm,
208
+ torch.nn.LocalResponseNorm,
209
+ )
210
+
211
+ params: List[Dict[str, Any]] = []
212
+ memo: Set[torch.nn.parameter.Parameter] = set()
213
+ for module_name, module in model.named_modules():
214
+ for module_param_name, value in module.named_parameters(recurse=False):
215
+ if not value.requires_grad:
216
+ continue
217
+ # Avoid duplicating parameters
218
+ if value in memo:
219
+ continue
220
+ memo.add(value)
221
+
222
+ hyperparams = copy.copy(defaults)
223
+ if "backbone" in module_name:
224
+ hyperparams["lr"] = hyperparams["lr"] * cfg.SOLVER.BACKBONE_MULTIPLIER
225
+ if (
226
+ "relative_position_bias_table" in module_param_name
227
+ or "absolute_pos_embed" in module_param_name
228
+ ):
229
+ print(module_param_name)
230
+ hyperparams["weight_decay"] = 0.0
231
+ if isinstance(module, norm_module_types):
232
+ hyperparams["weight_decay"] = weight_decay_norm
233
+ if isinstance(module, torch.nn.Embedding):
234
+ hyperparams["weight_decay"] = weight_decay_embed
235
+ params.append({"params": [value], **hyperparams})
236
+
237
+ def maybe_add_full_model_gradient_clipping(optim):
238
+ # detectron2 doesn't have full model gradient clipping now
239
+ clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
240
+ enable = (
241
+ cfg.SOLVER.CLIP_GRADIENTS.ENABLED
242
+ and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
243
+ and clip_norm_val > 0.0
244
+ )
245
+
246
+ class FullModelGradientClippingOptimizer(optim):
247
+ def step(self, closure=None):
248
+ all_params = itertools.chain(*[x["params"] for x in self.param_groups])
249
+ torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
250
+ super().step(closure=closure)
251
+
252
+ return FullModelGradientClippingOptimizer if enable else optim
253
+
254
+ optimizer_type = cfg.SOLVER.OPTIMIZER
255
+ if optimizer_type == "SGD":
256
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
257
+ params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
258
+ )
259
+ elif optimizer_type == "ADAMW":
260
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
261
+ params, cfg.SOLVER.BASE_LR
262
+ )
263
+ else:
264
+ raise NotImplementedError(f"no optimizer type {optimizer_type}")
265
+ if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
266
+ optimizer = maybe_add_gradient_clipping(cfg, optimizer)
267
+ return optimizer
268
+
269
+ @classmethod
270
+ def test_with_TTA(cls, cfg, model):
271
+ logger = logging.getLogger("detectron2.trainer")
272
+ # In the end of training, run an evaluation with TTA.
273
+ logger.info("Running inference with test-time augmentation ...")
274
+ model = SemanticSegmentorWithTTA(cfg, model)
275
+ evaluators = [
276
+ cls.build_evaluator(
277
+ cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
278
+ )
279
+ for name in cfg.DATASETS.TEST
280
+ ]
281
+ res = cls.test(cfg, model, evaluators)
282
+ res = OrderedDict({k + "_TTA": v for k, v in res.items()})
283
+ return res
284
+
285
+
286
+ def setup(args):
287
+ """
288
+ Create configs and perform basic setups.
289
+ """
290
+ cfg = get_cfg()
291
+ # for poly lr schedule
292
+ add_deeplab_config(cfg)
293
+ add_maskformer2_config(cfg)
294
+ add_fcclip_config(cfg)
295
+ add_mask_adapter_config(cfg)
296
+ cfg.merge_from_file(args.config_file)
297
+ cfg.merge_from_list(args.opts)
298
+ cfg.freeze()
299
+ default_setup(cfg, args)
300
+ # Setup logger for "fcclip" module
301
+ setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="fcclip")
302
+ return cfg
303
+
304
+
305
+ def main(args):
306
+ cfg = setup(args)
307
+
308
+ if args.eval_only:
309
+ model = Trainer.build_model(cfg)
310
+
311
+ total_params = sum(p.numel() for p in model.parameters())
312
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
313
+ frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
314
+ frozen_params_exclude_text = 0
315
+ for n, p in model.named_parameters():
316
+ if p.requires_grad:
317
+ continue
318
+ # ignore text tower
319
+ if 'clip_model.token_embedding' in n or 'clip_model.positional_embedding' in n or 'clip_model.transformer' in n or 'clip_model.ln_final' in n or 'clip_model.text_projection' in n:
320
+ continue
321
+ frozen_params_exclude_text += p.numel()
322
+ print(f"total_params: {total_params}, trainable_params: {trainable_params}, frozen_params: {frozen_params}, frozen_params_exclude_text: {frozen_params_exclude_text}")
323
+
324
+ DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
325
+ cfg.MODEL.WEIGHTS, resume=args.resume
326
+ )
327
+ res = Trainer.test(cfg, model)
328
+ if cfg.TEST.AUG.ENABLED:
329
+ res.update(Trainer.test_with_TTA(cfg, model))
330
+ if comm.is_main_process():
331
+ verify_results(cfg, res)
332
+ return res
333
+
334
+ trainer = Trainer(cfg)
335
+ trainer.resume_or_load(resume=args.resume)
336
+ return trainer.train()
337
+
338
+
339
+ if __name__ == "__main__":
340
+ args = default_argument_parser().parse_args()
341
+ print("Command Line Args:", args)
342
+ launch(
343
+ main,
344
+ args.num_gpus,
345
+ num_machines=args.num_machines,
346
+ machine_rank=args.machine_rank,
347
+ dist_url=args.dist_url,
348
+ args=(args,),
349
+ )
train_net_maftp.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
3
+ All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
4
+
5
+ Reference: https://github.com/facebookresearch/Mask2Former/blob/main/train_net.py
6
+
7
+ MAFT-Plus Training Script.
8
+
9
+ This script is a simplified version of the training script in detectron2/tools.
10
+ """
11
+ try:
12
+ # ignore ShapelyDeprecationWarning from fvcore
13
+ from shapely.errors import ShapelyDeprecationWarning
14
+ import warnings
15
+ warnings.filterwarnings('ignore', category=ShapelyDeprecationWarning)
16
+ except:
17
+ pass
18
+
19
+ import copy
20
+ import itertools
21
+ import logging
22
+ import os
23
+ # os.environ['CUDA_VISIBLE_DEVICES'] = '2,4,6'
24
+
25
+ from collections import OrderedDict
26
+ from typing import Any, Dict, List, Set
27
+
28
+ import torch
29
+
30
+ import detectron2.utils.comm as comm
31
+ from detectron2.checkpoint import DetectionCheckpointer
32
+ from detectron2.config import get_cfg
33
+ from detectron2.data import MetadataCatalog, build_detection_train_loader
34
+ from detectron2.engine import (
35
+ DefaultTrainer,
36
+ default_argument_parser,
37
+ default_setup,
38
+ launch,
39
+ )
40
+ from detectron2.evaluation import (
41
+ CityscapesInstanceEvaluator,
42
+ CityscapesSemSegEvaluator,
43
+ COCOEvaluator,
44
+ COCOPanopticEvaluator,
45
+ DatasetEvaluators,
46
+ LVISEvaluator,
47
+ SemSegEvaluator,
48
+ verify_results,
49
+ )
50
+ from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler
51
+ from detectron2.solver.build import maybe_add_gradient_clipping
52
+ from detectron2.utils.logger import setup_logger
53
+
54
+ from maft import (
55
+ COCOInstanceNewBaselineDatasetMapper,
56
+ COCOPanopticNewBaselineDatasetMapper,
57
+ COCOSemanticNewBaselineDatasetMapper,
58
+ InstanceSegEvaluator,
59
+ #SemSegEvaluator,
60
+ MaskFormerInstanceDatasetMapper,
61
+ MaskFormerPanopticDatasetMapper,
62
+ MaskFormerSemanticDatasetMapper,
63
+ SemanticSegmentorWithTTA,
64
+ add_maskformer2_config,
65
+ add_fcclip_config,
66
+ add_mask_adapter_config,
67
+ )
68
+
69
+
70
+ class Trainer(DefaultTrainer):
71
+ """
72
+ Extension of the Trainer class adapted to FCCLIP.
73
+ """
74
+
75
+ @classmethod
76
+ def build_evaluator(cls, cfg, dataset_name, output_folder=None):
77
+ """
78
+ Create evaluator(s) for a given dataset.
79
+ This uses the special metadata "evaluator_type" associated with each
80
+ builtin dataset. For your own dataset, you can simply create an
81
+ evaluator manually in your script and do not have to worry about the
82
+ hacky if-else logic here.
83
+ """
84
+ if output_folder is None:
85
+ output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
86
+ evaluator_list = []
87
+ evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
88
+ # semantic segmentation
89
+ if cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_semantic":
90
+ evaluator_list.append(
91
+ SemSegEvaluator(
92
+ dataset_name,
93
+ distributed=True,
94
+ output_dir=output_folder,
95
+ )
96
+ )
97
+ # panoptic segmentation
98
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_panoptic_lsj":
99
+ evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_dir=output_folder))
100
+ if len(evaluator_list) == 0:
101
+ raise NotImplementedError(
102
+ "no Evaluator for the dataset {} with the type {}".format(
103
+ dataset_name, evaluator_type
104
+ )
105
+ )
106
+ elif len(evaluator_list) == 1:
107
+ return evaluator_list[0]
108
+ return DatasetEvaluators(evaluator_list)
109
+
110
+ @classmethod
111
+ def build_train_loader(cls, cfg):
112
+ # Semantic segmentation dataset mapper
113
+ if cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_semantic":
114
+ mapper = MaskFormerSemanticDatasetMapper(cfg, True)
115
+ return build_detection_train_loader(cfg, mapper=mapper)
116
+ # Panoptic segmentation dataset mapper
117
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_panoptic":
118
+ mapper = MaskFormerPanopticDatasetMapper(cfg, True)
119
+ return build_detection_train_loader(cfg, mapper=mapper)
120
+ # Instance segmentation dataset mapper
121
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_instance":
122
+ mapper = MaskFormerInstanceDatasetMapper(cfg, True)
123
+ return build_detection_train_loader(cfg, mapper=mapper)
124
+ # coco instance segmentation lsj new baseline
125
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_instance_lsj":
126
+ mapper = COCOInstanceNewBaselineDatasetMapper(cfg, True)
127
+ return build_detection_train_loader(cfg, mapper=mapper)
128
+ # coco panoptic segmentation lsj new baseline
129
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_panoptic_lsj":
130
+ mapper = COCOPanopticNewBaselineDatasetMapper(cfg, True)
131
+ return build_detection_train_loader(cfg, mapper=mapper)
132
+ # coco panoptic segmentation lsj new baseline
133
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_semantic_lsj":
134
+ mapper = COCOSemanticNewBaselineDatasetMapper(cfg, True)
135
+ return build_detection_train_loader(cfg, mapper=mapper)
136
+
137
+ else:
138
+ mapper = None
139
+ return build_detection_train_loader(cfg, mapper=mapper)
140
+
141
+ @classmethod
142
+ def build_lr_scheduler(cls, cfg, optimizer):
143
+ """
144
+ It now calls :func:`detectron2.solver.build_lr_scheduler`.
145
+ Overwrite it if you'd like a different scheduler.
146
+ """
147
+ return build_lr_scheduler(cfg, optimizer)
148
+
149
+ @classmethod
150
+ def build_optimizer(cls, cfg, model):
151
+ weight_decay_norm = cfg.SOLVER.WEIGHT_DECAY_NORM
152
+ weight_decay_embed = cfg.SOLVER.WEIGHT_DECAY_EMBED
153
+
154
+ defaults = {}
155
+ defaults["lr"] = cfg.SOLVER.BASE_LR
156
+ defaults["weight_decay"] = cfg.SOLVER.WEIGHT_DECAY
157
+
158
+ norm_module_types = (
159
+ torch.nn.BatchNorm1d,
160
+ torch.nn.BatchNorm2d,
161
+ torch.nn.BatchNorm3d,
162
+ torch.nn.SyncBatchNorm,
163
+ # NaiveSyncBatchNorm inherits from BatchNorm2d
164
+ torch.nn.GroupNorm,
165
+ torch.nn.InstanceNorm1d,
166
+ torch.nn.InstanceNorm2d,
167
+ torch.nn.InstanceNorm3d,
168
+ torch.nn.LayerNorm,
169
+ torch.nn.LocalResponseNorm,
170
+ )
171
+
172
+ params: List[Dict[str, Any]] = []
173
+ memo: Set[torch.nn.parameter.Parameter] = set()
174
+ for module_name, module in model.named_modules():
175
+ for module_param_name, value in module.named_parameters(recurse=False):
176
+ if not value.requires_grad:
177
+ continue
178
+ # Avoid duplicating parameters
179
+ if value in memo:
180
+ continue
181
+ memo.add(value)
182
+
183
+ hyperparams = copy.copy(defaults)
184
+ if "backbone" in module_name:
185
+ hyperparams["lr"] = hyperparams["lr"] * cfg.SOLVER.BACKBONE_MULTIPLIER
186
+ if (
187
+ "relative_position_bias_table" in module_param_name
188
+ or "absolute_pos_embed" in module_param_name
189
+ ):
190
+ print(module_param_name)
191
+ hyperparams["weight_decay"] = 0.0
192
+ if isinstance(module, norm_module_types):
193
+ hyperparams["weight_decay"] = weight_decay_norm
194
+ if isinstance(module, torch.nn.Embedding):
195
+ hyperparams["weight_decay"] = weight_decay_embed
196
+ params.append({"params": [value], **hyperparams})
197
+
198
+ def maybe_add_full_model_gradient_clipping(optim):
199
+ # detectron2 doesn't have full model gradient clipping now
200
+ clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
201
+ enable = (
202
+ cfg.SOLVER.CLIP_GRADIENTS.ENABLED
203
+ and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
204
+ and clip_norm_val > 0.0
205
+ )
206
+
207
+ class FullModelGradientClippingOptimizer(optim):
208
+ def step(self, closure=None):
209
+ all_params = itertools.chain(*[x["params"] for x in self.param_groups])
210
+ torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
211
+ super().step(closure=closure)
212
+
213
+ return FullModelGradientClippingOptimizer if enable else optim
214
+
215
+ optimizer_type = cfg.SOLVER.OPTIMIZER
216
+ if optimizer_type == "SGD":
217
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
218
+ params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
219
+ )
220
+ elif optimizer_type == "ADAMW":
221
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
222
+ params, cfg.SOLVER.BASE_LR
223
+ )
224
+ else:
225
+ raise NotImplementedError(f"no optimizer type {optimizer_type}")
226
+ if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
227
+ optimizer = maybe_add_gradient_clipping(cfg, optimizer)
228
+ return optimizer
229
+
230
+ @classmethod
231
+ def test_with_TTA(cls, cfg, model):
232
+ logger = logging.getLogger("detectron2.trainer")
233
+ # In the end of training, run an evaluation with TTA.
234
+ logger.info("Running inference with test-time augmentation ...")
235
+ model = SemanticSegmentorWithTTA(cfg, model)
236
+ evaluators = [
237
+ cls.build_evaluator(
238
+ cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
239
+ )
240
+ for name in cfg.DATASETS.TEST
241
+ ]
242
+ res = cls.test(cfg, model, evaluators)
243
+ res = OrderedDict({k + "_TTA": v for k, v in res.items()})
244
+ return res
245
+
246
+
247
+ def setup(args):
248
+ """
249
+ Create configs and perform basic setups.
250
+ """
251
+ cfg = get_cfg()
252
+ # for poly lr schedule
253
+ add_deeplab_config(cfg)
254
+ add_maskformer2_config(cfg)
255
+ add_fcclip_config(cfg)
256
+ add_mask_adapter_config(cfg)
257
+ cfg.merge_from_file(args.config_file)
258
+ cfg.merge_from_list(args.opts)
259
+ cfg.merge_from_list(['SEED', 123])
260
+ cfg.freeze()
261
+ default_setup(cfg, args)
262
+ # Setup logger for "maft-plus" module
263
+ setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="maft-plus")
264
+ return cfg
265
+
266
+
267
+ def main(args):
268
+ # torch.multiprocessing.set_start_method('spawn')
269
+ cfg = setup(args)
270
+
271
+ if args.eval_only:
272
+ model = Trainer.build_model(cfg)
273
+ DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
274
+ cfg.MODEL.WEIGHTS, resume=args.resume
275
+ )
276
+ res = Trainer.test(cfg, model)
277
+ if cfg.TEST.AUG.ENABLED:
278
+ res.update(Trainer.test_with_TTA(cfg, model))
279
+ if comm.is_main_process():
280
+ verify_results(cfg, res)
281
+ return res
282
+
283
+ trainer = Trainer(cfg)
284
+ trainer.resume_or_load(resume=args.resume)
285
+ return trainer.train()
286
+
287
+
288
+ if __name__ == "__main__":
289
+ args = default_argument_parser().parse_args()
290
+ print("Command Line Args:", args)
291
+ launch(
292
+ main,
293
+ args.num_gpus,
294
+ num_machines=args.num_machines,
295
+ machine_rank=args.machine_rank,
296
+ dist_url=args.dist_url,
297
+ args=(args,),
298
+ )
train_net_maskadapter.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
3
+ All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
4
+
5
+ Reference: https://github.com/facebookresearch/Mask2Former/blob/main/train_net.py
6
+
7
+ FCCLIP Training Script.
8
+
9
+ This script is a simplified version of the training script in detectron2/tools.
10
+ """
11
+ try:
12
+ # ignore ShapelyDeprecationWarning from fvcore
13
+ from shapely.errors import ShapelyDeprecationWarning
14
+ import warnings
15
+ warnings.filterwarnings('ignore', category=ShapelyDeprecationWarning)
16
+ except:
17
+ pass
18
+
19
+ import copy
20
+ import itertools
21
+ import logging
22
+ import os
23
+
24
+ from collections import OrderedDict
25
+ from typing import Any, Dict, List, Set
26
+
27
+ import torch
28
+
29
+ import detectron2.utils.comm as comm
30
+ from detectron2.checkpoint import DetectionCheckpointer
31
+ from detectron2.config import get_cfg
32
+ from detectron2.data import MetadataCatalog, build_detection_train_loader
33
+ from detectron2.engine import (
34
+ DefaultTrainer,
35
+ default_argument_parser,
36
+ default_setup,
37
+ launch,
38
+ )
39
+ from detectron2.evaluation import (
40
+ CityscapesInstanceEvaluator,
41
+ CityscapesSemSegEvaluator,
42
+ COCOEvaluator,
43
+ COCOPanopticEvaluator,
44
+ DatasetEvaluators,
45
+ LVISEvaluator,
46
+ SemSegEvaluator,
47
+ verify_results,
48
+ )
49
+ from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler
50
+ from detectron2.solver.build import maybe_add_gradient_clipping
51
+ from detectron2.utils.logger import setup_logger
52
+
53
+ from mask_adapter import (
54
+ COCOInstanceNewBaselineDatasetMapper,
55
+ COCOPanopticNewBaselineDatasetMapper,
56
+ InstanceSegEvaluator,
57
+ MaskFormerInstanceDatasetMapper,
58
+ MaskFormerPanopticDatasetMapper,
59
+ MaskFormerSemanticDatasetMapper,
60
+ SemanticSegmentorWithTTA,
61
+ add_maskformer2_config,
62
+ add_fcclip_config,
63
+ add_mask_adapter_config
64
+ )
65
+
66
+
67
+ class Trainer(DefaultTrainer):
68
+ """
69
+ Extension of the Trainer class adapted to FCCLIP.
70
+ """
71
+
72
+ @classmethod
73
+ def build_evaluator(cls, cfg, dataset_name, output_folder=None):
74
+ """
75
+ Create evaluator(s) for a given dataset.
76
+ This uses the special metadata "evaluator_type" associated with each
77
+ builtin dataset. For your own dataset, you can simply create an
78
+ evaluator manually in your script and do not have to worry about the
79
+ hacky if-else logic here.
80
+ """
81
+ if output_folder is None:
82
+ output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
83
+ evaluator_list = []
84
+ evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
85
+ # semantic segmentation
86
+ if evaluator_type in ["sem_seg", "ade20k_panoptic_seg"]:
87
+ evaluator_list.append(
88
+ SemSegEvaluator(
89
+ dataset_name,
90
+ distributed=True,
91
+ output_dir=output_folder,
92
+ )
93
+ )
94
+ # instance segmentation
95
+ if evaluator_type == "coco":
96
+ evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
97
+ # panoptic segmentation
98
+ if evaluator_type in [
99
+ "coco_panoptic_seg",
100
+ "ade20k_panoptic_seg",
101
+ "cityscapes_panoptic_seg",
102
+ "mapillary_vistas_panoptic_seg",
103
+ ]:
104
+ if cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON:
105
+ evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
106
+ # COCO
107
+ if evaluator_type == "coco_panoptic_seg" and cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON:
108
+ evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
109
+ if evaluator_type == "coco_panoptic_seg" and cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON:
110
+ evaluator_list.append(SemSegEvaluator(dataset_name, distributed=True, output_dir=output_folder))
111
+ # Mapillary Vistas
112
+ if evaluator_type == "mapillary_vistas_panoptic_seg" and cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON:
113
+ evaluator_list.append(InstanceSegEvaluator(dataset_name, output_dir=output_folder))
114
+ if evaluator_type == "mapillary_vistas_panoptic_seg" and cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON:
115
+ evaluator_list.append(SemSegEvaluator(dataset_name, distributed=True, output_dir=output_folder))
116
+ # Cityscapes
117
+ if evaluator_type == "cityscapes_instance":
118
+ assert (
119
+ torch.cuda.device_count() > comm.get_rank()
120
+ ), "CityscapesEvaluator currently do not work with multiple machines."
121
+ return CityscapesInstanceEvaluator(dataset_name)
122
+ if evaluator_type == "cityscapes_sem_seg":
123
+ assert (
124
+ torch.cuda.device_count() > comm.get_rank()
125
+ ), "CityscapesEvaluator currently do not work with multiple machines."
126
+ return CityscapesSemSegEvaluator(dataset_name)
127
+ if evaluator_type == "cityscapes_panoptic_seg":
128
+ if cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON:
129
+ assert (
130
+ torch.cuda.device_count() > comm.get_rank()
131
+ ), "CityscapesEvaluator currently do not work with multiple machines."
132
+ evaluator_list.append(CityscapesSemSegEvaluator(dataset_name))
133
+ if cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON:
134
+ assert (
135
+ torch.cuda.device_count() > comm.get_rank()
136
+ ), "CityscapesEvaluator currently do not work with multiple machines."
137
+ evaluator_list.append(CityscapesInstanceEvaluator(dataset_name))
138
+ # ADE20K
139
+ if evaluator_type == "ade20k_panoptic_seg" and cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON:
140
+ evaluator_list.append(InstanceSegEvaluator(dataset_name, output_dir=output_folder))
141
+ # LVIS
142
+ if evaluator_type == "lvis":
143
+ return LVISEvaluator(dataset_name, output_dir=output_folder)
144
+ if len(evaluator_list) == 0:
145
+ raise NotImplementedError(
146
+ "no Evaluator for the dataset {} with the type {}".format(
147
+ dataset_name, evaluator_type
148
+ )
149
+ )
150
+ elif len(evaluator_list) == 1:
151
+ return evaluator_list[0]
152
+ return DatasetEvaluators(evaluator_list)
153
+
154
+ @classmethod
155
+ def build_train_loader(cls, cfg):
156
+ # Semantic segmentation dataset mapper
157
+ if cfg.DATALOADER.SAMPLER_TRAIN == "MultiDatasetSampler":
158
+ mapper = COCOCombineNewBaselineDatasetMapper(cfg, True)
159
+ data_loader = build_custom_train_loader(cfg, mapper=mapper)
160
+ return data_loader
161
+ else:
162
+ if cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_semantic":
163
+ mapper = MaskFormerSemanticDatasetMapper(cfg, True)
164
+ return build_detection_train_loader(cfg, mapper=mapper)
165
+ # Panoptic segmentation dataset mapper
166
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_panoptic":
167
+ mapper = MaskFormerPanopticDatasetMapper(cfg, True)
168
+ return build_detection_train_loader(cfg, mapper=mapper)
169
+ # Instance segmentation dataset mapper
170
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_instance":
171
+ mapper = MaskFormerInstanceDatasetMapper(cfg, True)
172
+ return build_detection_train_loader(cfg, mapper=mapper)
173
+ # coco instance segmentation lsj new baseline
174
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_instance_lsj":
175
+ mapper = COCOInstanceNewBaselineDatasetMapper(cfg, True)
176
+ return build_detection_train_loader(cfg, mapper=mapper)
177
+ # coco panoptic segmentation lsj new baseline
178
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_panoptic_lsj":
179
+ mapper = COCOPanopticNewBaselineDatasetMapper(cfg, True)
180
+ return build_detection_train_loader(cfg, mapper=mapper)
181
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_combine_lsj":
182
+ mapper = COCOCombineNewBaselineDatasetMapper(cfg, True)
183
+ return build_detection_train_loader(cfg, mapper=mapper)
184
+ # elif cfg.INPUT.DATASET_MAPPER_NAME == "grand_panoptic_lsj":
185
+ # mapper = GrandNewBaselineDatasetMapper(cfg, True)
186
+ # return build_detection_train_loader(cfg, mapper=mapper)
187
+ else:
188
+ mapper = None
189
+ return build_detection_train_loader(cfg, mapper=mapper)
190
+
191
+ @classmethod
192
+ def build_lr_scheduler(cls, cfg, optimizer):
193
+ """
194
+ It now calls :func:`detectron2.solver.build_lr_scheduler`.
195
+ Overwrite it if you'd like a different scheduler.
196
+ """
197
+ return build_lr_scheduler(cfg, optimizer)
198
+
199
+ @classmethod
200
+ def build_optimizer(cls, cfg, model):
201
+ weight_decay_norm = cfg.SOLVER.WEIGHT_DECAY_NORM
202
+ weight_decay_embed = cfg.SOLVER.WEIGHT_DECAY_EMBED
203
+
204
+ defaults = {}
205
+ defaults["lr"] = cfg.SOLVER.BASE_LR
206
+ defaults["weight_decay"] = cfg.SOLVER.WEIGHT_DECAY
207
+
208
+ norm_module_types = (
209
+ torch.nn.BatchNorm1d,
210
+ torch.nn.BatchNorm2d,
211
+ torch.nn.BatchNorm3d,
212
+ torch.nn.SyncBatchNorm,
213
+ # NaiveSyncBatchNorm inherits from BatchNorm2d
214
+ torch.nn.GroupNorm,
215
+ torch.nn.InstanceNorm1d,
216
+ torch.nn.InstanceNorm2d,
217
+ torch.nn.InstanceNorm3d,
218
+ torch.nn.LayerNorm,
219
+ torch.nn.LocalResponseNorm,
220
+ )
221
+
222
+ params: List[Dict[str, Any]] = []
223
+ memo: Set[torch.nn.parameter.Parameter] = set()
224
+ for module_name, module in model.named_modules():
225
+ for module_param_name, value in module.named_parameters(recurse=False):
226
+ if not value.requires_grad:
227
+ continue
228
+ # Avoid duplicating parameters
229
+ if value in memo:
230
+ continue
231
+ memo.add(value)
232
+
233
+ hyperparams = copy.copy(defaults)
234
+ if "backbone" in module_name:
235
+ hyperparams["lr"] = hyperparams["lr"] * cfg.SOLVER.BACKBONE_MULTIPLIER
236
+ if (
237
+ "relative_position_bias_table" in module_param_name
238
+ or "absolute_pos_embed" in module_param_name
239
+ ):
240
+ print(module_param_name)
241
+ hyperparams["weight_decay"] = 0.0
242
+ if isinstance(module, norm_module_types):
243
+ hyperparams["weight_decay"] = weight_decay_norm
244
+ if isinstance(module, torch.nn.Embedding):
245
+ hyperparams["weight_decay"] = weight_decay_embed
246
+ params.append({"params": [value], **hyperparams})
247
+
248
+ def maybe_add_full_model_gradient_clipping(optim):
249
+ # detectron2 doesn't have full model gradient clipping now
250
+ clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
251
+ enable = (
252
+ cfg.SOLVER.CLIP_GRADIENTS.ENABLED
253
+ and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
254
+ and clip_norm_val > 0.0
255
+ )
256
+
257
+ class FullModelGradientClippingOptimizer(optim):
258
+ def step(self, closure=None):
259
+ all_params = itertools.chain(*[x["params"] for x in self.param_groups])
260
+ torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
261
+ super().step(closure=closure)
262
+
263
+ return FullModelGradientClippingOptimizer if enable else optim
264
+
265
+ optimizer_type = cfg.SOLVER.OPTIMIZER
266
+ if optimizer_type == "SGD":
267
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
268
+ params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
269
+ )
270
+ elif optimizer_type == "ADAMW":
271
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
272
+ params, cfg.SOLVER.BASE_LR
273
+ )
274
+ else:
275
+ raise NotImplementedError(f"no optimizer type {optimizer_type}")
276
+ if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
277
+ optimizer = maybe_add_gradient_clipping(cfg, optimizer)
278
+ return optimizer
279
+
280
+ @classmethod
281
+ def test_with_TTA(cls, cfg, model):
282
+ logger = logging.getLogger("detectron2.trainer")
283
+ # In the end of training, run an evaluation with TTA.
284
+ logger.info("Running inference with test-time augmentation ...")
285
+ model = SemanticSegmentorWithTTA(cfg, model)
286
+ evaluators = [
287
+ cls.build_evaluator(
288
+ cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
289
+ )
290
+ for name in cfg.DATASETS.TEST
291
+ ]
292
+ res = cls.test(cfg, model, evaluators)
293
+ res = OrderedDict({k + "_TTA": v for k, v in res.items()})
294
+ return res
295
+
296
+
297
+ def setup(args):
298
+ """
299
+ Create configs and perform basic setups.
300
+ """
301
+ cfg = get_cfg()
302
+ # for poly lr schedule
303
+ add_deeplab_config(cfg)
304
+ add_maskformer2_config(cfg)
305
+ add_fcclip_config(cfg)
306
+ add_mask_adapter_config(cfg)
307
+ cfg.merge_from_file(args.config_file)
308
+ cfg.merge_from_list(args.opts)
309
+ cfg.freeze()
310
+ default_setup(cfg, args)
311
+ # Setup logger for "fcclip" module
312
+ setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="fcclip")
313
+ return cfg
314
+
315
+
316
+ def main(args):
317
+ cfg = setup(args)
318
+
319
+ if args.eval_only:
320
+ model = Trainer.build_model(cfg)
321
+
322
+ total_params = sum(p.numel() for p in model.parameters())
323
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
324
+ frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
325
+ frozen_params_exclude_text = 0
326
+ for n, p in model.named_parameters():
327
+ if p.requires_grad:
328
+ continue
329
+ # ignore text tower
330
+ if 'clip_model.token_embedding' in n or 'clip_model.positional_embedding' in n or 'clip_model.transformer' in n or 'clip_model.ln_final' in n or 'clip_model.text_projection' in n:
331
+ continue
332
+ frozen_params_exclude_text += p.numel()
333
+ print(f"total_params: {total_params}, trainable_params: {trainable_params}, frozen_params: {frozen_params}, frozen_params_exclude_text: {frozen_params_exclude_text}")
334
+
335
+ DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
336
+ cfg.MODEL.WEIGHTS, resume=args.resume
337
+ )
338
+ res = Trainer.test(cfg, model)
339
+ if cfg.TEST.AUG.ENABLED:
340
+ res.update(Trainer.test_with_TTA(cfg, model))
341
+ if comm.is_main_process():
342
+ verify_results(cfg, res)
343
+ return res
344
+
345
+ trainer = Trainer(cfg)
346
+ trainer.resume_or_load(resume=args.resume)
347
+ return trainer.train()
348
+
349
+
350
+ if __name__ == "__main__":
351
+ args = default_argument_parser().parse_args()
352
+ print("Command Line Args:", args)
353
+ launch(
354
+ main,
355
+ args.num_gpus,
356
+ num_machines=args.num_machines,
357
+ machine_rank=args.machine_rank,
358
+ dist_url=args.dist_url,
359
+ args=(args,),
360
+ )