Spaces:
Sleeping
Sleeping
diaoqishuai
commited on
Commit
·
4a3ad95
0
Parent(s):
first commit
Browse files- .gitignore +145 -0
- LICENSE +21 -0
- README.md +117 -0
- config.py +273 -0
- configs/MetaFG_0_224.yaml +5 -0
- configs/MetaFG_1_224.yaml +5 -0
- configs/MetaFG_2_224.yaml +5 -0
- configs/MetaFG_meta_0_224.yaml +8 -0
- configs/MetaFG_meta_1_224.yaml +8 -0
- configs/MetaFG_meta_2_224.yaml +8 -0
- configs/MetaFG_meta_attribute_1_224.yaml +8 -0
- configs/MetaFG_meta_bert_0_224.yaml +8 -0
- configs/MetaFG_meta_bert_1_224.yaml +8 -0
- data/__init__.py +1 -0
- data/build.py +169 -0
- data/cached_image_folder.py +251 -0
- data/dataset_fg.py +457 -0
- data/samplers.py +29 -0
- data/zipreader.py +103 -0
- figs/overview.png +0 -0
- get_flops.py +62 -0
- logger.py +47 -0
- lr_scheduler.py +102 -0
- main.py +403 -0
- models/MBConv.py +169 -0
- models/MHSA.py +161 -0
- models/MetaFG.py +213 -0
- models/MetaFG_meta.py +268 -0
- models/__init__.py +1 -0
- models/build.py +20 -0
- models/meta_encoder.py +21 -0
- optimizer.py +70 -0
- utils.py +173 -0
.gitignore
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
/output
|
131 |
+
/imagenet
|
132 |
+
/imagenet_raw
|
133 |
+
/pretrained_model
|
134 |
+
/inaturelist2021
|
135 |
+
/inaturelist2021_mini
|
136 |
+
/save_model
|
137 |
+
/inaturelist2017
|
138 |
+
/inaturelist2018
|
139 |
+
/cub-200
|
140 |
+
/stanfordcars
|
141 |
+
/oxfordflower
|
142 |
+
/stanforddogs
|
143 |
+
/nabirds
|
144 |
+
/aircraft
|
145 |
+
/datasets
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 dqshuai
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MetaFG
|
2 |
+
A repository for the code used to create and train the model defined in “MetaFormer : A Unified Meta Framework for Fine-Grained Recognition”
|
3 |
+
|
4 |
+
## Model zoo
|
5 |
+
| name | resolution | 1k model | 21k model | iNat21 model |
|
6 |
+
| :--------: | :----------: | :--------: | :----------: | :------------: |
|
7 |
+
| MetaFormer-0 | 384x384 | [metafg_0_1k_384](https://drive.google.com/file/d/1r62S3CJFRWV_qA5udC9MOFOJYwRf8mE2/view?usp=sharing) | [metafg_0_21k_384](https://drive.google.com/file/d/1wVmlPjNTA6JKHcF3ROGorEVPxKVO83Ss/view?usp=sharing) | [metafg_0_inat21_384](https://drive.google.com/file/d/11gCk_IuSN7krdkOUSWSM4xlf8GGknmxc/view?usp=sharing) |
|
8 |
+
| MetaFormer-1 | 384x384 | [metafg_1_1k_384](https://drive.google.com/file/d/12OTmZg4J6fMGvs-colOTDfmhdA5EMMvo/view?usp=sharing) | [metafg_1_21k_384](https://drive.google.com/file/d/13dsarbtsNrkhpG5XpCRlN5ogXDGXO3Z_/view?usp=sharing) | [metafg_1_inat21_384](https://drive.google.com/file/d/1ATUIrDxaQaGqx4lJ8HE2IwX_evMhblPu/view?usp=sharing) |
|
9 |
+
| MetaFormer-2 | 384x384 | [metafg_2_1k_384](https://drive.google.com/file/d/167oBaseORq32aFA3Ex6lpHuasvu2PMb8/view?usp=sharing) | [metafg_2_21k_384](https://drive.google.com/file/d/1PnpntloQaYduEokFGQ6y79G7DdyjD_u3/view?usp=sharing) | [metafg_2_inat21_384](https://drive.google.com/file/d/17sUNST7ivQhonBAfZEiTOLAgtaHa4F3e/view?usp=sharing) |
|
10 |
+
## Usage
|
11 |
+
#### python module
|
12 |
+
* install `Pytorch and torchvision`
|
13 |
+
```
|
14 |
+
pip install torch==1.5.1 torchvision==0.6.1
|
15 |
+
```
|
16 |
+
* install `timm`
|
17 |
+
```
|
18 |
+
pip install timm==0.4.5
|
19 |
+
```
|
20 |
+
* install `Apex`
|
21 |
+
```
|
22 |
+
git clone https://github.com/NVIDIA/apex
|
23 |
+
cd apex
|
24 |
+
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
|
25 |
+
```
|
26 |
+
* install other requirements
|
27 |
+
```
|
28 |
+
pip install opencv-python==4.5.1.48 yacs==0.1.8
|
29 |
+
```
|
30 |
+
#### data preparation
|
31 |
+
Download [inat21,18,17](https://github.com/visipedia/inat_comp),[CUB](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html),[NABirds](https://dl.allaboutbirds.org/nabirds),[stanfordcars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html), and[aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/), put them in respective folders (\<root\>/datasets/<dataset_name>) and Unzip file. The folder sturture as follow:
|
32 |
+
```
|
33 |
+
datasets
|
34 |
+
|————inraturelist2021
|
35 |
+
| └——————train
|
36 |
+
| └——————val
|
37 |
+
| └——————train.json
|
38 |
+
| └——————val.json
|
39 |
+
|————inraturelist2018
|
40 |
+
| └——————train_val_images
|
41 |
+
| └——————train2018.json
|
42 |
+
| └——————val2018.json
|
43 |
+
| └——————train2018_locations.json
|
44 |
+
| └——————val2018_locations.json
|
45 |
+
| └——————categories.json.json
|
46 |
+
|————inraturelist2017
|
47 |
+
| └——————train_val_images
|
48 |
+
| └——————train2017.json
|
49 |
+
| └——————val2017.json
|
50 |
+
| └——————train2017_locations.json
|
51 |
+
| └——————val2017_locations.json
|
52 |
+
|————cub-200
|
53 |
+
| └——————...
|
54 |
+
|————nabirds
|
55 |
+
| └——————...
|
56 |
+
|————stanfordcars
|
57 |
+
| └——————car_ims
|
58 |
+
| └——————cars_annos.mat
|
59 |
+
|————aircraft
|
60 |
+
| └——————...
|
61 |
+
```
|
62 |
+
#### Training
|
63 |
+
You can dowmload pre-trained model from model zoo, and put them under \<root\>/pretrained.
|
64 |
+
To train MetaFG on datasets, run:
|
65 |
+
```
|
66 |
+
python3 -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py --cfg <config-file> --dataset <dataset-name> --pretrain <pretainedmodel-path> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]
|
67 |
+
```
|
68 |
+
\<dataset-name\>:inaturelist2021,inaturelist2018,inaturelist2017,cub-200,nabirds,stanfordcars,aircraft
|
69 |
+
For CUB-200-2011, run:
|
70 |
+
```
|
71 |
+
python3 -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py --cfg ./configs/MetaFG_1_224.yaml --batch-size 32 --tag cub-200_v1 --lr 5e-5 --min-lr 5e-7 --warmup-lr 5e-8 --epochs 300 --warmup-epochs 20 --dataset cub-200 --pretrain ./pretrained_model/<xxxx>.pth --accumulation-steps 2 --opts DATA.IMG_SIZE 384
|
72 |
+
```
|
73 |
+
note that final learning rate is total_bs/512.
|
74 |
+
#### Eval
|
75 |
+
To evaluate model on dataset,run:
|
76 |
+
```
|
77 |
+
python3 -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py --eval --cfg <config-file> --dataset <dataset-name> --resume <checkpoint> [--batch-size <batch-size-per-gpu>]
|
78 |
+
```
|
79 |
+
## Main Result
|
80 |
+
#### ImageNet-1k
|
81 |
+
| Name | Resolution | #Param | #FLOPS | Throughput | Top-1 acc |
|
82 |
+
| :--------: | :----------: | :--------: | :----------: | :------------: | :------------: |
|
83 |
+
| MetaFormer-0 | 224x224 | 28M | 4.6G | 840.1 | 82.9 |
|
84 |
+
| MetaFormer-1 | 224x224 | 45M | 8.5G | 444.8 | 83.9 |
|
85 |
+
| MetaFormer-2 | 224x224 | 81M | 16.9G | 438.9 | 84.1 |
|
86 |
+
| MetaFormer-0 | 384x384 | 28M | 13.4G | 349.4 | 84.2 |
|
87 |
+
| MetaFormer-1 | 384x384 | 45M | 24.7G | 165.3 | 84.4 |
|
88 |
+
| MetaFormer-2 | 384x384 | 81M | 49.7G | 132.7 | 84.6 |
|
89 |
+
#### Fine-grained Datasets
|
90 |
+
Result on fine-grained datasets with different pre-trained model.
|
91 |
+
| Name | Pretrain | CUB | NABirds | iNat2017 | iNat2018 | Cars | Aircraft |
|
92 |
+
| :--------: | :----------: | :--------: | :----------: | :------------: | :------------: | :--------: |:--------: |
|
93 |
+
| MetaFormer-0|ImageNet-1k|89.6|89.1|75.7|79.5|95.0|91.2|
|
94 |
+
| MetaFormer-0|ImageNet-21k|89.7|89.5|75.8|79.9|94.6|91.2|
|
95 |
+
| MetaFormer-0|iNaturalist 2021|91.8|91.5|78.3|82.9|95.1|87.4|
|
96 |
+
| MetaFormer-1|ImageNet-1k|89.7|89.4|78.2|81.9|94.9|90.8|
|
97 |
+
| MetaFormer-1|ImageNet-21k|91.3|91.6|79.4|83.2|95.0|92.6|
|
98 |
+
| MetaFormer-1|iNaturalist 2021|92.3|92.7|82.0|87.5|95.0|92.5|
|
99 |
+
| MetaFormer-2|ImageNet-1k|89.7|89.7|79.0|82.6|95.0|92.4|
|
100 |
+
| MetaFormer-2|ImageNet-21k|91.8|92.2|80.4|84.3|95.1|92.9|
|
101 |
+
| MetaFormer-2|iNaturalist 2021|92.9|93.0|82.8|87.7|95.4|92.8|
|
102 |
+
|
103 |
+
|
104 |
+
Results in iNaturalist 2019, iNaturalist 2018, and iNaturalist 2021 with meta-information.
|
105 |
+
| Name | Pretrain | Meta added| iNat2017 | iNat2018 | iNat2021 |
|
106 |
+
| :--------: | :----------: | :--------: | :---------- | :------------ |:------------ |
|
107 |
+
|MetaFormer-0|ImageNet-1k|N|75.7|79.5|88.4|
|
108 |
+
|MetaFormer-0|ImageNet-1k|Y|79.8(+4.1)|85.4(+5.9)|92.6(+4.2)|
|
109 |
+
|MetaFormer-1|ImageNet-1k|N|78.2|81.9|90.2|
|
110 |
+
|MetaFormer-1|ImageNet-1k|Y|81.3(+3.1)|86.5(+4.6)|93.4(+3.2)|
|
111 |
+
|MetaFormer-2|ImageNet-1k|N|79.0|82.6|89.8|
|
112 |
+
|MetaFormer-2|ImageNet-1k|Y|82.0(+3.0)|86.8(+4.2)|93.2(+3.4)|
|
113 |
+
|MetaFormer-2|ImageNet-21k|N|80.4|84.3|90.3|
|
114 |
+
|MetaFormer-2|ImageNet-21k|Y|83.4(+3.0)|88.7(+4.4)|93.6(+3.3)|
|
115 |
+
## Citation
|
116 |
+
## Acknowledgement
|
117 |
+
Many thanks for [swin-transformer](https://github.com/microsoft/Swin-Transformer).A part of the code is borrowed from it.
|
config.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Swin Transformer
|
3 |
+
# Copyright (c) 2021 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Ze Liu
|
6 |
+
# --------------------------------------------------------'
|
7 |
+
|
8 |
+
import os
|
9 |
+
import yaml
|
10 |
+
from yacs.config import CfgNode as CN
|
11 |
+
|
12 |
+
_C = CN()
|
13 |
+
|
14 |
+
# Base config files
|
15 |
+
_C.BASE = ['']
|
16 |
+
|
17 |
+
# -----------------------------------------------------------------------------
|
18 |
+
# Data settings
|
19 |
+
# -----------------------------------------------------------------------------
|
20 |
+
_C.DATA = CN()
|
21 |
+
# Batch size for a single GPU, could be overwritten by command line argument
|
22 |
+
_C.DATA.BATCH_SIZE = 128
|
23 |
+
# Path to dataset, could be overwritten by command line argument
|
24 |
+
_C.DATA.DATA_PATH = ''
|
25 |
+
# Dataset name
|
26 |
+
_C.DATA.DATASET = 'imagenet'
|
27 |
+
# Input image size
|
28 |
+
_C.DATA.IMG_SIZE = 224
|
29 |
+
# Interpolation to resize image (random, bilinear, bicubic)
|
30 |
+
_C.DATA.INTERPOLATION = 'bicubic'
|
31 |
+
_C.DATA.TRAIN_INTERPOLATION = 'bicubic'
|
32 |
+
# Use zipped dataset instead of folder dataset
|
33 |
+
# could be overwritten by command line argument
|
34 |
+
_C.DATA.ZIP_MODE = False
|
35 |
+
# Cache Data in Memory, could be overwritten by command line argument
|
36 |
+
_C.DATA.CACHE_MODE = 'part'
|
37 |
+
# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
|
38 |
+
_C.DATA.PIN_MEMORY = True
|
39 |
+
# Number of data loading threads
|
40 |
+
_C.DATA.NUM_WORKERS = 8
|
41 |
+
# hdfs data dir
|
42 |
+
_C.DATA.TRAIN_PATH = None
|
43 |
+
_C.DATA.VAL_PATH = None
|
44 |
+
# arnold dataset parallel
|
45 |
+
_C.DATA.NUM_READERS = 4
|
46 |
+
|
47 |
+
|
48 |
+
#meta info
|
49 |
+
_C.DATA.ADD_META = False
|
50 |
+
_C.DATA.FUSION = 'early'
|
51 |
+
_C.DATA.MASK_PROB = 0.0
|
52 |
+
_C.DATA.MASK_TYPE = 'constant'
|
53 |
+
_C.DATA.LATE_FUSION_LAYER = -1
|
54 |
+
|
55 |
+
# -----------------------------------------------------------------------------
|
56 |
+
# Model settings
|
57 |
+
# -----------------------------------------------------------------------------
|
58 |
+
_C.MODEL = CN()
|
59 |
+
# Model type
|
60 |
+
_C.MODEL.TYPE = ''
|
61 |
+
# Model name
|
62 |
+
_C.MODEL.NAME = ''
|
63 |
+
# Checkpoint to resume, could be overwritten by command line argument
|
64 |
+
_C.MODEL.RESUME = ''
|
65 |
+
# Number of classes, overwritten in data preparation
|
66 |
+
_C.MODEL.NUM_CLASSES = 1000
|
67 |
+
# Dropout rate
|
68 |
+
_C.MODEL.DROP_RATE = 0.0
|
69 |
+
# Drop path rate
|
70 |
+
_C.MODEL.DROP_PATH_RATE = 0.1
|
71 |
+
# Label Smoothing
|
72 |
+
_C.MODEL.LABEL_SMOOTHING = 0.1
|
73 |
+
#pretrain
|
74 |
+
_C.MODEL.PRETRAINED = None
|
75 |
+
_C.MODEL.DORP_HEAD = True
|
76 |
+
_C.MODEL.DORP_META = True
|
77 |
+
|
78 |
+
_C.MODEL.ONLY_LAST_CLS = False
|
79 |
+
_C.MODEL.EXTRA_TOKEN_NUM = 1
|
80 |
+
_C.MODEL.META_DIMS = []
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
# -----------------------------------------------------------------------------
|
85 |
+
# Training settings
|
86 |
+
# -----------------------------------------------------------------------------
|
87 |
+
_C.TRAIN = CN()
|
88 |
+
_C.TRAIN.START_EPOCH = 0
|
89 |
+
_C.TRAIN.EPOCHS = 300
|
90 |
+
_C.TRAIN.WARMUP_EPOCHS = 20
|
91 |
+
_C.TRAIN.WEIGHT_DECAY = 0.05
|
92 |
+
_C.TRAIN.BASE_LR = 5e-4
|
93 |
+
_C.TRAIN.WARMUP_LR = 5e-7
|
94 |
+
_C.TRAIN.MIN_LR = 5e-6
|
95 |
+
# Clip gradient norm
|
96 |
+
_C.TRAIN.CLIP_GRAD = 5.0
|
97 |
+
# Auto resume from latest checkpoint
|
98 |
+
_C.TRAIN.AUTO_RESUME = True
|
99 |
+
# Gradient accumulation steps
|
100 |
+
# could be overwritten by command line argument
|
101 |
+
_C.TRAIN.ACCUMULATION_STEPS = 0
|
102 |
+
# Whether to use gradient checkpointing to save memory
|
103 |
+
# could be overwritten by command line argument
|
104 |
+
_C.TRAIN.USE_CHECKPOINT = False
|
105 |
+
|
106 |
+
# LR scheduler
|
107 |
+
_C.TRAIN.LR_SCHEDULER = CN()
|
108 |
+
_C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
|
109 |
+
# Epoch interval to decay LR, used in StepLRScheduler
|
110 |
+
_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
|
111 |
+
# LR decay rate, used in StepLRScheduler
|
112 |
+
_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
|
113 |
+
|
114 |
+
# Optimizer
|
115 |
+
_C.TRAIN.OPTIMIZER = CN()
|
116 |
+
_C.TRAIN.OPTIMIZER.NAME = 'adamw'
|
117 |
+
# Optimizer Epsilon
|
118 |
+
_C.TRAIN.OPTIMIZER.EPS = 1e-8
|
119 |
+
# Optimizer Betas
|
120 |
+
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
|
121 |
+
# SGD momentum
|
122 |
+
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
|
123 |
+
|
124 |
+
# -----------------------------------------------------------------------------
|
125 |
+
# Augmentation settings
|
126 |
+
# -----------------------------------------------------------------------------
|
127 |
+
_C.AUG = CN()
|
128 |
+
# Color jitter factor
|
129 |
+
_C.AUG.COLOR_JITTER = 0.4
|
130 |
+
# Use AutoAugment policy. "v0" or "original"
|
131 |
+
_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
|
132 |
+
# Random erase prob
|
133 |
+
_C.AUG.REPROB = 0.25
|
134 |
+
# Random erase mode
|
135 |
+
_C.AUG.REMODE = 'pixel'
|
136 |
+
# Random erase count
|
137 |
+
_C.AUG.RECOUNT = 1
|
138 |
+
# Mixup alpha, mixup enabled if > 0
|
139 |
+
_C.AUG.MIXUP = 0.8
|
140 |
+
# Cutmix alpha, cutmix enabled if > 0
|
141 |
+
_C.AUG.CUTMIX = 1.0
|
142 |
+
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
|
143 |
+
_C.AUG.CUTMIX_MINMAX = None
|
144 |
+
# Probability of performing mixup or cutmix when either/both is enabled
|
145 |
+
_C.AUG.MIXUP_PROB = 1.0
|
146 |
+
# Probability of switching to cutmix when both mixup and cutmix enabled
|
147 |
+
_C.AUG.MIXUP_SWITCH_PROB = 0.5
|
148 |
+
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
|
149 |
+
_C.AUG.MIXUP_MODE = 'batch'
|
150 |
+
|
151 |
+
# -----------------------------------------------------------------------------
|
152 |
+
# Testing settings
|
153 |
+
# -----------------------------------------------------------------------------
|
154 |
+
_C.TEST = CN()
|
155 |
+
# Whether to use center crop when testing
|
156 |
+
_C.TEST.CROP = True
|
157 |
+
|
158 |
+
# -----------------------------------------------------------------------------
|
159 |
+
# Misc
|
160 |
+
# -----------------------------------------------------------------------------
|
161 |
+
# Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
|
162 |
+
# overwritten by command line argument
|
163 |
+
_C.AMP_OPT_LEVEL = ''
|
164 |
+
# Path to output folder, overwritten by command line argument
|
165 |
+
_C.OUTPUT = ''
|
166 |
+
# Tag of experiment, overwritten by command line argument
|
167 |
+
_C.TAG = 'default'
|
168 |
+
# Frequency to save checkpoint
|
169 |
+
_C.SAVE_FREQ = 1
|
170 |
+
# Frequency to logging info
|
171 |
+
_C.PRINT_FREQ = 10
|
172 |
+
# Fixed random seed
|
173 |
+
_C.SEED = 0
|
174 |
+
# Perform evaluation only, overwritten by command line argument
|
175 |
+
_C.EVAL_MODE = False
|
176 |
+
# Test throughput only, overwritten by command line argument
|
177 |
+
_C.THROUGHPUT_MODE = False
|
178 |
+
# local rank for DistributedDataParallel, given by command line argument
|
179 |
+
_C.LOCAL_RANK = 0
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
+
def _update_config_from_file(config, cfg_file):
|
185 |
+
config.defrost()
|
186 |
+
with open(cfg_file, 'r') as f:
|
187 |
+
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
|
188 |
+
|
189 |
+
for cfg in yaml_cfg.setdefault('BASE', ['']):
|
190 |
+
if cfg:
|
191 |
+
_update_config_from_file(
|
192 |
+
config, os.path.join(os.path.dirname(cfg_file), cfg)
|
193 |
+
)
|
194 |
+
print('=> merge config from {}'.format(cfg_file))
|
195 |
+
config.merge_from_file(cfg_file)
|
196 |
+
config.freeze()
|
197 |
+
|
198 |
+
|
199 |
+
def update_config(config, args):
|
200 |
+
_update_config_from_file(config, args.cfg)
|
201 |
+
|
202 |
+
config.defrost()
|
203 |
+
if args.opts:
|
204 |
+
config.merge_from_list(args.opts)
|
205 |
+
|
206 |
+
# merge from specific arguments
|
207 |
+
if args.batch_size:
|
208 |
+
config.DATA.BATCH_SIZE = args.batch_size
|
209 |
+
if args.data_path:
|
210 |
+
config.DATA.DATA_PATH = args.data_path
|
211 |
+
if args.zip:
|
212 |
+
config.DATA.ZIP_MODE = True
|
213 |
+
if args.cache_mode:
|
214 |
+
config.DATA.CACHE_MODE = args.cache_mode
|
215 |
+
if args.resume:
|
216 |
+
config.MODEL.RESUME = args.resume
|
217 |
+
if args.accumulation_steps:
|
218 |
+
config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
|
219 |
+
if args.use_checkpoint:
|
220 |
+
config.TRAIN.USE_CHECKPOINT = True
|
221 |
+
if args.amp_opt_level:
|
222 |
+
config.AMP_OPT_LEVEL = args.amp_opt_level
|
223 |
+
if args.output:
|
224 |
+
config.OUTPUT = args.output
|
225 |
+
if args.tag:
|
226 |
+
config.TAG = args.tag
|
227 |
+
if args.eval:
|
228 |
+
config.EVAL_MODE = True
|
229 |
+
if args.throughput:
|
230 |
+
config.THROUGHPUT_MODE = True
|
231 |
+
|
232 |
+
|
233 |
+
if args.num_workers is not None:
|
234 |
+
config.DATA.NUM_WORKERS = args.num_workers
|
235 |
+
|
236 |
+
#set lr and weight decay
|
237 |
+
if args.lr is not None:
|
238 |
+
config.TRAIN.BASE_LR = args.lr
|
239 |
+
if args.min_lr is not None:
|
240 |
+
config.TRAIN.MIN_LR = args.min_lr
|
241 |
+
if args.warmup_lr is not None:
|
242 |
+
config.TRAIN.WARMUP_LR = args.warmup_lr
|
243 |
+
if args.warmup_epochs is not None:
|
244 |
+
config.TRAIN.WARMUP_EPOCHS = args.warmup_epochs
|
245 |
+
if args.weight_decay is not None:
|
246 |
+
config.TRAIN.WEIGHT_DECAY = args.weight_decay
|
247 |
+
|
248 |
+
if args.epochs is not None:
|
249 |
+
config.TRAIN.EPOCHS = args.epochs
|
250 |
+
if args.dataset is not None:
|
251 |
+
config.DATA.DATASET = args.dataset
|
252 |
+
if args.lr_scheduler_name is not None:
|
253 |
+
config.TRAIN.LR_SCHEDULER.NAME = args.lr_scheduler_name
|
254 |
+
if args.pretrain is not None:
|
255 |
+
config.MODEL.PRETRAINED = args.pretrain
|
256 |
+
|
257 |
+
# set local rank for distributed training
|
258 |
+
config.LOCAL_RANK = args.local_rank
|
259 |
+
|
260 |
+
# output folder
|
261 |
+
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
|
262 |
+
|
263 |
+
config.freeze()
|
264 |
+
|
265 |
+
|
266 |
+
def get_config(args):
|
267 |
+
"""Get a yacs CfgNode object with default values."""
|
268 |
+
# Return a clone so that the defaults will not be altered
|
269 |
+
# This is for the "local variable" use pattern
|
270 |
+
config = _C.clone()
|
271 |
+
update_config(config, args)
|
272 |
+
|
273 |
+
return config
|
configs/MetaFG_0_224.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATA:
|
2 |
+
IMG_SIZE: 224
|
3 |
+
MODEL:
|
4 |
+
TYPE: MetaFG
|
5 |
+
NAME: MetaFG_0
|
configs/MetaFG_1_224.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATA:
|
2 |
+
IMG_SIZE: 224
|
3 |
+
MODEL:
|
4 |
+
TYPE: MetaFG
|
5 |
+
NAME: MetaFG_1
|
configs/MetaFG_2_224.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATA:
|
2 |
+
IMG_SIZE: 224
|
3 |
+
MODEL:
|
4 |
+
TYPE: MetaFG
|
5 |
+
NAME: MetaFG_2
|
configs/MetaFG_meta_0_224.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATA:
|
2 |
+
IMG_SIZE: 224
|
3 |
+
ADD_META: True
|
4 |
+
MODEL:
|
5 |
+
TYPE: MetaFG
|
6 |
+
NAME: MetaFG_meta_0
|
7 |
+
EXTRA_TOKEN_NUM: 3
|
8 |
+
META_DIMS: [ 4, 3 ]
|
configs/MetaFG_meta_1_224.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATA:
|
2 |
+
IMG_SIZE: 224
|
3 |
+
ADD_META: True
|
4 |
+
MODEL:
|
5 |
+
TYPE: MetaFG
|
6 |
+
NAME: MetaFG_meta_1
|
7 |
+
EXTRA_TOKEN_NUM: 3
|
8 |
+
META_DIMS: [ 4, 3 ]
|
configs/MetaFG_meta_2_224.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATA:
|
2 |
+
IMG_SIZE: 224
|
3 |
+
ADD_META: True
|
4 |
+
MODEL:
|
5 |
+
TYPE: MetaFG
|
6 |
+
NAME: MetaFG_meta_2
|
7 |
+
EXTRA_TOKEN_NUM: 3
|
8 |
+
META_DIMS: [ 4, 3 ]
|
configs/MetaFG_meta_attribute_1_224.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATA:
|
2 |
+
IMG_SIZE: 224
|
3 |
+
ADD_META: True
|
4 |
+
MODEL:
|
5 |
+
TYPE: MetaFG
|
6 |
+
NAME: MetaFG_meta_1
|
7 |
+
EXTRA_TOKEN_NUM: 2
|
8 |
+
META_DIMS: [ 312, ]
|
configs/MetaFG_meta_bert_0_224.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATA:
|
2 |
+
IMG_SIZE: 224
|
3 |
+
ADD_META: True
|
4 |
+
MODEL:
|
5 |
+
TYPE: MetaFG
|
6 |
+
NAME: MetaFG_meta_0
|
7 |
+
EXTRA_TOKEN_NUM: 33
|
8 |
+
META_DIMS: [ 768, ]
|
configs/MetaFG_meta_bert_1_224.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATA:
|
2 |
+
IMG_SIZE: 224
|
3 |
+
ADD_META: True
|
4 |
+
MODEL:
|
5 |
+
TYPE: MetaFG
|
6 |
+
NAME: MetaFG_meta_1
|
7 |
+
EXTRA_TOKEN_NUM: 33
|
8 |
+
META_DIMS: [ 768, ]
|
data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .build import build_loader
|
data/build.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Swin Transformer
|
3 |
+
# Copyright (c) 2021 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Ze Liu
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
import torch.distributed as dist
|
12 |
+
from torchvision import datasets, transforms
|
13 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
14 |
+
from timm.data import Mixup
|
15 |
+
from timm.data import create_transform
|
16 |
+
from timm.data.transforms import _pil_interp
|
17 |
+
|
18 |
+
from .cached_image_folder import CachedImageFolder
|
19 |
+
from .samplers import SubsetRandomSampler
|
20 |
+
from .dataset_fg import DatasetMeta
|
21 |
+
def build_loader(config):
|
22 |
+
config.defrost()
|
23 |
+
dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
|
24 |
+
config.freeze()
|
25 |
+
print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
|
26 |
+
dataset_val, _ = build_dataset(is_train=False, config=config)
|
27 |
+
print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
|
28 |
+
|
29 |
+
num_tasks = dist.get_world_size()
|
30 |
+
global_rank = dist.get_rank()
|
31 |
+
if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
|
32 |
+
indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
|
33 |
+
sampler_train = SubsetRandomSampler(indices)
|
34 |
+
else:
|
35 |
+
sampler_train = torch.utils.data.DistributedSampler(
|
36 |
+
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
37 |
+
)
|
38 |
+
|
39 |
+
indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size())
|
40 |
+
sampler_val = SubsetRandomSampler(indices)
|
41 |
+
|
42 |
+
data_loader_train = torch.utils.data.DataLoader(
|
43 |
+
dataset_train, sampler=sampler_train,
|
44 |
+
batch_size=config.DATA.BATCH_SIZE,
|
45 |
+
num_workers=config.DATA.NUM_WORKERS,
|
46 |
+
pin_memory=config.DATA.PIN_MEMORY,
|
47 |
+
drop_last=True,
|
48 |
+
)
|
49 |
+
|
50 |
+
data_loader_val = torch.utils.data.DataLoader(
|
51 |
+
dataset_val, sampler=sampler_val,
|
52 |
+
batch_size=config.DATA.BATCH_SIZE,
|
53 |
+
shuffle=False,
|
54 |
+
num_workers=config.DATA.NUM_WORKERS,
|
55 |
+
pin_memory=config.DATA.PIN_MEMORY,
|
56 |
+
drop_last=False
|
57 |
+
)
|
58 |
+
|
59 |
+
# setup mixup / cutmix
|
60 |
+
mixup_fn = None
|
61 |
+
mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
|
62 |
+
if mixup_active:
|
63 |
+
mixup_fn = Mixup(
|
64 |
+
mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
|
65 |
+
prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
|
66 |
+
label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
|
67 |
+
|
68 |
+
return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
|
69 |
+
|
70 |
+
|
71 |
+
def build_dataset(is_train, config):
|
72 |
+
transform = build_transform(is_train, config)
|
73 |
+
if config.DATA.DATASET == 'imagenet':
|
74 |
+
prefix = 'train' if is_train else 'val'
|
75 |
+
if config.DATA.ZIP_MODE:
|
76 |
+
ann_file = prefix + "_map.txt"
|
77 |
+
prefix = prefix + ".zip@/"
|
78 |
+
dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
|
79 |
+
cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
|
80 |
+
else:
|
81 |
+
# root = os.path.join(config.DATA.DATA_PATH, prefix)
|
82 |
+
root = './datasets/imagenet'
|
83 |
+
dataset = datasets.ImageFolder(root, transform=transform)
|
84 |
+
nb_classes = 1000
|
85 |
+
elif config.DATA.DATASET == 'inaturelist2021':
|
86 |
+
root = './datasets/inaturelist2021'
|
87 |
+
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET,
|
88 |
+
class_ratio=config.DATA.CLASS_RATIO,per_sample=config.DATA.PER_SAMPLE)
|
89 |
+
nb_classes = 10000
|
90 |
+
elif config.DATA.DATASET == 'inaturelist2021_mini':
|
91 |
+
root = './datasets/inaturelist2021_mini'
|
92 |
+
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
93 |
+
nb_classes = 10000
|
94 |
+
elif config.DATA.DATASET == 'inaturelist2017':
|
95 |
+
root = './datasets/inaturelist2017'
|
96 |
+
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
97 |
+
nb_classes = 5089
|
98 |
+
elif config.DATA.DATASET == 'inaturelist2018':
|
99 |
+
root = './datasets/inaturelist2018'
|
100 |
+
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
101 |
+
nb_classes = 8142
|
102 |
+
elif config.DATA.DATASET == 'cub-200':
|
103 |
+
root = './datasets/cub-200'
|
104 |
+
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
105 |
+
nb_classes = 200
|
106 |
+
elif config.DATA.DATASET == 'stanfordcars':
|
107 |
+
root = './datasets/stanfordcars'
|
108 |
+
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
109 |
+
nb_classes = 196
|
110 |
+
elif config.DATA.DATASET == 'oxfordflower':
|
111 |
+
root = './datasets/oxfordflower'
|
112 |
+
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
113 |
+
nb_classes = 102
|
114 |
+
elif config.DATA.DATASET == 'stanforddogs':
|
115 |
+
root = './datasets/stanforddogs'
|
116 |
+
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
117 |
+
nb_classes = 120
|
118 |
+
elif config.DATA.DATASET == 'nabirds':
|
119 |
+
root = './datasets/nabirds'
|
120 |
+
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
121 |
+
nb_classes = 555
|
122 |
+
elif config.DATA.DATASET == 'aircraft':
|
123 |
+
root = './datasets/aircraft'
|
124 |
+
dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
|
125 |
+
nb_classes = 100
|
126 |
+
else:
|
127 |
+
raise NotImplementedError("We only support ImageNet and inaturelist.")
|
128 |
+
|
129 |
+
return dataset, nb_classes
|
130 |
+
|
131 |
+
|
132 |
+
def build_transform(is_train, config):
|
133 |
+
resize_im = config.DATA.IMG_SIZE > 32
|
134 |
+
if is_train:
|
135 |
+
# this should always dispatch to transforms_imagenet_train
|
136 |
+
transform = create_transform(
|
137 |
+
input_size=config.DATA.IMG_SIZE,
|
138 |
+
is_training=True,
|
139 |
+
color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
|
140 |
+
auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
|
141 |
+
re_prob=config.AUG.REPROB,
|
142 |
+
re_mode=config.AUG.REMODE,
|
143 |
+
re_count=config.AUG.RECOUNT,
|
144 |
+
interpolation=config.DATA.TRAIN_INTERPOLATION,
|
145 |
+
)
|
146 |
+
if not resize_im:
|
147 |
+
# replace RandomResizedCropAndInterpolation with
|
148 |
+
# RandomCrop
|
149 |
+
transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
|
150 |
+
return transform
|
151 |
+
|
152 |
+
t = []
|
153 |
+
if resize_im:
|
154 |
+
if config.TEST.CROP:
|
155 |
+
size = int((256 / 224) * config.DATA.IMG_SIZE)
|
156 |
+
t.append(
|
157 |
+
transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
|
158 |
+
# to maintain same ratio w.r.t. 224 images
|
159 |
+
)
|
160 |
+
t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
|
161 |
+
else:
|
162 |
+
t.append(
|
163 |
+
transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
|
164 |
+
interpolation=_pil_interp(config.DATA.INTERPOLATION))
|
165 |
+
)
|
166 |
+
|
167 |
+
t.append(transforms.ToTensor())
|
168 |
+
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
|
169 |
+
return transforms.Compose(t)
|
data/cached_image_folder.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Swin Transformer
|
3 |
+
# Copyright (c) 2021 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Ze Liu
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import io
|
9 |
+
import os
|
10 |
+
import time
|
11 |
+
import torch.distributed as dist
|
12 |
+
import torch.utils.data as data
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
from .zipreader import is_zip_path, ZipReader
|
16 |
+
|
17 |
+
|
18 |
+
def has_file_allowed_extension(filename, extensions):
|
19 |
+
"""Checks if a file is an allowed extension.
|
20 |
+
Args:
|
21 |
+
filename (string): path to a file
|
22 |
+
Returns:
|
23 |
+
bool: True if the filename ends with a known image extension
|
24 |
+
"""
|
25 |
+
filename_lower = filename.lower()
|
26 |
+
return any(filename_lower.endswith(ext) for ext in extensions)
|
27 |
+
|
28 |
+
|
29 |
+
def find_classes(dir):
|
30 |
+
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
|
31 |
+
classes.sort()
|
32 |
+
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
33 |
+
return classes, class_to_idx
|
34 |
+
|
35 |
+
|
36 |
+
def make_dataset(dir, class_to_idx, extensions):
|
37 |
+
images = []
|
38 |
+
dir = os.path.expanduser(dir)
|
39 |
+
for target in sorted(os.listdir(dir)):
|
40 |
+
d = os.path.join(dir, target)
|
41 |
+
if not os.path.isdir(d):
|
42 |
+
continue
|
43 |
+
|
44 |
+
for root, _, fnames in sorted(os.walk(d)):
|
45 |
+
for fname in sorted(fnames):
|
46 |
+
if has_file_allowed_extension(fname, extensions):
|
47 |
+
path = os.path.join(root, fname)
|
48 |
+
item = (path, class_to_idx[target])
|
49 |
+
images.append(item)
|
50 |
+
|
51 |
+
return images
|
52 |
+
|
53 |
+
|
54 |
+
def make_dataset_with_ann(ann_file, img_prefix, extensions):
|
55 |
+
images = []
|
56 |
+
with open(ann_file, "r") as f:
|
57 |
+
contents = f.readlines()
|
58 |
+
for line_str in contents:
|
59 |
+
path_contents = [c for c in line_str.split('\t')]
|
60 |
+
im_file_name = path_contents[0]
|
61 |
+
class_index = int(path_contents[1])
|
62 |
+
|
63 |
+
assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions
|
64 |
+
item = (os.path.join(img_prefix, im_file_name), class_index)
|
65 |
+
|
66 |
+
images.append(item)
|
67 |
+
|
68 |
+
return images
|
69 |
+
|
70 |
+
|
71 |
+
class DatasetFolder(data.Dataset):
|
72 |
+
"""A generic data loader where the samples are arranged in this way: ::
|
73 |
+
root/class_x/xxx.ext
|
74 |
+
root/class_x/xxy.ext
|
75 |
+
root/class_x/xxz.ext
|
76 |
+
root/class_y/123.ext
|
77 |
+
root/class_y/nsdf3.ext
|
78 |
+
root/class_y/asd932_.ext
|
79 |
+
Args:
|
80 |
+
root (string): Root directory path.
|
81 |
+
loader (callable): A function to load a sample given its path.
|
82 |
+
extensions (list[string]): A list of allowed extensions.
|
83 |
+
transform (callable, optional): A function/transform that takes in
|
84 |
+
a sample and returns a transformed version.
|
85 |
+
E.g, ``transforms.RandomCrop`` for images.
|
86 |
+
target_transform (callable, optional): A function/transform that takes
|
87 |
+
in the target and transforms it.
|
88 |
+
Attributes:
|
89 |
+
samples (list): List of (sample path, class_index) tuples
|
90 |
+
"""
|
91 |
+
|
92 |
+
def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
|
93 |
+
cache_mode="no"):
|
94 |
+
# image folder mode
|
95 |
+
if ann_file == '':
|
96 |
+
_, class_to_idx = find_classes(root)
|
97 |
+
samples = make_dataset(root, class_to_idx, extensions)
|
98 |
+
# zip mode
|
99 |
+
else:
|
100 |
+
samples = make_dataset_with_ann(os.path.join(root, ann_file),
|
101 |
+
os.path.join(root, img_prefix),
|
102 |
+
extensions)
|
103 |
+
|
104 |
+
if len(samples) == 0:
|
105 |
+
raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
|
106 |
+
"Supported extensions are: " + ",".join(extensions)))
|
107 |
+
|
108 |
+
self.root = root
|
109 |
+
self.loader = loader
|
110 |
+
self.extensions = extensions
|
111 |
+
|
112 |
+
self.samples = samples
|
113 |
+
self.labels = [y_1k for _, y_1k in samples]
|
114 |
+
self.classes = list(set(self.labels))
|
115 |
+
|
116 |
+
self.transform = transform
|
117 |
+
self.target_transform = target_transform
|
118 |
+
|
119 |
+
self.cache_mode = cache_mode
|
120 |
+
if self.cache_mode != "no":
|
121 |
+
self.init_cache()
|
122 |
+
|
123 |
+
def init_cache(self):
|
124 |
+
assert self.cache_mode in ["part", "full"]
|
125 |
+
n_sample = len(self.samples)
|
126 |
+
global_rank = dist.get_rank()
|
127 |
+
world_size = dist.get_world_size()
|
128 |
+
|
129 |
+
samples_bytes = [None for _ in range(n_sample)]
|
130 |
+
start_time = time.time()
|
131 |
+
for index in range(n_sample):
|
132 |
+
if index % (n_sample // 10) == 0:
|
133 |
+
t = time.time() - start_time
|
134 |
+
print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')
|
135 |
+
start_time = time.time()
|
136 |
+
path, target = self.samples[index]
|
137 |
+
if self.cache_mode == "full":
|
138 |
+
samples_bytes[index] = (ZipReader.read(path), target)
|
139 |
+
elif self.cache_mode == "part" and index % world_size == global_rank:
|
140 |
+
samples_bytes[index] = (ZipReader.read(path), target)
|
141 |
+
else:
|
142 |
+
samples_bytes[index] = (path, target)
|
143 |
+
self.samples = samples_bytes
|
144 |
+
|
145 |
+
def __getitem__(self, index):
|
146 |
+
"""
|
147 |
+
Args:
|
148 |
+
index (int): Index
|
149 |
+
Returns:
|
150 |
+
tuple: (sample, target) where target is class_index of the target class.
|
151 |
+
"""
|
152 |
+
path, target = self.samples[index]
|
153 |
+
sample = self.loader(path)
|
154 |
+
if self.transform is not None:
|
155 |
+
sample = self.transform(sample)
|
156 |
+
if self.target_transform is not None:
|
157 |
+
target = self.target_transform(target)
|
158 |
+
|
159 |
+
return sample, target
|
160 |
+
|
161 |
+
def __len__(self):
|
162 |
+
return len(self.samples)
|
163 |
+
|
164 |
+
def __repr__(self):
|
165 |
+
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
166 |
+
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
|
167 |
+
fmt_str += ' Root Location: {}\n'.format(self.root)
|
168 |
+
tmp = ' Transforms (if any): '
|
169 |
+
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
170 |
+
tmp = ' Target Transforms (if any): '
|
171 |
+
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
172 |
+
return fmt_str
|
173 |
+
|
174 |
+
|
175 |
+
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
|
176 |
+
|
177 |
+
|
178 |
+
def pil_loader(path):
|
179 |
+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
180 |
+
if isinstance(path, bytes):
|
181 |
+
img = Image.open(io.BytesIO(path))
|
182 |
+
elif is_zip_path(path):
|
183 |
+
data = ZipReader.read(path)
|
184 |
+
img = Image.open(io.BytesIO(data))
|
185 |
+
else:
|
186 |
+
with open(path, 'rb') as f:
|
187 |
+
img = Image.open(f)
|
188 |
+
return img.convert('RGB')
|
189 |
+
|
190 |
+
|
191 |
+
def accimage_loader(path):
|
192 |
+
import accimage
|
193 |
+
try:
|
194 |
+
return accimage.Image(path)
|
195 |
+
except IOError:
|
196 |
+
# Potentially a decoding problem, fall back to PIL.Image
|
197 |
+
return pil_loader(path)
|
198 |
+
|
199 |
+
|
200 |
+
def default_img_loader(path):
|
201 |
+
from torchvision import get_image_backend
|
202 |
+
if get_image_backend() == 'accimage':
|
203 |
+
return accimage_loader(path)
|
204 |
+
else:
|
205 |
+
return pil_loader(path)
|
206 |
+
|
207 |
+
|
208 |
+
class CachedImageFolder(DatasetFolder):
|
209 |
+
"""A generic data loader where the images are arranged in this way: ::
|
210 |
+
root/dog/xxx.png
|
211 |
+
root/dog/xxy.png
|
212 |
+
root/dog/xxz.png
|
213 |
+
root/cat/123.png
|
214 |
+
root/cat/nsdf3.png
|
215 |
+
root/cat/asd932_.png
|
216 |
+
Args:
|
217 |
+
root (string): Root directory path.
|
218 |
+
transform (callable, optional): A function/transform that takes in an PIL image
|
219 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
220 |
+
target_transform (callable, optional): A function/transform that takes in the
|
221 |
+
target and transforms it.
|
222 |
+
loader (callable, optional): A function to load an image given its path.
|
223 |
+
Attributes:
|
224 |
+
imgs (list): List of (image path, class_index) tuples
|
225 |
+
"""
|
226 |
+
|
227 |
+
def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
|
228 |
+
loader=default_img_loader, cache_mode="no"):
|
229 |
+
super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
|
230 |
+
ann_file=ann_file, img_prefix=img_prefix,
|
231 |
+
transform=transform, target_transform=target_transform,
|
232 |
+
cache_mode=cache_mode)
|
233 |
+
self.imgs = self.samples
|
234 |
+
|
235 |
+
def __getitem__(self, index):
|
236 |
+
"""
|
237 |
+
Args:
|
238 |
+
index (int): Index
|
239 |
+
Returns:
|
240 |
+
tuple: (image, target) where target is class_index of the target class.
|
241 |
+
"""
|
242 |
+
path, target = self.samples[index]
|
243 |
+
image = self.loader(path)
|
244 |
+
if self.transform is not None:
|
245 |
+
img = self.transform(image)
|
246 |
+
else:
|
247 |
+
img = image
|
248 |
+
if self.target_transform is not None:
|
249 |
+
target = self.target_transform(target)
|
250 |
+
|
251 |
+
return img, target
|
data/dataset_fg.py
ADDED
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import csv
|
6 |
+
import json
|
7 |
+
import torch
|
8 |
+
import tarfile
|
9 |
+
import pickle
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
import random
|
13 |
+
random.seed(2021)
|
14 |
+
from PIL import Image
|
15 |
+
from scipy import io as scio
|
16 |
+
from math import radians, cos, sin, asin, sqrt, pi
|
17 |
+
IMG_EXTENSIONS = ['.png', '.jpg', '.jpeg']
|
18 |
+
def get_spatial_info(latitude,longitude):
|
19 |
+
if latitude and longitude:
|
20 |
+
latitude = radians(latitude)
|
21 |
+
longitude = radians(longitude)
|
22 |
+
x = cos(latitude)*cos(longitude)
|
23 |
+
y = cos(latitude)*sin(longitude)
|
24 |
+
z = sin(latitude)
|
25 |
+
return [x,y,z]
|
26 |
+
else:
|
27 |
+
return [0,0,0]
|
28 |
+
def get_temporal_info(date,miss_hour=False):
|
29 |
+
try:
|
30 |
+
if date:
|
31 |
+
if miss_hour:
|
32 |
+
pattern = re.compile(r'(\d*)-(\d*)-(\d*)', re.I)
|
33 |
+
else:
|
34 |
+
pattern = re.compile(r'(\d*)-(\d*)-(\d*) (\d*):(\d*):(\d*)', re.I)
|
35 |
+
m = pattern.match(date.strip())
|
36 |
+
|
37 |
+
if m:
|
38 |
+
year = int(m.group(1))
|
39 |
+
month = int(m.group(2))
|
40 |
+
day = int(m.group(3))
|
41 |
+
x_month = sin(2*pi*month/12)
|
42 |
+
y_month = cos(2*pi*month/12)
|
43 |
+
if miss_hour:
|
44 |
+
x_hour = 0
|
45 |
+
y_hour = 0
|
46 |
+
else:
|
47 |
+
hour = int(m.group(4))
|
48 |
+
x_hour = sin(2*pi*hour/24)
|
49 |
+
y_hour = cos(2*pi*hour/24)
|
50 |
+
return [x_month,y_month,x_hour,y_hour]
|
51 |
+
else:
|
52 |
+
return [0,0,0,0]
|
53 |
+
else:
|
54 |
+
return [0,0,0,0]
|
55 |
+
except:
|
56 |
+
return [0,0,0,0]
|
57 |
+
def load_file(root,dataset):
|
58 |
+
if dataset == 'inaturelist2017':
|
59 |
+
year_flag = 7
|
60 |
+
elif dataset == 'inaturelist2018':
|
61 |
+
year_flag = 8
|
62 |
+
|
63 |
+
if dataset == 'inaturelist2018':
|
64 |
+
with open(os.path.join(root,'categories.json'),'r') as f:
|
65 |
+
map_label = json.load(f)
|
66 |
+
map_2018 = dict()
|
67 |
+
for _map in map_label:
|
68 |
+
map_2018[int(_map['id'])] = _map['name'].strip().lower()
|
69 |
+
with open(os.path.join(root,f'val201{year_flag}_locations.json'),'r') as f:
|
70 |
+
val_location = json.load(f)
|
71 |
+
val_id2meta = dict()
|
72 |
+
for meta_info in val_location:
|
73 |
+
val_id2meta[meta_info['id']] = meta_info
|
74 |
+
with open(os.path.join(root,f'train201{year_flag}_locations.json'),'r') as f:
|
75 |
+
train_location = json.load(f)
|
76 |
+
train_id2meta = dict()
|
77 |
+
for meta_info in train_location:
|
78 |
+
train_id2meta[meta_info['id']] = meta_info
|
79 |
+
with open(os.path.join(root,f'val201{year_flag}.json'),'r') as f:
|
80 |
+
val_class_info = json.load(f)
|
81 |
+
with open(os.path.join(root,f'train201{year_flag}.json'),'r') as f:
|
82 |
+
train_class_info = json.load(f)
|
83 |
+
|
84 |
+
if dataset == 'inaturelist2017':
|
85 |
+
categories_2017 = [x['name'].strip().lower() for x in val_class_info['categories']]
|
86 |
+
class_to_idx = {c: idx for idx, c in enumerate(categories_2017)}
|
87 |
+
id2label = dict()
|
88 |
+
for categorie in val_class_info['categories']:
|
89 |
+
id2label[int(categorie['id'])] = categorie['name'].strip().lower()
|
90 |
+
elif dataset == 'inaturelist2018':
|
91 |
+
categories_2018 = [x['name'].strip().lower() for x in map_label]
|
92 |
+
class_to_idx = {c: idx for idx, c in enumerate(categories_2018)}
|
93 |
+
id2label = dict()
|
94 |
+
for categorie in val_class_info['categories']:
|
95 |
+
name = map_2018[int(categorie['name'])]
|
96 |
+
id2label[int(categorie['id'])] = name.strip().lower()
|
97 |
+
|
98 |
+
return train_class_info,train_id2meta,val_class_info,val_id2meta,class_to_idx,id2label
|
99 |
+
def find_images_and_targets_cub200(root,dataset,istrain=False,aux_info=False):
|
100 |
+
imageid2label = {}
|
101 |
+
with open(os.path.join(os.path.join(root,'CUB_200_2011'),'image_class_labels.txt'),'r') as f:
|
102 |
+
for line in f:
|
103 |
+
image_id,label = line.split()
|
104 |
+
imageid2label[int(image_id)] = int(label)-1
|
105 |
+
imageid2split = {}
|
106 |
+
with open(os.path.join(os.path.join(root,'CUB_200_2011'),'train_test_split.txt'),'r') as f:
|
107 |
+
for line in f:
|
108 |
+
image_id,split = line.split()
|
109 |
+
imageid2split[int(image_id)] = int(split)
|
110 |
+
images_and_targets = []
|
111 |
+
images_info = []
|
112 |
+
images_root = os.path.join(os.path.join(root,'CUB_200_2011'),'images')
|
113 |
+
bert_embedding_root = os.path.join(root,'bert_embedding_cub')
|
114 |
+
text_root = os.path.join(root,'text_c10')
|
115 |
+
with open(os.path.join(os.path.join(root,'CUB_200_2011'),'images.txt'),'r') as f:
|
116 |
+
for line in f:
|
117 |
+
image_id,file_name = line.split()
|
118 |
+
file_path = os.path.join(images_root,file_name)
|
119 |
+
target = imageid2label[int(image_id)]
|
120 |
+
if aux_info:
|
121 |
+
with open(os.path.join(bert_embedding_root,file_name.replace('.jpg','.pickle')),'rb') as f_bert:
|
122 |
+
bert_embedding = pickle.load(f_bert)
|
123 |
+
bert_embedding = bert_embedding['embedding_words']
|
124 |
+
text_list = []
|
125 |
+
with open(os.path.join(text_root,file_name.replace('.jpg','.txt')),'r') as f_text:
|
126 |
+
for line in f_text:
|
127 |
+
line = line.encode(encoding='UTF-8',errors='strict')
|
128 |
+
line = line.replace(b'\xef\xbf\xbd\xef\xbf\xbd',b' ')
|
129 |
+
line = line.decode('UTF-8','strict')
|
130 |
+
text_list.append(line)
|
131 |
+
if istrain and imageid2split[int(image_id)]==1:
|
132 |
+
if aux_info:
|
133 |
+
images_and_targets.append([file_path,target,bert_embedding])
|
134 |
+
images_info.append({'text_list':text_list})
|
135 |
+
else:
|
136 |
+
images_and_targets.append([file_path,target])
|
137 |
+
elif not istrain and imageid2split[int(image_id)]==0:
|
138 |
+
if aux_info:
|
139 |
+
images_and_targets.append([file_path,target,bert_embedding])
|
140 |
+
images_info.append({'text_list':text_list})
|
141 |
+
else:
|
142 |
+
images_and_targets.append([file_path,target])
|
143 |
+
return images_and_targets,None,images_info
|
144 |
+
def find_images_and_targets_cub200_attribute(root,dataset,istrain=False,aux_info=False):
|
145 |
+
imageid2label = {}
|
146 |
+
with open(os.path.join(os.path.join(root,'CUB_200_2011'),'image_class_labels.txt'),'r') as f:
|
147 |
+
for line in f:
|
148 |
+
image_id,label = line.split()
|
149 |
+
imageid2label[int(image_id)] = int(label)-1
|
150 |
+
imageid2split = {}
|
151 |
+
with open(os.path.join(os.path.join(root,'CUB_200_2011'),'train_test_split.txt'),'r') as f:
|
152 |
+
for line in f:
|
153 |
+
image_id,split = line.split()
|
154 |
+
imageid2split[int(image_id)] = int(split)
|
155 |
+
images_and_targets = []
|
156 |
+
images_info = []
|
157 |
+
images_root = os.path.join(os.path.join(root,'CUB_200_2011'),'images')
|
158 |
+
attributes_root = os.path.join(os.path.join(root,'CUB_200_2011'),'attributes')
|
159 |
+
imageid2attribute = {}
|
160 |
+
with open(os.path.join(attributes_root,'image_attribute_labels.txt'),'r') as f:
|
161 |
+
for line in f:
|
162 |
+
if len(line.split())==6:
|
163 |
+
image_id,attribute_id,is_present,_,_,_ = line.split()
|
164 |
+
else:
|
165 |
+
image_id,attribute_id,is_present,certainty_id,time = line.split()
|
166 |
+
if int(image_id) not in imageid2attribute:
|
167 |
+
imageid2attribute[int(image_id)] = [0 for i in range(312)]
|
168 |
+
imageid2attribute[int(image_id)][int(attribute_id)-1] = int(is_present)
|
169 |
+
with open(os.path.join(os.path.join(root,'CUB_200_2011'),'images.txt'),'r') as f:
|
170 |
+
for line in f:
|
171 |
+
image_id,file_name = line.split()
|
172 |
+
file_path = os.path.join(images_root,file_name)
|
173 |
+
target = imageid2label[int(image_id)]
|
174 |
+
if aux_info:
|
175 |
+
pass
|
176 |
+
if istrain and imageid2split[int(image_id)]==1:
|
177 |
+
if aux_info:
|
178 |
+
images_and_targets.append([file_path,target,imageid2attribute[int(image_id)]])
|
179 |
+
images_info.append({'attributes':imageid2attribute[int(image_id)]})
|
180 |
+
else:
|
181 |
+
images_and_targets.append([file_path,target])
|
182 |
+
elif not istrain and imageid2split[int(image_id)]==0:
|
183 |
+
if aux_info:
|
184 |
+
images_and_targets.append([file_path,target,imageid2attribute[int(image_id)]])
|
185 |
+
images_info.append({'attributes':imageid2attribute[int(image_id)]})
|
186 |
+
else:
|
187 |
+
images_and_targets.append([file_path,target])
|
188 |
+
return images_and_targets,None,images_info
|
189 |
+
def find_images_and_targets_oxfordflower(root,dataset,istrain=False,aux_info=False):
|
190 |
+
imagelabels = scio.loadmat(os.path.join(root,'imagelabels.mat'))
|
191 |
+
imagelabels = imagelabels['labels'][0]
|
192 |
+
train_val_split = scio.loadmat(os.path.join(root,'setid.mat'))
|
193 |
+
train_data = train_val_split['trnid'][0].tolist()
|
194 |
+
val_data = train_val_split['valid'][0].tolist()
|
195 |
+
test_data = train_val_split['tstid'][0].tolist()
|
196 |
+
images_and_targets = []
|
197 |
+
images_info = []
|
198 |
+
images_root = os.path.join(root,'jpg')
|
199 |
+
bert_embedding_root = os.path.join(root,'bert_embedding_flower')
|
200 |
+
if istrain:
|
201 |
+
all_data = train_data+val_data
|
202 |
+
else:
|
203 |
+
all_data = test_data
|
204 |
+
for data in all_data:
|
205 |
+
file_path = os.path.join(images_root,f'image_{str(data).zfill(5)}.jpg')
|
206 |
+
target = int(imagelabels[int(data)-1])-1
|
207 |
+
if aux_info:
|
208 |
+
with open(os.path.join(bert_embedding_root,f'image_{str(data).zfill(5)}.pickle'),'rb') as f_bert:
|
209 |
+
bert_embedding = pickle.load(f_bert)
|
210 |
+
bert_embedding = bert_embedding['embedding_full']
|
211 |
+
images_and_targets.append([file_path,target,bert_embedding])
|
212 |
+
else:
|
213 |
+
images_and_targets.append([file_path,target])
|
214 |
+
return images_and_targets,None,images_info
|
215 |
+
def find_images_and_targets_stanforddogs(root,dataset,istrain=False,aux_info=False):
|
216 |
+
if istrain:
|
217 |
+
anno_data = scio.loadmat(os.path.join(root,'train_list.mat'))
|
218 |
+
else:
|
219 |
+
anno_data = scio.loadmat(os.path.join(root,'test_list.mat'))
|
220 |
+
images_and_targets = []
|
221 |
+
images_info = []
|
222 |
+
for file,label in zip(anno_data['file_list'],anno_data['labels']):
|
223 |
+
file_path = os.path.join(os.path.join(root,'Images'),file[0][0])
|
224 |
+
target = int(label[0])-1
|
225 |
+
images_and_targets.append([file_path,target])
|
226 |
+
return images_and_targets,None,images_info
|
227 |
+
def find_images_and_targets_nabirds(root,dataset,istrain=False,aux_info=False):
|
228 |
+
root = os.path.join(root,'nabirds')
|
229 |
+
image_paths = pd.read_csv(os.path.join(root,'images.txt'),sep=' ',names=['img_id','filepath'])
|
230 |
+
image_class_labels = pd.read_csv(os.path.join(root,'image_class_labels.txt'),sep=' ',names=['img_id','target'])
|
231 |
+
label_list = list(set(image_class_labels['target']))
|
232 |
+
label_list = sorted(label_list)
|
233 |
+
label_map = {k: i for i, k in enumerate(label_list)}
|
234 |
+
train_test_split = pd.read_csv(os.path.join(root, 'train_test_split.txt'), sep=' ', names=['img_id', 'is_training_img'])
|
235 |
+
data = image_paths.merge(image_class_labels, on='img_id')
|
236 |
+
data = data.merge(train_test_split, on='img_id')
|
237 |
+
if istrain:
|
238 |
+
data = data[data.is_training_img == 1]
|
239 |
+
else:
|
240 |
+
data = data[data.is_training_img == 0]
|
241 |
+
images_and_targets = []
|
242 |
+
images_info = []
|
243 |
+
for index,row in data.iterrows():
|
244 |
+
file_path = os.path.join(os.path.join(root,'images'),row['filepath'])
|
245 |
+
target = int(label_map[row['target']])
|
246 |
+
images_and_targets.append([file_path,target])
|
247 |
+
return images_and_targets,None,images_info
|
248 |
+
def find_images_and_targets_stanfordcars_v1(root,dataset,istrain=False,aux_info=False):
|
249 |
+
if istrain:
|
250 |
+
flag = 'train'
|
251 |
+
else:
|
252 |
+
flag = 'test'
|
253 |
+
if istrain:
|
254 |
+
anno_data = scio.loadmat(os.path.join(os.path.join(root,'devkit'),f'cars_{flag}_annos.mat'))
|
255 |
+
else:
|
256 |
+
anno_data = scio.loadmat(os.path.join(os.path.join(root,'devkit'),f'cars_{flag}_annos_withlabels.mat'))
|
257 |
+
annotation = anno_data['annotations']
|
258 |
+
images_and_targets = []
|
259 |
+
images_info = []
|
260 |
+
for r in annotation[0]:
|
261 |
+
_,_,_,_,label,name = r
|
262 |
+
file_path = os.path.join(os.path.join(root,f'cars_{flag}'),name[0])
|
263 |
+
target = int(label[0][0])-1
|
264 |
+
images_and_targets.append([file_path,target])
|
265 |
+
return images_and_targets,None,images_info
|
266 |
+
def find_images_and_targets_stanfordcars(root,dataset,istrain=False,aux_info=False):
|
267 |
+
anno_data = scio.loadmat(os.path.join(root,'cars_annos.mat'))
|
268 |
+
annotation = anno_data['annotations']
|
269 |
+
images_and_targets = []
|
270 |
+
images_info = []
|
271 |
+
for r in annotation[0]:
|
272 |
+
name,_,_,_,_,label,split = r
|
273 |
+
file_path = os.path.join(root,name[0])
|
274 |
+
target = int(label[0][0])-1
|
275 |
+
if istrain and int(split[0][0])==0:
|
276 |
+
images_and_targets.append([file_path,target])
|
277 |
+
elif not istrain and int(split[0][0])==1:
|
278 |
+
images_and_targets.append([file_path,target])
|
279 |
+
return images_and_targets,None,images_info
|
280 |
+
def find_images_and_targets_aircraft(root,dataset,istrain=False,aux_info=False):
|
281 |
+
file_root = os.path.join(root,'fgvc-aircraft-2013b','data')
|
282 |
+
if istrain:
|
283 |
+
data_file = os.path.join(file_root,'images_variant_trainval.txt')
|
284 |
+
else:
|
285 |
+
data_file = os.path.join(file_root,'images_variant_test.txt')
|
286 |
+
classes = set()
|
287 |
+
with open(data_file,'r') as f:
|
288 |
+
for line in f:
|
289 |
+
class_name = '_'.join(line.split()[1:])
|
290 |
+
classes.add(class_name)
|
291 |
+
classes = sorted(list(classes))
|
292 |
+
class_to_idx = {name:ind for ind,name in enumerate(classes)}
|
293 |
+
|
294 |
+
images_and_targets = []
|
295 |
+
images_info = []
|
296 |
+
with open(data_file,'r') as f:
|
297 |
+
images_root = os.path.join(file_root,'images')
|
298 |
+
for line in f:
|
299 |
+
image_file = line.split()[0]
|
300 |
+
class_name = '_'.join(line.split()[1:])
|
301 |
+
file_path = os.path.join(images_root,f'{image_file}.jpg')
|
302 |
+
target = class_to_idx[class_name]
|
303 |
+
images_and_targets.append([file_path,target])
|
304 |
+
return images_and_targets,class_to_idx,images_info
|
305 |
+
|
306 |
+
def find_images_and_targets_2017_2018(root,dataset,istrain=False,aux_info=False):
|
307 |
+
train_class_info,train_id2meta,val_class_info,val_id2meta,class_to_idx,id2label = load_file(root,dataset)
|
308 |
+
miss_hour = (dataset == 'inaturelist2017')
|
309 |
+
|
310 |
+
class_info = train_class_info if istrain else val_class_info
|
311 |
+
id2meta = train_id2meta if istrain else val_id2meta
|
312 |
+
images_and_targets = []
|
313 |
+
images_info = []
|
314 |
+
if aux_info:
|
315 |
+
temporal_info = []
|
316 |
+
spatial_info = []
|
317 |
+
for image,annotation in zip(class_info['images'],class_info['annotations']):
|
318 |
+
file_path = os.path.join(root,image['file_name'])
|
319 |
+
id_name = id2label[int(annotation['category_id'])]
|
320 |
+
target = class_to_idx[id_name]
|
321 |
+
image_id = image['id']
|
322 |
+
date = id2meta[image_id]['date']
|
323 |
+
latitude = id2meta[image_id]['lat']
|
324 |
+
longitude = id2meta[image_id]['lon']
|
325 |
+
location_uncertainty = id2meta[image_id]['loc_uncert']
|
326 |
+
images_info.append({'date':date,
|
327 |
+
'latitude':latitude,
|
328 |
+
'longitude':longitude,
|
329 |
+
'location_uncertainty':location_uncertainty,
|
330 |
+
'target':target})
|
331 |
+
if aux_info:
|
332 |
+
temporal_info = get_temporal_info(date,miss_hour=miss_hour)
|
333 |
+
spatial_info = get_spatial_info(latitude,longitude)
|
334 |
+
images_and_targets.append((file_path,target,temporal_info+spatial_info))
|
335 |
+
else:
|
336 |
+
images_and_targets.append((file_path,target))
|
337 |
+
return images_and_targets,class_to_idx,images_info
|
338 |
+
def find_images_and_targets(root,istrain=False,aux_info=False):
|
339 |
+
if os.path.exists(os.path.join(root,'train.json')):
|
340 |
+
with open(os.path.join(root,'train.json'),'r') as f:
|
341 |
+
train_class_info = json.load(f)
|
342 |
+
elif os.path.exists(os.path.join(root,'train_mini.json')):
|
343 |
+
with open(os.path.join(root,'train_mini.json'),'r') as f:
|
344 |
+
train_class_info = json.load(f)
|
345 |
+
else:
|
346 |
+
raise ValueError(f'not eixst file {root}/train.json or {root}/train_mini.json')
|
347 |
+
with open(os.path.join(root,'val.json'),'r') as f:
|
348 |
+
val_class_info = json.load(f)
|
349 |
+
categories_2021 = [x['name'].strip().lower() for x in val_class_info['categories']]
|
350 |
+
class_to_idx = {c: idx for idx, c in enumerate(categories_2021)}
|
351 |
+
id2label = dict()
|
352 |
+
for categorie in train_class_info['categories']:
|
353 |
+
id2label[int(categorie['id'])] = categorie['name'].strip().lower()
|
354 |
+
class_info = train_class_info if istrain else val_class_info
|
355 |
+
|
356 |
+
images_and_targets = []
|
357 |
+
images_info = []
|
358 |
+
if aux_info:
|
359 |
+
temporal_info = []
|
360 |
+
spatial_info = []
|
361 |
+
|
362 |
+
for image,annotation in zip(class_info['images'],class_info['annotations']):
|
363 |
+
file_path = os.path.join(root,image['file_name'])
|
364 |
+
id_name = id2label[int(annotation['category_id'])]
|
365 |
+
target = class_to_idx[id_name]
|
366 |
+
date = image['date']
|
367 |
+
latitude = image['latitude']
|
368 |
+
longitude = image['longitude']
|
369 |
+
location_uncertainty = image['location_uncertainty']
|
370 |
+
images_info.append({'date':date,
|
371 |
+
'latitude':latitude,
|
372 |
+
'longitude':longitude,
|
373 |
+
'location_uncertainty':location_uncertainty,
|
374 |
+
'target':target})
|
375 |
+
if aux_info:
|
376 |
+
temporal_info = get_temporal_info(date)
|
377 |
+
spatial_info = get_spatial_info(latitude,longitude)
|
378 |
+
images_and_targets.append((file_path,target,temporal_info+spatial_info))
|
379 |
+
else:
|
380 |
+
images_and_targets.append((file_path,target))
|
381 |
+
return images_and_targets,class_to_idx,images_info
|
382 |
+
|
383 |
+
|
384 |
+
class DatasetMeta(data.Dataset):
|
385 |
+
def __init__(
|
386 |
+
self,
|
387 |
+
root,
|
388 |
+
load_bytes=False,
|
389 |
+
transform=None,
|
390 |
+
train=False,
|
391 |
+
aux_info=False,
|
392 |
+
dataset='inaturelist2021',
|
393 |
+
class_ratio=1.0,
|
394 |
+
per_sample=1.0):
|
395 |
+
self.aux_info = aux_info
|
396 |
+
self.dataset = dataset
|
397 |
+
if dataset in ['inaturelist2021','inaturelist2021_mini']:
|
398 |
+
images, class_to_idx,images_info = find_images_and_targets(root,train,aux_info)
|
399 |
+
elif dataset in ['inaturelist2017','inaturelist2018']:
|
400 |
+
images, class_to_idx,images_info = find_images_and_targets_2017_2018(root,dataset,train,aux_info)
|
401 |
+
elif dataset == 'cub-200':
|
402 |
+
images, class_to_idx,images_info = find_images_and_targets_cub200(root,dataset,train,aux_info)
|
403 |
+
elif dataset == 'stanfordcars':
|
404 |
+
images, class_to_idx,images_info = find_images_and_targets_stanfordcars(root,dataset,train)
|
405 |
+
elif dataset == 'oxfordflower':
|
406 |
+
images, class_to_idx,images_info = find_images_and_targets_oxfordflower(root,dataset,train,aux_info)
|
407 |
+
elif dataset == 'stanforddogs':
|
408 |
+
images,class_to_idx,images_info = find_images_and_targets_stanforddogs(root,dataset,train)
|
409 |
+
elif dataset == 'nabirds':
|
410 |
+
images,class_to_idx,images_info = find_images_and_targets_nabirds(root,dataset,train)
|
411 |
+
elif dataset == 'aircraft':
|
412 |
+
images,class_to_idx,images_info = find_images_and_targets_aircraft(root,dataset,train)
|
413 |
+
if len(images) == 0:
|
414 |
+
raise RuntimeError(f'Found 0 images in subfolders of {root}. '
|
415 |
+
f'Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
|
416 |
+
self.root = root
|
417 |
+
self.samples = images
|
418 |
+
self.imgs = self.samples # torchvision ImageFolder compat
|
419 |
+
self.class_to_idx = class_to_idx
|
420 |
+
self.images_info = images_info
|
421 |
+
self.load_bytes = load_bytes
|
422 |
+
self.transform = transform
|
423 |
+
|
424 |
+
|
425 |
+
def __getitem__(self, index):
|
426 |
+
if self.aux_info:
|
427 |
+
path, target,aux_info = self.samples[index]
|
428 |
+
else:
|
429 |
+
path, target = self.samples[index]
|
430 |
+
img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB')
|
431 |
+
if self.transform is not None:
|
432 |
+
img = self.transform(img)
|
433 |
+
if self.aux_info:
|
434 |
+
if type(aux_info) is np.ndarray:
|
435 |
+
select_index = np.random.randint(aux_info.shape[0])
|
436 |
+
return img, target, aux_info[select_index,:]
|
437 |
+
else:
|
438 |
+
return img, target, np.asarray(aux_info).astype(np.float64)
|
439 |
+
else:
|
440 |
+
return img, target
|
441 |
+
|
442 |
+
def __len__(self):
|
443 |
+
return len(self.samples)
|
444 |
+
if __name__ == '__main__':
|
445 |
+
# train_dataset = DatasetPre('./fgvc_previous','./fgvc_previous',train=True,aux_info=True)
|
446 |
+
# import ipdb;ipdb.set_trace()
|
447 |
+
# train_dataset = DatasetMeta('./nabirds',train=True,aux_info=False,dataset='nabirds')
|
448 |
+
# find_images_and_targets_stanforddogs('./stanforddogs',None,istrain=True)
|
449 |
+
# find_images_and_targets_oxfordflower('./oxfordflower',None,istrain=True)
|
450 |
+
find_images_and_targets_ablation('./inaturelist2021',True,True,0.5,1.0)
|
451 |
+
# find_images_and_targets_cub200('./cub-200','cub-200',True,True)
|
452 |
+
# find_images_and_targets_aircraft('./aircraft','aircraft',True)
|
453 |
+
# train_dataset = DatasetMeta('./aircraft',train=False,aux_info=False,dataset='aircraft')
|
454 |
+
import ipdb;ipdb.set_trace()
|
455 |
+
# find_images_and_targets_2017('')
|
456 |
+
|
457 |
+
|
data/samplers.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Swin Transformer
|
3 |
+
# Copyright (c) 2021 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Ze Liu
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class SubsetRandomSampler(torch.utils.data.Sampler):
|
12 |
+
r"""Samples elements randomly from a given list of indices, without replacement.
|
13 |
+
|
14 |
+
Arguments:
|
15 |
+
indices (sequence): a sequence of indices
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, indices):
|
19 |
+
self.epoch = 0
|
20 |
+
self.indices = indices
|
21 |
+
|
22 |
+
def __iter__(self):
|
23 |
+
return (self.indices[i] for i in torch.randperm(len(self.indices)))
|
24 |
+
|
25 |
+
def __len__(self):
|
26 |
+
return len(self.indices)
|
27 |
+
|
28 |
+
def set_epoch(self, epoch):
|
29 |
+
self.epoch = epoch
|
data/zipreader.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Swin Transformer
|
3 |
+
# Copyright (c) 2021 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Ze Liu
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import os
|
9 |
+
import zipfile
|
10 |
+
import io
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
from PIL import ImageFile
|
14 |
+
|
15 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
16 |
+
|
17 |
+
|
18 |
+
def is_zip_path(img_or_path):
|
19 |
+
"""judge if this is a zip path"""
|
20 |
+
return '.zip@' in img_or_path
|
21 |
+
|
22 |
+
|
23 |
+
class ZipReader(object):
|
24 |
+
"""A class to read zipped files"""
|
25 |
+
zip_bank = dict()
|
26 |
+
|
27 |
+
def __init__(self):
|
28 |
+
super(ZipReader, self).__init__()
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
def get_zipfile(path):
|
32 |
+
zip_bank = ZipReader.zip_bank
|
33 |
+
if path not in zip_bank:
|
34 |
+
zfile = zipfile.ZipFile(path, 'r')
|
35 |
+
zip_bank[path] = zfile
|
36 |
+
return zip_bank[path]
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def split_zip_style_path(path):
|
40 |
+
pos_at = path.index('@')
|
41 |
+
assert pos_at != -1, "character '@' is not found from the given path '%s'" % path
|
42 |
+
|
43 |
+
zip_path = path[0: pos_at]
|
44 |
+
folder_path = path[pos_at + 1:]
|
45 |
+
folder_path = str.strip(folder_path, '/')
|
46 |
+
return zip_path, folder_path
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def list_folder(path):
|
50 |
+
zip_path, folder_path = ZipReader.split_zip_style_path(path)
|
51 |
+
|
52 |
+
zfile = ZipReader.get_zipfile(zip_path)
|
53 |
+
folder_list = []
|
54 |
+
for file_foler_name in zfile.namelist():
|
55 |
+
file_foler_name = str.strip(file_foler_name, '/')
|
56 |
+
if file_foler_name.startswith(folder_path) and \
|
57 |
+
len(os.path.splitext(file_foler_name)[-1]) == 0 and \
|
58 |
+
file_foler_name != folder_path:
|
59 |
+
if len(folder_path) == 0:
|
60 |
+
folder_list.append(file_foler_name)
|
61 |
+
else:
|
62 |
+
folder_list.append(file_foler_name[len(folder_path) + 1:])
|
63 |
+
|
64 |
+
return folder_list
|
65 |
+
|
66 |
+
@staticmethod
|
67 |
+
def list_files(path, extension=None):
|
68 |
+
if extension is None:
|
69 |
+
extension = ['.*']
|
70 |
+
zip_path, folder_path = ZipReader.split_zip_style_path(path)
|
71 |
+
|
72 |
+
zfile = ZipReader.get_zipfile(zip_path)
|
73 |
+
file_lists = []
|
74 |
+
for file_foler_name in zfile.namelist():
|
75 |
+
file_foler_name = str.strip(file_foler_name, '/')
|
76 |
+
if file_foler_name.startswith(folder_path) and \
|
77 |
+
str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
|
78 |
+
if len(folder_path) == 0:
|
79 |
+
file_lists.append(file_foler_name)
|
80 |
+
else:
|
81 |
+
file_lists.append(file_foler_name[len(folder_path) + 1:])
|
82 |
+
|
83 |
+
return file_lists
|
84 |
+
|
85 |
+
@staticmethod
|
86 |
+
def read(path):
|
87 |
+
zip_path, path_img = ZipReader.split_zip_style_path(path)
|
88 |
+
zfile = ZipReader.get_zipfile(zip_path)
|
89 |
+
data = zfile.read(path_img)
|
90 |
+
return data
|
91 |
+
|
92 |
+
@staticmethod
|
93 |
+
def imread(path):
|
94 |
+
zip_path, path_img = ZipReader.split_zip_style_path(path)
|
95 |
+
zfile = ZipReader.get_zipfile(zip_path)
|
96 |
+
data = zfile.read(path_img)
|
97 |
+
try:
|
98 |
+
im = Image.open(io.BytesIO(data))
|
99 |
+
except:
|
100 |
+
print("ERROR IMG LOADED: ", path_img)
|
101 |
+
random_img = np.random.rand(224, 224, 3) * 255
|
102 |
+
im = Image.fromarray(np.uint8(random_img))
|
103 |
+
return im
|
figs/overview.png
ADDED
![]() |
get_flops.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
from timm.models import create_model
|
4 |
+
from models.CoAt import *
|
5 |
+
|
6 |
+
try:
|
7 |
+
from mmcv.cnn import get_model_complexity_info
|
8 |
+
from mmcv.cnn.utils.flops_counter import get_model_complexity_info, flops_to_string, params_to_string
|
9 |
+
except ImportError:
|
10 |
+
raise ImportError('Please upgrade mmcv to >0.6.2')
|
11 |
+
|
12 |
+
|
13 |
+
def parse_args():
|
14 |
+
parser = argparse.ArgumentParser(description='Get FLOPS of a classification model')
|
15 |
+
parser.add_argument('model', help='train config file path')
|
16 |
+
parser.add_argument(
|
17 |
+
'--shape',
|
18 |
+
type=int,
|
19 |
+
nargs='+',
|
20 |
+
default=[224,],
|
21 |
+
help='input image size')
|
22 |
+
args = parser.parse_args()
|
23 |
+
return args
|
24 |
+
|
25 |
+
def get_flops(model, input_shape):
|
26 |
+
flops, params = get_model_complexity_info(model, input_shape, as_strings=False)
|
27 |
+
return flops_to_string(flops), params_to_string(params)
|
28 |
+
|
29 |
+
|
30 |
+
def main():
|
31 |
+
args = parse_args()
|
32 |
+
|
33 |
+
if len(args.shape) == 1:
|
34 |
+
input_shape = (3, args.shape[0], args.shape[0])
|
35 |
+
elif len(args.shape) == 2:
|
36 |
+
input_shape = (3,) + tuple(args.shape)
|
37 |
+
else:
|
38 |
+
raise ValueError('invalid input shape')
|
39 |
+
|
40 |
+
model = create_model(
|
41 |
+
args.model,
|
42 |
+
pretrained=False,
|
43 |
+
num_classes=1000,
|
44 |
+
img_size=args.shape[0],
|
45 |
+
)
|
46 |
+
model.name = args.model
|
47 |
+
if torch.cuda.is_available():
|
48 |
+
model.cuda()
|
49 |
+
model.eval()
|
50 |
+
|
51 |
+
flops, params = get_flops(model, input_shape)
|
52 |
+
|
53 |
+
split_line = '=' * 30
|
54 |
+
print(f'{split_line}\nInput shape: {input_shape}\n'
|
55 |
+
f'Flops: {flops}\nParams: {params}\n{split_line}')
|
56 |
+
print('!!!Please be cautious if you use the results in papers. '
|
57 |
+
'You may need to check if all ops are supported and verify that the '
|
58 |
+
'flops computation is correct.')
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == '__main__':
|
62 |
+
main()
|
logger.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Swin Transformer
|
3 |
+
# Copyright (c) 2021 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Ze Liu
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import logging
|
11 |
+
import functools
|
12 |
+
from termcolor import colored
|
13 |
+
|
14 |
+
|
15 |
+
@functools.lru_cache()
|
16 |
+
def create_logger(output_dir, dist_rank=0, name='',local_rank=0):
|
17 |
+
# create logger
|
18 |
+
logger = logging.getLogger(name)
|
19 |
+
logger.setLevel(logging.DEBUG)
|
20 |
+
logger.propagate = False
|
21 |
+
|
22 |
+
# create formatter
|
23 |
+
fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
|
24 |
+
color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
|
25 |
+
colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
|
26 |
+
|
27 |
+
# create console handlers for master process
|
28 |
+
# if dist_rank == 0:
|
29 |
+
# console_handler = logging.StreamHandler(sys.stdout)
|
30 |
+
# console_handler.setLevel(logging.DEBUG)
|
31 |
+
# console_handler.setFormatter(
|
32 |
+
# logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
|
33 |
+
# logger.addHandler(console_handler)
|
34 |
+
|
35 |
+
if local_rank == 0:
|
36 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
37 |
+
console_handler.setLevel(logging.DEBUG)
|
38 |
+
console_handler.setFormatter(
|
39 |
+
logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
|
40 |
+
logger.addHandler(console_handler)
|
41 |
+
# create file handlers
|
42 |
+
file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
|
43 |
+
file_handler.setLevel(logging.DEBUG)
|
44 |
+
file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
|
45 |
+
logger.addHandler(file_handler)
|
46 |
+
|
47 |
+
return logger
|
lr_scheduler.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Swin Transformer
|
3 |
+
# Copyright (c) 2021 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Ze Liu
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from timm.scheduler.cosine_lr import CosineLRScheduler
|
10 |
+
from timm.scheduler.step_lr import StepLRScheduler
|
11 |
+
from timm.scheduler.scheduler import Scheduler
|
12 |
+
|
13 |
+
|
14 |
+
def build_scheduler(config, optimizer, n_iter_per_epoch):
|
15 |
+
num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
|
16 |
+
warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
|
17 |
+
decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
|
18 |
+
|
19 |
+
lr_scheduler = None
|
20 |
+
if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
|
21 |
+
lr_scheduler = CosineLRScheduler(
|
22 |
+
optimizer,
|
23 |
+
t_initial=num_steps,
|
24 |
+
t_mul=1.,
|
25 |
+
lr_min=config.TRAIN.MIN_LR,
|
26 |
+
warmup_lr_init=config.TRAIN.WARMUP_LR,
|
27 |
+
warmup_t=warmup_steps,
|
28 |
+
cycle_limit=1,
|
29 |
+
t_in_epochs=False,
|
30 |
+
)
|
31 |
+
elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
|
32 |
+
lr_scheduler = LinearLRScheduler(
|
33 |
+
optimizer,
|
34 |
+
t_initial=num_steps,
|
35 |
+
lr_min_rate=0.01,
|
36 |
+
warmup_lr_init=config.TRAIN.WARMUP_LR,
|
37 |
+
warmup_t=warmup_steps,
|
38 |
+
t_in_epochs=False,
|
39 |
+
)
|
40 |
+
elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
|
41 |
+
lr_scheduler = StepLRScheduler(
|
42 |
+
optimizer,
|
43 |
+
decay_t=decay_steps,
|
44 |
+
decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
|
45 |
+
warmup_lr_init=config.TRAIN.WARMUP_LR,
|
46 |
+
warmup_t=warmup_steps,
|
47 |
+
t_in_epochs=False,
|
48 |
+
)
|
49 |
+
|
50 |
+
return lr_scheduler
|
51 |
+
|
52 |
+
|
53 |
+
class LinearLRScheduler(Scheduler):
|
54 |
+
def __init__(self,
|
55 |
+
optimizer: torch.optim.Optimizer,
|
56 |
+
t_initial: int,
|
57 |
+
lr_min_rate: float,
|
58 |
+
warmup_t=0,
|
59 |
+
warmup_lr_init=0.,
|
60 |
+
t_in_epochs=True,
|
61 |
+
noise_range_t=None,
|
62 |
+
noise_pct=0.67,
|
63 |
+
noise_std=1.0,
|
64 |
+
noise_seed=42,
|
65 |
+
initialize=True,
|
66 |
+
) -> None:
|
67 |
+
super().__init__(
|
68 |
+
optimizer, param_group_field="lr",
|
69 |
+
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
|
70 |
+
initialize=initialize)
|
71 |
+
|
72 |
+
self.t_initial = t_initial
|
73 |
+
self.lr_min_rate = lr_min_rate
|
74 |
+
self.warmup_t = warmup_t
|
75 |
+
self.warmup_lr_init = warmup_lr_init
|
76 |
+
self.t_in_epochs = t_in_epochs
|
77 |
+
if self.warmup_t:
|
78 |
+
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
|
79 |
+
super().update_groups(self.warmup_lr_init)
|
80 |
+
else:
|
81 |
+
self.warmup_steps = [1 for _ in self.base_values]
|
82 |
+
|
83 |
+
def _get_lr(self, t):
|
84 |
+
if t < self.warmup_t:
|
85 |
+
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
86 |
+
else:
|
87 |
+
t = t - self.warmup_t
|
88 |
+
total_t = self.t_initial - self.warmup_t
|
89 |
+
lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
|
90 |
+
return lrs
|
91 |
+
|
92 |
+
def get_epoch_values(self, epoch: int):
|
93 |
+
if self.t_in_epochs:
|
94 |
+
return self._get_lr(epoch)
|
95 |
+
else:
|
96 |
+
return None
|
97 |
+
|
98 |
+
def get_update_values(self, num_updates: int):
|
99 |
+
if not self.t_in_epochs:
|
100 |
+
return self._get_lr(num_updates)
|
101 |
+
else:
|
102 |
+
return None
|
main.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import argparse
|
4 |
+
import datetime
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.backends.cudnn as cudnn
|
9 |
+
import torch.distributed as dist
|
10 |
+
|
11 |
+
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
12 |
+
from timm.utils import accuracy, AverageMeter
|
13 |
+
|
14 |
+
from config import get_config
|
15 |
+
from models import build_model
|
16 |
+
from data import build_loader
|
17 |
+
from lr_scheduler import build_scheduler
|
18 |
+
from optimizer import build_optimizer
|
19 |
+
from logger import create_logger
|
20 |
+
from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor,load_pretained
|
21 |
+
from torch.utils.tensorboard import SummaryWriter
|
22 |
+
try:
|
23 |
+
# noinspection PyUnresolvedReferences
|
24 |
+
from apex import amp
|
25 |
+
except ImportError:
|
26 |
+
amp = None
|
27 |
+
|
28 |
+
|
29 |
+
def parse_option():
|
30 |
+
parser = argparse.ArgumentParser('MetaFG training and evaluation script', add_help=False)
|
31 |
+
parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
|
32 |
+
parser.add_argument(
|
33 |
+
"--opts",
|
34 |
+
help="Modify config options by adding 'KEY VALUE' pairs. ",
|
35 |
+
default=None,
|
36 |
+
nargs='+',
|
37 |
+
)
|
38 |
+
|
39 |
+
# easy config modification
|
40 |
+
parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
|
41 |
+
parser.add_argument('--data-path',default='./imagenet', type=str, help='path to dataset')
|
42 |
+
parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
|
43 |
+
parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
|
44 |
+
help='no: no cache, '
|
45 |
+
'full: cache all data, '
|
46 |
+
'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
|
47 |
+
parser.add_argument('--resume', help='resume from checkpoint')
|
48 |
+
parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
|
49 |
+
parser.add_argument('--use-checkpoint', action='store_true',
|
50 |
+
help="whether to use gradient checkpointing to save memory")
|
51 |
+
parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
|
52 |
+
help='mixed precision opt level, if O0, no amp is used')
|
53 |
+
parser.add_argument('--output', default='output', type=str, metavar='PATH',
|
54 |
+
help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
|
55 |
+
parser.add_argument('--tag', help='tag of experiment')
|
56 |
+
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
|
57 |
+
parser.add_argument('--throughput', action='store_true', help='Test throughput only')
|
58 |
+
|
59 |
+
parser.add_argument('--num-workers', type=int,
|
60 |
+
help="num of workers on dataloader ")
|
61 |
+
|
62 |
+
parser.add_argument('--lr', type=float, metavar='LR',
|
63 |
+
help='learning rate')
|
64 |
+
parser.add_argument('--weight-decay', type=float,
|
65 |
+
help='weight decay (default: 0.05 for adamw)')
|
66 |
+
|
67 |
+
parser.add_argument('--min-lr', type=float,
|
68 |
+
help='learning rate')
|
69 |
+
parser.add_argument('--warmup-lr', type=float,
|
70 |
+
help='warmup learning rate')
|
71 |
+
parser.add_argument('--epochs', type=int,
|
72 |
+
help="epochs")
|
73 |
+
parser.add_argument('--warmup-epochs', type=int,
|
74 |
+
help="epochs")
|
75 |
+
|
76 |
+
parser.add_argument('--dataset', type=str,
|
77 |
+
help='dataset')
|
78 |
+
parser.add_argument('--lr-scheduler-name', type=str,
|
79 |
+
help='lr scheduler name,cosin linear,step')
|
80 |
+
|
81 |
+
parser.add_argument('--pretrain', type=str,
|
82 |
+
help='pretrain')
|
83 |
+
|
84 |
+
parser.add_argument('--tensorboard', action='store_true', help='using tensorboard')
|
85 |
+
|
86 |
+
|
87 |
+
# distributed training
|
88 |
+
parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
|
89 |
+
|
90 |
+
args, unparsed = parser.parse_known_args()
|
91 |
+
|
92 |
+
config = get_config(args)
|
93 |
+
|
94 |
+
return args, config
|
95 |
+
|
96 |
+
|
97 |
+
def main(config):
|
98 |
+
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
|
99 |
+
logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
|
100 |
+
model = build_model(config)
|
101 |
+
model.cuda()
|
102 |
+
logger.info(str(model))
|
103 |
+
|
104 |
+
optimizer = build_optimizer(config, model)
|
105 |
+
if config.AMP_OPT_LEVEL != "O0":
|
106 |
+
model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL)
|
107 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
|
108 |
+
model_without_ddp = model.module
|
109 |
+
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
110 |
+
logger.info(f"number of params: {n_parameters}")
|
111 |
+
if hasattr(model_without_ddp, 'flops'):
|
112 |
+
flops = model_without_ddp.flops()
|
113 |
+
logger.info(f"number of GFLOPs: {flops / 1e9}")
|
114 |
+
lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
|
115 |
+
if config.AUG.MIXUP > 0.:
|
116 |
+
# smoothing is handled with mixup label transform
|
117 |
+
criterion = SoftTargetCrossEntropy()
|
118 |
+
elif config.MODEL.LABEL_SMOOTHING > 0.:
|
119 |
+
criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
|
120 |
+
else:
|
121 |
+
criterion = torch.nn.CrossEntropyLoss()
|
122 |
+
|
123 |
+
max_accuracy = 0.0
|
124 |
+
if config.MODEL.PRETRAINED:
|
125 |
+
load_pretained(config,model_without_ddp,logger)
|
126 |
+
if config.EVAL_MODE:
|
127 |
+
acc1, acc5, loss = validate(config, data_loader_val, model)
|
128 |
+
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
129 |
+
return
|
130 |
+
|
131 |
+
if config.TRAIN.AUTO_RESUME:
|
132 |
+
resume_file = auto_resume_helper(config.OUTPUT)
|
133 |
+
if resume_file:
|
134 |
+
if config.MODEL.RESUME:
|
135 |
+
logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
|
136 |
+
config.defrost()
|
137 |
+
config.MODEL.RESUME = resume_file
|
138 |
+
config.freeze()
|
139 |
+
logger.info(f'auto resuming from {resume_file}')
|
140 |
+
else:
|
141 |
+
logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
|
142 |
+
|
143 |
+
if config.MODEL.RESUME:
|
144 |
+
logger.info(f"**********normal test***********")
|
145 |
+
max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger)
|
146 |
+
acc1, acc5, loss = validate(config, data_loader_val, model)
|
147 |
+
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
148 |
+
if config.DATA.ADD_META:
|
149 |
+
logger.info(f"**********mask meta test***********")
|
150 |
+
acc1, acc5, loss = validate(config, data_loader_val, model,mask_meta=True)
|
151 |
+
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
152 |
+
if config.EVAL_MODE:
|
153 |
+
return
|
154 |
+
|
155 |
+
if config.THROUGHPUT_MODE:
|
156 |
+
throughput(data_loader_val, model, logger)
|
157 |
+
return
|
158 |
+
|
159 |
+
logger.info("Start training")
|
160 |
+
start_time = time.time()
|
161 |
+
for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
|
162 |
+
data_loader_train.sampler.set_epoch(epoch)
|
163 |
+
train_one_epoch_local_data(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler)
|
164 |
+
if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
|
165 |
+
save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger)
|
166 |
+
|
167 |
+
logger.info(f"**********normal test***********")
|
168 |
+
acc1, acc5, loss = validate(config, data_loader_val, model)
|
169 |
+
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
170 |
+
max_accuracy = max(max_accuracy, acc1)
|
171 |
+
logger.info(f'Max accuracy: {max_accuracy:.2f}%')
|
172 |
+
if config.DATA.ADD_META:
|
173 |
+
logger.info(f"**********mask meta test***********")
|
174 |
+
acc1, acc5, loss = validate(config, data_loader_val, model,mask_meta=True)
|
175 |
+
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
|
176 |
+
# data_loader_train.terminate()
|
177 |
+
total_time = time.time() - start_time
|
178 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
179 |
+
logger.info('Training time {}'.format(total_time_str))
|
180 |
+
def train_one_epoch_local_data(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler,tb_logger=None):
|
181 |
+
model.train()
|
182 |
+
if hasattr(model.module,'cur_epoch'):
|
183 |
+
model.module.cur_epoch = epoch
|
184 |
+
model.module.total_epoch = config.TRAIN.EPOCHS
|
185 |
+
optimizer.zero_grad()
|
186 |
+
|
187 |
+
num_steps = len(data_loader)
|
188 |
+
batch_time = AverageMeter()
|
189 |
+
loss_meter = AverageMeter()
|
190 |
+
norm_meter = AverageMeter()
|
191 |
+
|
192 |
+
start = time.time()
|
193 |
+
end = time.time()
|
194 |
+
for idx, data in enumerate(data_loader):
|
195 |
+
if config.DATA.ADD_META:
|
196 |
+
samples, targets,meta = data
|
197 |
+
meta = [m.float() for m in meta]
|
198 |
+
meta = torch.stack(meta,dim=0)
|
199 |
+
meta = meta.cuda(non_blocking=True)
|
200 |
+
else:
|
201 |
+
samples, targets= data
|
202 |
+
meta = None
|
203 |
+
|
204 |
+
samples = samples.cuda(non_blocking=True)
|
205 |
+
targets = targets.cuda(non_blocking=True)
|
206 |
+
|
207 |
+
if mixup_fn is not None:
|
208 |
+
samples, targets = mixup_fn(samples, targets)
|
209 |
+
if config.DATA.ADD_META:
|
210 |
+
outputs = model(samples,meta)
|
211 |
+
else:
|
212 |
+
outputs = model(samples)
|
213 |
+
|
214 |
+
if config.TRAIN.ACCUMULATION_STEPS > 1:
|
215 |
+
loss = criterion(outputs, targets)
|
216 |
+
loss = loss / config.TRAIN.ACCUMULATION_STEPS
|
217 |
+
if config.AMP_OPT_LEVEL != "O0":
|
218 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
219 |
+
scaled_loss.backward()
|
220 |
+
if config.TRAIN.CLIP_GRAD:
|
221 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
|
222 |
+
else:
|
223 |
+
grad_norm = get_grad_norm(amp.master_params(optimizer))
|
224 |
+
else:
|
225 |
+
loss.backward()
|
226 |
+
if config.TRAIN.CLIP_GRAD:
|
227 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
|
228 |
+
else:
|
229 |
+
grad_norm = get_grad_norm(model.parameters())
|
230 |
+
if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
|
231 |
+
optimizer.step()
|
232 |
+
optimizer.zero_grad()
|
233 |
+
lr_scheduler.step_update(epoch * num_steps + idx)
|
234 |
+
else:
|
235 |
+
loss = criterion(outputs, targets)
|
236 |
+
optimizer.zero_grad()
|
237 |
+
if config.AMP_OPT_LEVEL != "O0":
|
238 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
239 |
+
scaled_loss.backward()
|
240 |
+
if config.TRAIN.CLIP_GRAD:
|
241 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
|
242 |
+
else:
|
243 |
+
grad_norm = get_grad_norm(amp.master_params(optimizer))
|
244 |
+
else:
|
245 |
+
loss.backward()
|
246 |
+
if config.TRAIN.CLIP_GRAD:
|
247 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
|
248 |
+
else:
|
249 |
+
grad_norm = get_grad_norm(model.parameters())
|
250 |
+
optimizer.step()
|
251 |
+
lr_scheduler.step_update(epoch * num_steps + idx)
|
252 |
+
|
253 |
+
torch.cuda.synchronize()
|
254 |
+
|
255 |
+
loss_meter.update(loss.item(), targets.size(0))
|
256 |
+
norm_meter.update(grad_norm)
|
257 |
+
batch_time.update(time.time() - end)
|
258 |
+
end = time.time()
|
259 |
+
|
260 |
+
if idx % config.PRINT_FREQ == 0:
|
261 |
+
lr = optimizer.param_groups[0]['lr']
|
262 |
+
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
|
263 |
+
etas = batch_time.avg * (num_steps - idx)
|
264 |
+
logger.info(
|
265 |
+
f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
|
266 |
+
f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
|
267 |
+
f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
|
268 |
+
f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
|
269 |
+
f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
|
270 |
+
f'mem {memory_used:.0f}MB')
|
271 |
+
epoch_time = time.time() - start
|
272 |
+
logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
|
273 |
+
@torch.no_grad()
|
274 |
+
def validate(config, data_loader, model, mask_meta=False):
|
275 |
+
criterion = torch.nn.CrossEntropyLoss()
|
276 |
+
model.eval()
|
277 |
+
|
278 |
+
batch_time = AverageMeter()
|
279 |
+
loss_meter = AverageMeter()
|
280 |
+
acc1_meter = AverageMeter()
|
281 |
+
acc5_meter = AverageMeter()
|
282 |
+
|
283 |
+
end = time.time()
|
284 |
+
for idx, data in enumerate(data_loader):
|
285 |
+
if config.DATA.ADD_META:
|
286 |
+
images,target,meta = data
|
287 |
+
meta = [m.float() for m in meta]
|
288 |
+
meta = torch.stack(meta,dim=0)
|
289 |
+
if mask_meta:
|
290 |
+
meta = torch.zeros_like(meta)
|
291 |
+
meta = meta.cuda(non_blocking=True)
|
292 |
+
else:
|
293 |
+
images, target = data
|
294 |
+
meta = None
|
295 |
+
|
296 |
+
images = images.cuda(non_blocking=True)
|
297 |
+
target = target.cuda(non_blocking=True)
|
298 |
+
|
299 |
+
# compute output
|
300 |
+
if config.DATA.ADD_META:
|
301 |
+
output = model(images,meta)
|
302 |
+
else:
|
303 |
+
output = model(images)
|
304 |
+
|
305 |
+
# measure accuracy and record loss
|
306 |
+
loss = criterion(output, target)
|
307 |
+
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
308 |
+
|
309 |
+
acc1 = reduce_tensor(acc1)
|
310 |
+
acc5 = reduce_tensor(acc5)
|
311 |
+
loss = reduce_tensor(loss)
|
312 |
+
|
313 |
+
loss_meter.update(loss.item(), target.size(0))
|
314 |
+
acc1_meter.update(acc1.item(), target.size(0))
|
315 |
+
acc5_meter.update(acc5.item(), target.size(0))
|
316 |
+
|
317 |
+
# measure elapsed time
|
318 |
+
batch_time.update(time.time() - end)
|
319 |
+
end = time.time()
|
320 |
+
|
321 |
+
if idx % config.PRINT_FREQ == 0:
|
322 |
+
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
|
323 |
+
logger.info(
|
324 |
+
f'Test: [{idx}/{len(data_loader)}]\t'
|
325 |
+
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
326 |
+
f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
|
327 |
+
f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
|
328 |
+
f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
|
329 |
+
f'Mem {memory_used:.0f}MB')
|
330 |
+
logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
|
331 |
+
return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
|
332 |
+
|
333 |
+
|
334 |
+
@torch.no_grad()
|
335 |
+
def throughput(data_loader, model, logger):
|
336 |
+
model.eval()
|
337 |
+
|
338 |
+
for idx, (images, _) in enumerate(data_loader):
|
339 |
+
images = images.cuda(non_blocking=True)
|
340 |
+
batch_size = images.shape[0]
|
341 |
+
for i in range(50):
|
342 |
+
model(images)
|
343 |
+
torch.cuda.synchronize()
|
344 |
+
logger.info(f"throughput averaged with 30 times")
|
345 |
+
tic1 = time.time()
|
346 |
+
for i in range(30):
|
347 |
+
model(images)
|
348 |
+
torch.cuda.synchronize()
|
349 |
+
tic2 = time.time()
|
350 |
+
logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
|
351 |
+
return
|
352 |
+
|
353 |
+
|
354 |
+
if __name__ == '__main__':
|
355 |
+
_, config = parse_option()
|
356 |
+
|
357 |
+
if config.AMP_OPT_LEVEL != "O0":
|
358 |
+
assert amp is not None, "amp not installed!"
|
359 |
+
|
360 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
361 |
+
rank = int(os.environ["RANK"])
|
362 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
363 |
+
print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
|
364 |
+
else:
|
365 |
+
rank = -1
|
366 |
+
world_size = -1
|
367 |
+
torch.cuda.set_device(config.LOCAL_RANK)
|
368 |
+
torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
|
369 |
+
torch.distributed.barrier()
|
370 |
+
|
371 |
+
seed = config.SEED + dist.get_rank()
|
372 |
+
torch.manual_seed(seed)
|
373 |
+
np.random.seed(seed)
|
374 |
+
cudnn.benchmark = True
|
375 |
+
|
376 |
+
# linear scale the learning rate according to total batch size, may not be optimal
|
377 |
+
linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
|
378 |
+
linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
|
379 |
+
linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
|
380 |
+
# gradient accumulation also need to scale the learning rate
|
381 |
+
if config.TRAIN.ACCUMULATION_STEPS > 1:
|
382 |
+
linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
|
383 |
+
linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
|
384 |
+
linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
|
385 |
+
config.defrost()
|
386 |
+
config.TRAIN.BASE_LR = linear_scaled_lr
|
387 |
+
config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
|
388 |
+
config.TRAIN.MIN_LR = linear_scaled_min_lr
|
389 |
+
config.freeze()
|
390 |
+
|
391 |
+
os.makedirs(config.OUTPUT, exist_ok=True)
|
392 |
+
logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}",local_rank=config.LOCAL_RANK)
|
393 |
+
|
394 |
+
if dist.get_rank() == 0:
|
395 |
+
path = os.path.join(config.OUTPUT, "config.json")
|
396 |
+
with open(path, "w") as f:
|
397 |
+
f.write(config.dump())
|
398 |
+
logger.info(f"Full config saved to {path}")
|
399 |
+
|
400 |
+
# print config
|
401 |
+
logger.info(config.dump())
|
402 |
+
|
403 |
+
main(config)
|
models/MBConv.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
class SwishImplementation(torch.autograd.Function):
|
9 |
+
@staticmethod
|
10 |
+
def forward(ctx, i):
|
11 |
+
result = i * torch.sigmoid(i)
|
12 |
+
ctx.save_for_backward(i)
|
13 |
+
return result
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
def backward(ctx, grad_output):
|
17 |
+
i = ctx.saved_variables[0]
|
18 |
+
sigmoid_i = torch.sigmoid(i)
|
19 |
+
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
20 |
+
|
21 |
+
class MemoryEfficientSwish(nn.Module):
|
22 |
+
def forward(self, x):
|
23 |
+
return SwishImplementation.apply(x)
|
24 |
+
|
25 |
+
|
26 |
+
def drop_connect(inputs, p, training):
|
27 |
+
""" Drop connect. """
|
28 |
+
if not training: return inputs
|
29 |
+
batch_size = inputs.shape[0]
|
30 |
+
keep_prob = 1 - p
|
31 |
+
random_tensor = keep_prob
|
32 |
+
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
|
33 |
+
binary_tensor = torch.floor(random_tensor)
|
34 |
+
output = inputs / keep_prob * binary_tensor
|
35 |
+
return output
|
36 |
+
|
37 |
+
|
38 |
+
def get_same_padding_conv2d(image_size=None):
|
39 |
+
return partial(Conv2dStaticSamePadding, image_size=image_size)
|
40 |
+
|
41 |
+
def get_width_and_height_from_size(x):
|
42 |
+
""" Obtains width and height from a int or tuple """
|
43 |
+
if isinstance(x, int): return x, x
|
44 |
+
if isinstance(x, list) or isinstance(x, tuple): return x
|
45 |
+
else: raise TypeError()
|
46 |
+
|
47 |
+
def calculate_output_image_size(input_image_size, stride):
|
48 |
+
"""
|
49 |
+
计算出 Conv2dSamePadding with a stride.
|
50 |
+
"""
|
51 |
+
if input_image_size is None: return None
|
52 |
+
image_height, image_width = get_width_and_height_from_size(input_image_size)
|
53 |
+
stride = stride if isinstance(stride, int) else stride[0]
|
54 |
+
image_height = int(math.ceil(image_height / stride))
|
55 |
+
image_width = int(math.ceil(image_width / stride))
|
56 |
+
return [image_height, image_width]
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
class Conv2dStaticSamePadding(nn.Conv2d):
|
61 |
+
""" 2D Convolutions like TensorFlow, for a fixed image size"""
|
62 |
+
|
63 |
+
def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
|
64 |
+
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
|
65 |
+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
66 |
+
|
67 |
+
# Calculate padding based on image size and save it
|
68 |
+
assert image_size is not None
|
69 |
+
ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
|
70 |
+
kh, kw = self.weight.size()[-2:]
|
71 |
+
sh, sw = self.stride
|
72 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
73 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
74 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
75 |
+
if pad_h > 0 or pad_w > 0:
|
76 |
+
self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
|
77 |
+
else:
|
78 |
+
self.static_padding = Identity()
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
x = self.static_padding(x)
|
82 |
+
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
83 |
+
return x
|
84 |
+
|
85 |
+
class Identity(nn.Module):
|
86 |
+
def __init__(self, ):
|
87 |
+
super(Identity, self).__init__()
|
88 |
+
|
89 |
+
def forward(self, input):
|
90 |
+
return input
|
91 |
+
|
92 |
+
# #MBConvBlock
|
93 |
+
class MBConvBlock(nn.Module):
|
94 |
+
'''
|
95 |
+
层 ksize3*3 输入32 输出16 conv1 stride步长1
|
96 |
+
'''
|
97 |
+
def __init__(self, ksize, input_filters, output_filters, expand_ratio=1, stride=1,image_size=224,drop_connect_rate=0.):
|
98 |
+
super().__init__()
|
99 |
+
self._bn_mom = 0.1
|
100 |
+
self._bn_eps = 0.01
|
101 |
+
self._se_ratio = 0.25
|
102 |
+
self._input_filters = input_filters
|
103 |
+
self._output_filters = output_filters
|
104 |
+
self._expand_ratio = expand_ratio
|
105 |
+
self._kernel_size = ksize
|
106 |
+
self._stride = stride
|
107 |
+
self._drop_connect_rate = drop_connect_rate
|
108 |
+
inp = self._input_filters
|
109 |
+
oup = self._input_filters * self._expand_ratio
|
110 |
+
if self._expand_ratio != 1:
|
111 |
+
self._expand_conv = nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=1,bias=False)
|
112 |
+
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
113 |
+
|
114 |
+
|
115 |
+
# Depthwise convolution
|
116 |
+
k = self._kernel_size
|
117 |
+
s = self._stride
|
118 |
+
self._depthwise_conv = nn.Conv2d(in_channels=oup, out_channels=oup, groups=oup,
|
119 |
+
kernel_size=k, stride=s, padding=1,bias=False)
|
120 |
+
|
121 |
+
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
122 |
+
# Squeeze and Excitation layer, if desired
|
123 |
+
num_squeezed_channels = max(1, int(self._input_filters * self._se_ratio))
|
124 |
+
self._se_reduce = nn.Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
|
125 |
+
self._se_expand = nn.Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
|
126 |
+
|
127 |
+
# Output phase
|
128 |
+
final_oup = self._output_filters
|
129 |
+
self._project_conv = nn.Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1,bias=False)
|
130 |
+
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
|
131 |
+
self._swish = MemoryEfficientSwish()
|
132 |
+
|
133 |
+
def forward(self, inputs):
|
134 |
+
"""
|
135 |
+
:param inputs: input tensor
|
136 |
+
:return: output of block
|
137 |
+
"""
|
138 |
+
|
139 |
+
# Expansion and Depthwise Convolution
|
140 |
+
x = inputs
|
141 |
+
if self._expand_ratio != 1:
|
142 |
+
expand = self._expand_conv(inputs)
|
143 |
+
bn0 = self._bn0(expand)
|
144 |
+
x = self._swish(bn0)
|
145 |
+
depthwise = self._depthwise_conv(x)
|
146 |
+
bn1 = self._bn1(depthwise)
|
147 |
+
x = self._swish(bn1)
|
148 |
+
# Squeeze and Excitation
|
149 |
+
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
150 |
+
x_squeezed = self._se_reduce(x_squeezed)
|
151 |
+
x_squeezed = self._swish(x_squeezed)
|
152 |
+
x_squeezed = self._se_expand(x_squeezed)
|
153 |
+
x = torch.sigmoid(x_squeezed) * x
|
154 |
+
|
155 |
+
x = self._bn2(self._project_conv(x))
|
156 |
+
|
157 |
+
# Skip connection and drop connect
|
158 |
+
input_filters, output_filters = self._input_filters, self._output_filters
|
159 |
+
if self._stride == 1 and input_filters == output_filters:
|
160 |
+
if self._drop_connect_rate!=0:
|
161 |
+
x = drop_connect(x, p=self._drop_connect_rate, training=self.training)
|
162 |
+
x = x + inputs # skip connection
|
163 |
+
return x
|
164 |
+
if __name__ == '__main__':
|
165 |
+
input=torch.randn(1,3,112,112)
|
166 |
+
mbconv=MBConvBlock(ksize=3,input_filters=3,output_filters=3,expand_ratio=4,stride=1)
|
167 |
+
print(mbconv)
|
168 |
+
out=mbconv(input)
|
169 |
+
print(out.shape)
|
models/MHSA.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
import numpy as np
|
6 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
7 |
+
|
8 |
+
class Mlp(nn.Module):
|
9 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
10 |
+
super().__init__()
|
11 |
+
out_features = out_features or in_features
|
12 |
+
hidden_features = hidden_features or in_features
|
13 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
14 |
+
self.act = act_layer()
|
15 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
16 |
+
self.drop = nn.Dropout(drop)
|
17 |
+
|
18 |
+
def forward(self, x, H=None, W=None):
|
19 |
+
x = self.fc1(x)
|
20 |
+
x = self.act(x)
|
21 |
+
x = self.drop(x)
|
22 |
+
x = self.fc2(x)
|
23 |
+
x = self.drop(x)
|
24 |
+
return x
|
25 |
+
class DWConv(nn.Module):
|
26 |
+
def __init__(self, dim=768):
|
27 |
+
super(DWConv, self).__init__()
|
28 |
+
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
29 |
+
|
30 |
+
def forward(self, x, H, W):
|
31 |
+
B, N, C = x.shape
|
32 |
+
x = x.transpose(1, 2).view(B, C, H, W)
|
33 |
+
x = self.dwconv(x)
|
34 |
+
x = x.flatten(2).transpose(1, 2)
|
35 |
+
|
36 |
+
return x
|
37 |
+
class Relative_Attention(nn.Module):
|
38 |
+
def __init__(self,dim,img_size,extra_token_num=1,num_heads=8,qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
39 |
+
super().__init__()
|
40 |
+
self.num_heads = num_heads
|
41 |
+
self.extra_token_num = extra_token_num
|
42 |
+
head_dim = dim // num_heads
|
43 |
+
self.img_size = img_size # h,w
|
44 |
+
self.scale = qk_scale or head_dim ** -0.5
|
45 |
+
# define a parameter table of relative position bias,add cls_token bias
|
46 |
+
self.relative_position_bias_table = nn.Parameter(
|
47 |
+
torch.zeros((2 * img_size[0] - 1) * (2 * img_size[1] - 1) + 1, num_heads)) # 2*h-1 * 2*w-1 + 1, nH
|
48 |
+
|
49 |
+
# get pair-wise relative position index for each token
|
50 |
+
coords_h = torch.arange(self.img_size[0])
|
51 |
+
coords_w = torch.arange(self.img_size[1])
|
52 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, h, w
|
53 |
+
coords_flatten = torch.flatten(coords, 1) # 2, h*w
|
54 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, h*w, h*w
|
55 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # h*w, h*w, 2
|
56 |
+
relative_coords[:, :, 0] += self.img_size[0] - 1 # shift to start from 0
|
57 |
+
relative_coords[:, :, 1] += self.img_size[1] - 1
|
58 |
+
relative_coords[:, :, 0] *= 2 * self.img_size[1] - 1
|
59 |
+
relative_position_index = relative_coords.sum(-1) # h*w, h*w
|
60 |
+
relative_position_index = F.pad(relative_position_index,(extra_token_num,0,extra_token_num,0))
|
61 |
+
relative_position_index = relative_position_index.long()
|
62 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
63 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
64 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
65 |
+
self.proj = nn.Linear(dim, dim)
|
66 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
67 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
68 |
+
self.softmax = nn.Softmax(dim=-1)
|
69 |
+
def forward(self, x,):
|
70 |
+
"""
|
71 |
+
Args:
|
72 |
+
x: input features with shape of (B, N, C)
|
73 |
+
"""
|
74 |
+
B_, N, C = x.shape
|
75 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
76 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
77 |
+
|
78 |
+
q = q * self.scale
|
79 |
+
attn = (q @ k.transpose(-2, -1))
|
80 |
+
|
81 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
82 |
+
self.img_size[0] * self.img_size[1] + self.extra_token_num, self.img_size[0] * self.img_size[1] + self.extra_token_num, -1) # h*w+1,h*w+1,nH
|
83 |
+
|
84 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, h*w+1, h*w+1
|
85 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
86 |
+
|
87 |
+
attn = self.softmax(attn)
|
88 |
+
|
89 |
+
attn = self.attn_drop(attn)
|
90 |
+
|
91 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
92 |
+
x = self.proj(x)
|
93 |
+
x = self.proj_drop(x)
|
94 |
+
return x
|
95 |
+
class OverlapPatchEmbed(nn.Module):
|
96 |
+
""" Image to Patch Embedding
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
100 |
+
super().__init__()
|
101 |
+
patch_size = to_2tuple(patch_size)
|
102 |
+
self.patch_size = patch_size
|
103 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
104 |
+
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
105 |
+
self.norm = nn.LayerNorm(embed_dim)
|
106 |
+
|
107 |
+
self.apply(self._init_weights)
|
108 |
+
|
109 |
+
def _init_weights(self, m):
|
110 |
+
if isinstance(m, nn.Linear):
|
111 |
+
trunc_normal_(m.weight, std=.02)
|
112 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
113 |
+
nn.init.constant_(m.bias, 0)
|
114 |
+
elif isinstance(m, nn.LayerNorm):
|
115 |
+
nn.init.constant_(m.bias, 0)
|
116 |
+
nn.init.constant_(m.weight, 1.0)
|
117 |
+
elif isinstance(m, nn.Conv2d):
|
118 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
119 |
+
fan_out //= m.groups
|
120 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
121 |
+
if m.bias is not None:
|
122 |
+
m.bias.data.zero_()
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
x = self.proj(x)
|
126 |
+
_, _, H, W = x.shape
|
127 |
+
x = x.flatten(2).transpose(1, 2)
|
128 |
+
x = self.norm(x)
|
129 |
+
|
130 |
+
return x, H, W
|
131 |
+
class MHSABlock(nn.Module):
|
132 |
+
def __init__(self, input_dim, output_dim,image_size, stride, num_heads,extra_token_num=1,mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
133 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
134 |
+
super().__init__()
|
135 |
+
if stride != 1:
|
136 |
+
self.patch_embed = OverlapPatchEmbed(patch_size=3,stride=stride,in_chans=input_dim,embed_dim=output_dim)
|
137 |
+
self.img_size = image_size//2
|
138 |
+
else:
|
139 |
+
self.patch_embed = None
|
140 |
+
self.img_size = image_size
|
141 |
+
self.img_size = to_2tuple(self.img_size)
|
142 |
+
|
143 |
+
self.norm1 = norm_layer(output_dim)
|
144 |
+
self.attn = Relative_Attention(
|
145 |
+
output_dim,self.img_size, extra_token_num=extra_token_num,num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
146 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
147 |
+
self.norm2 = norm_layer(output_dim)
|
148 |
+
mlp_hidden_dim = int(output_dim * mlp_ratio)
|
149 |
+
self.mlp = Mlp(in_features=output_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
150 |
+
|
151 |
+
def forward(self, x, H, W, extra_tokens=None):
|
152 |
+
if self.patch_embed is not None:
|
153 |
+
x,_,_ = self.patch_embed(x)
|
154 |
+
|
155 |
+
extra_tokens = [token.expand(x.shape[0],-1,-1) for token in extra_tokens]
|
156 |
+
extra_tokens.append(x)
|
157 |
+
x = torch.cat(extra_tokens,dim=1)
|
158 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
159 |
+
x = x + self.drop_path(self.mlp(self.norm2(x),H//2,W//2))
|
160 |
+
return x
|
161 |
+
|
models/MetaFG.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from timm.models.helpers import load_pretrained
|
6 |
+
from timm.models.registry import register_model
|
7 |
+
from timm.models.layers import trunc_normal_
|
8 |
+
import numpy as np
|
9 |
+
from .MBConv import MBConvBlock
|
10 |
+
from .MHSA import MHSABlock,Mlp
|
11 |
+
def _cfg(url='', **kwargs):
|
12 |
+
return {
|
13 |
+
'url': url,
|
14 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
15 |
+
'crop_pct': .9, 'interpolation': 'bicubic',
|
16 |
+
'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225),
|
17 |
+
'classifier': 'head',
|
18 |
+
**kwargs
|
19 |
+
}
|
20 |
+
|
21 |
+
default_cfgs = {
|
22 |
+
'MetaFG_0': _cfg(),
|
23 |
+
'MetaFG_1': _cfg(),
|
24 |
+
'MetaFG_2': _cfg(),
|
25 |
+
}
|
26 |
+
|
27 |
+
def make_blocks(stage_index,depths,embed_dims,img_size,dpr,extra_token_num=1,num_heads=8,mlp_ratio=4.,stage_type='conv'):
|
28 |
+
stage_name = f'stage_{stage_index}'
|
29 |
+
blocks = []
|
30 |
+
for block_idx in range(depths[stage_index]):
|
31 |
+
stride = 2 if block_idx == 0 and stage_index != 1 else 1
|
32 |
+
in_chans = embed_dims[stage_index] if block_idx != 0 else embed_dims[stage_index-1]
|
33 |
+
out_chans = embed_dims[stage_index]
|
34 |
+
image_size = img_size if block_idx == 0 or stage_index == 1 else img_size//2
|
35 |
+
drop_path_rate = dpr[sum(depths[1:stage_index])+block_idx]
|
36 |
+
if stage_type == 'conv':
|
37 |
+
blocks.append(MBConvBlock(ksize=3,input_filters=in_chans,output_filters=out_chans,
|
38 |
+
image_size=image_size,expand_ratio=int(mlp_ratio),stride=stride,drop_connect_rate=drop_path_rate))
|
39 |
+
elif stage_type == 'mhsa':
|
40 |
+
blocks.append(MHSABlock(input_dim=in_chans,output_dim=out_chans,
|
41 |
+
image_size=image_size,stride=stride,num_heads=num_heads,extra_token_num=extra_token_num,
|
42 |
+
mlp_ratio=mlp_ratio,drop_path=drop_path_rate))
|
43 |
+
else:
|
44 |
+
raise NotImplementedError("We only support conv and mhsa")
|
45 |
+
return blocks
|
46 |
+
|
47 |
+
|
48 |
+
class MetaFG(nn.Module):
|
49 |
+
def __init__(self,img_size=224,in_chans=3, num_classes=1000,
|
50 |
+
conv_embed_dims = [64,96,192],attn_embed_dims=[384,768],
|
51 |
+
conv_depths = [2,2,3],attn_depths = [5,2],num_heads=32,extra_token_num=1,mlp_ratio=4.,
|
52 |
+
conv_norm_layer=nn.BatchNorm2d,attn_norm_layer=nn.LayerNorm,
|
53 |
+
conv_act_layer=nn.ReLU,attn_act_layer=nn.GELU,
|
54 |
+
qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,drop_path_rate=0.,
|
55 |
+
meta_dims=[],
|
56 |
+
only_last_cls=False,
|
57 |
+
use_checkpoint=False):
|
58 |
+
super().__init__()
|
59 |
+
self.only_last_cls = only_last_cls
|
60 |
+
self.img_size = img_size
|
61 |
+
self.num_classes = num_classes
|
62 |
+
stem_chs = (3 * (conv_embed_dims[0] // 4), conv_embed_dims[0])
|
63 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(conv_depths[1:]+attn_depths))]
|
64 |
+
#stage_0
|
65 |
+
self.stage_0 = nn.Sequential(*[
|
66 |
+
nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False),
|
67 |
+
conv_norm_layer(stem_chs[0]),
|
68 |
+
conv_act_layer(inplace=True),
|
69 |
+
nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False),
|
70 |
+
conv_norm_layer(stem_chs[1]),
|
71 |
+
conv_act_layer(inplace=True),
|
72 |
+
nn.Conv2d(stem_chs[1], conv_embed_dims[0], 3, stride=1, padding=1, bias=False)])
|
73 |
+
self.bn1 = conv_norm_layer(conv_embed_dims[0])
|
74 |
+
self.act1 = conv_act_layer(inplace=True)
|
75 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
76 |
+
#stage_1
|
77 |
+
self.stage_1 = nn.ModuleList(make_blocks(1,conv_depths+attn_depths,conv_embed_dims+attn_embed_dims,img_size//4,
|
78 |
+
dpr=dpr,num_heads=num_heads,extra_token_num=extra_token_num,mlp_ratio=mlp_ratio,stage_type='conv'))
|
79 |
+
#stage_2
|
80 |
+
self.stage_2 = nn.ModuleList(make_blocks(2,conv_depths+attn_depths,conv_embed_dims+attn_embed_dims,img_size//4,
|
81 |
+
dpr=dpr,num_heads=num_heads,extra_token_num=extra_token_num,mlp_ratio=mlp_ratio,stage_type='conv'))
|
82 |
+
|
83 |
+
#stage_3
|
84 |
+
self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, attn_embed_dims[0]))
|
85 |
+
self.stage_3 = nn.ModuleList(make_blocks(3,conv_depths+attn_depths,conv_embed_dims+attn_embed_dims,img_size//8,
|
86 |
+
dpr=dpr,num_heads=num_heads,extra_token_num=extra_token_num,mlp_ratio=mlp_ratio,stage_type='mhsa'))
|
87 |
+
|
88 |
+
#stage_4
|
89 |
+
self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, attn_embed_dims[1]))
|
90 |
+
self.stage_4 = nn.ModuleList(make_blocks(4,conv_depths+attn_depths,conv_embed_dims+attn_embed_dims,img_size//16,
|
91 |
+
dpr=dpr,num_heads=num_heads,extra_token_num=extra_token_num,mlp_ratio=mlp_ratio,stage_type='mhsa'))
|
92 |
+
self.norm_2 = attn_norm_layer(attn_embed_dims[1])
|
93 |
+
#Aggregate
|
94 |
+
if not self.only_last_cls:
|
95 |
+
self.cl_1_fc = nn.Sequential(*[Mlp(in_features=attn_embed_dims[0], out_features=attn_embed_dims[1]),
|
96 |
+
attn_norm_layer(attn_embed_dims[1])])
|
97 |
+
self.aggregate = torch.nn.Conv1d(in_channels=2, out_channels=1, kernel_size=1)
|
98 |
+
self.norm_1 = attn_norm_layer(attn_embed_dims[0])
|
99 |
+
self.norm = attn_norm_layer(attn_embed_dims[1])
|
100 |
+
|
101 |
+
# Classifier head
|
102 |
+
self.head = nn.Linear(attn_embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
|
103 |
+
|
104 |
+
trunc_normal_(self.cls_token_1, std=.02)
|
105 |
+
trunc_normal_(self.cls_token_2, std=.02)
|
106 |
+
self.apply(self._init_weights)
|
107 |
+
def _init_weights(self, m):
|
108 |
+
if isinstance(m, nn.Linear):
|
109 |
+
trunc_normal_(m.weight, std=.02)
|
110 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
111 |
+
nn.init.constant_(m.bias, 0)
|
112 |
+
elif isinstance(m, nn.LayerNorm):
|
113 |
+
nn.init.constant_(m.bias, 0)
|
114 |
+
nn.init.constant_(m.weight, 1.0)
|
115 |
+
elif isinstance(m, nn.Conv2d):
|
116 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
117 |
+
# fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
118 |
+
# fan_out //= m.groups
|
119 |
+
# m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
120 |
+
# if m.bias is not None:
|
121 |
+
# m.bias.data.zero_()
|
122 |
+
elif isinstance(m, nn.BatchNorm2d):
|
123 |
+
nn.init.ones_(m.weight)
|
124 |
+
nn.init.zeros_(m.bias)
|
125 |
+
|
126 |
+
@torch.jit.ignore
|
127 |
+
def no_weight_decay(self):
|
128 |
+
return {'cls_token_1','cls_token_2'}
|
129 |
+
|
130 |
+
def get_classifier(self):
|
131 |
+
return self.head
|
132 |
+
|
133 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
134 |
+
self.num_classes = num_classes
|
135 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
136 |
+
|
137 |
+
def forward_features(self,x,meta=None):
|
138 |
+
extra_tokens_1 = [self.cls_token_1]
|
139 |
+
extra_tokens_2 = [self.cls_token_2]
|
140 |
+
B = x.shape[0]
|
141 |
+
x = self.stage_0(x)
|
142 |
+
x = self.bn1(x)
|
143 |
+
x = self.act1(x)
|
144 |
+
x = self.maxpool(x)
|
145 |
+
for blk in self.stage_1:
|
146 |
+
x = blk(x)
|
147 |
+
for blk in self.stage_2:
|
148 |
+
x = blk(x)
|
149 |
+
H0,W0 = self.img_size//8,self.img_size//8
|
150 |
+
for ind,blk in enumerate(self.stage_3):
|
151 |
+
if ind==0:
|
152 |
+
x = blk(x,H0,W0,extra_tokens_1)
|
153 |
+
else:
|
154 |
+
x = blk(x,H0,W0)
|
155 |
+
if not self.only_last_cls:
|
156 |
+
cls_1 = x[:, :1, :]
|
157 |
+
cls_1 = self.norm_1(cls_1)
|
158 |
+
cls_1 = self.cl_1_fc(cls_1)
|
159 |
+
x = x[:, 1:, :]
|
160 |
+
H1,W1 = self.img_size//16,self.img_size//16
|
161 |
+
x = x.reshape(B,H1,W1,-1).permute(0, 3, 1, 2).contiguous()
|
162 |
+
for ind,blk in enumerate(self.stage_4):
|
163 |
+
if ind==0:
|
164 |
+
x = blk(x,H1,W1,extra_tokens_2)
|
165 |
+
else:
|
166 |
+
x = blk(x,H1,W1)
|
167 |
+
cls_2 = x[:, :1, :]
|
168 |
+
cls_2 = self.norm_2(cls_2)
|
169 |
+
if not self.only_last_cls:
|
170 |
+
cls = torch.cat((cls_1,cls_2), dim=1)#B,2,C
|
171 |
+
cls = self.aggregate(cls).squeeze(dim=1)#B,C
|
172 |
+
cls = self.norm(cls)
|
173 |
+
else:
|
174 |
+
cls = cls_2.squeeze(dim=1)
|
175 |
+
return cls
|
176 |
+
|
177 |
+
def forward(self, x,meta=None):
|
178 |
+
x = self.forward_features(x,meta)
|
179 |
+
x = self.head(x)
|
180 |
+
return x
|
181 |
+
@register_model
|
182 |
+
def MetaFG_0(pretrained=False, **kwargs):
|
183 |
+
model = MetaFG(conv_embed_dims = [64,96,192],attn_embed_dims=[384,768],
|
184 |
+
conv_depths = [2,2,3],attn_depths = [5,2],num_heads=8,mlp_ratio=4., **kwargs)
|
185 |
+
model.default_cfg = default_cfgs['MetaFG_0']
|
186 |
+
if pretrained:
|
187 |
+
load_pretrained(
|
188 |
+
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
189 |
+
return model
|
190 |
+
@register_model
|
191 |
+
def MetaFG_1(pretrained=False, **kwargs):
|
192 |
+
model = MetaFG(conv_embed_dims = [64,96,192],attn_embed_dims=[384,768],
|
193 |
+
conv_depths = [2,2,6],attn_depths = [14,2],num_heads=8,mlp_ratio=4., **kwargs)
|
194 |
+
model.default_cfg = default_cfgs['MetaFG_1']
|
195 |
+
if pretrained:
|
196 |
+
load_pretrained(
|
197 |
+
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
198 |
+
return model
|
199 |
+
@register_model
|
200 |
+
def MetaFG_2(pretrained=False, **kwargs):
|
201 |
+
model = MetaFG(conv_embed_dims = [128,128,256],attn_embed_dims=[512,1024],
|
202 |
+
conv_depths = [2,2,6],attn_depths = [14,2],num_heads=8,mlp_ratio=4., **kwargs)
|
203 |
+
model.default_cfg = default_cfgs['MetaFG_2']
|
204 |
+
if pretrained:
|
205 |
+
load_pretrained(
|
206 |
+
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
207 |
+
return model
|
208 |
+
if __name__ == "__main__":
|
209 |
+
x = torch.randn([2, 3, 224, 224])
|
210 |
+
model = MetaFG()
|
211 |
+
import ipdb;ipdb.set_trace()
|
212 |
+
output = model(x)
|
213 |
+
print(output.shape)
|
models/MetaFG_meta.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.utils.checkpoint as checkpoint
|
5 |
+
from timm.models.helpers import load_pretrained
|
6 |
+
from timm.models.registry import register_model
|
7 |
+
from timm.models.layers import trunc_normal_
|
8 |
+
import numpy as np
|
9 |
+
from .MBConv import MBConvBlock
|
10 |
+
from .MHSA import MHSABlock,Mlp
|
11 |
+
from .meta_encoder import ResNormLayer
|
12 |
+
def _cfg(url='', **kwargs):
|
13 |
+
return {
|
14 |
+
'url': url,
|
15 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
16 |
+
'crop_pct': .9, 'interpolation': 'bicubic',
|
17 |
+
'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225),
|
18 |
+
'classifier': 'head',
|
19 |
+
**kwargs
|
20 |
+
}
|
21 |
+
|
22 |
+
default_cfgs = {
|
23 |
+
'MetaFG_0': _cfg(),
|
24 |
+
'MetaFG_1': _cfg(),
|
25 |
+
'MetaFG_2': _cfg(),
|
26 |
+
}
|
27 |
+
|
28 |
+
def make_blocks(stage_index,depths,embed_dims,img_size,dpr,extra_token_num=1,num_heads=8,mlp_ratio=4.,stage_type='conv'):
|
29 |
+
stage_name = f'stage_{stage_index}'
|
30 |
+
blocks = []
|
31 |
+
for block_idx in range(depths[stage_index]):
|
32 |
+
stride = 2 if block_idx == 0 and stage_index != 1 else 1
|
33 |
+
in_chans = embed_dims[stage_index] if block_idx != 0 else embed_dims[stage_index-1]
|
34 |
+
out_chans = embed_dims[stage_index]
|
35 |
+
image_size = img_size if block_idx == 0 or stage_index == 1 else img_size//2
|
36 |
+
drop_path_rate = dpr[sum(depths[1:stage_index])+block_idx]
|
37 |
+
if stage_type == 'conv':
|
38 |
+
blocks.append(MBConvBlock(ksize=3,input_filters=in_chans,output_filters=out_chans,
|
39 |
+
image_size=image_size,expand_ratio=int(mlp_ratio),stride=stride,drop_connect_rate=drop_path_rate))
|
40 |
+
elif stage_type == 'mhsa':
|
41 |
+
blocks.append(MHSABlock(input_dim=in_chans,output_dim=out_chans,
|
42 |
+
image_size=image_size,stride=stride,num_heads=num_heads,extra_token_num=extra_token_num,
|
43 |
+
mlp_ratio=mlp_ratio,drop_path=drop_path_rate))
|
44 |
+
else:
|
45 |
+
raise NotImplementedError("We only support conv and mhsa")
|
46 |
+
return blocks
|
47 |
+
|
48 |
+
|
49 |
+
class MetaFG_Meta(nn.Module):
|
50 |
+
def __init__(self,img_size=224,in_chans=3, num_classes=1000,
|
51 |
+
conv_embed_dims = [64,96,192],attn_embed_dims=[384,768],
|
52 |
+
conv_depths = [2,2,3],attn_depths = [5,2],num_heads=32,extra_token_num=3,mlp_ratio=4.,
|
53 |
+
conv_norm_layer=nn.BatchNorm2d,attn_norm_layer=nn.LayerNorm,
|
54 |
+
conv_act_layer=nn.ReLU,attn_act_layer=nn.GELU,
|
55 |
+
qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,drop_path_rate=0.,
|
56 |
+
add_meta=True,meta_dims=[4,3],mask_prob=1.0,mask_type='linear',
|
57 |
+
only_last_cls=False,
|
58 |
+
use_checkpoint=False):
|
59 |
+
super().__init__()
|
60 |
+
self.only_last_cls = only_last_cls
|
61 |
+
self.img_size = img_size
|
62 |
+
self.num_classes = num_classes
|
63 |
+
self.add_meta = add_meta
|
64 |
+
self.meta_dims = meta_dims
|
65 |
+
self.cur_epoch = -1
|
66 |
+
self.total_epoch = -1
|
67 |
+
self.mask_prob = mask_prob
|
68 |
+
self.mask_type = mask_type
|
69 |
+
self.attn_embed_dims = attn_embed_dims
|
70 |
+
self.extra_token_num = extra_token_num
|
71 |
+
if self.add_meta:
|
72 |
+
# assert len(meta_dims)==extra_token_num-1
|
73 |
+
for ind,meta_dim in enumerate(meta_dims):
|
74 |
+
meta_head_1 = nn.Sequential(
|
75 |
+
nn.Linear(meta_dim, attn_embed_dims[0]),
|
76 |
+
nn.ReLU(inplace=True),
|
77 |
+
nn.LayerNorm(attn_embed_dims[0]),
|
78 |
+
ResNormLayer(attn_embed_dims[0]),
|
79 |
+
) if meta_dim > 0 else nn.Identity()
|
80 |
+
meta_head_2 = nn.Sequential(
|
81 |
+
nn.Linear(meta_dim, attn_embed_dims[1]),
|
82 |
+
nn.ReLU(inplace=True),
|
83 |
+
nn.LayerNorm(attn_embed_dims[1]),
|
84 |
+
ResNormLayer(attn_embed_dims[1]),
|
85 |
+
) if meta_dim > 0 else nn.Identity()
|
86 |
+
setattr(self, f"meta_{ind+1}_head_1", meta_head_1)
|
87 |
+
setattr(self, f"meta_{ind+1}_head_2", meta_head_2)
|
88 |
+
|
89 |
+
|
90 |
+
stem_chs = (3 * (conv_embed_dims[0] // 4), conv_embed_dims[0])
|
91 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(conv_depths[1:]+attn_depths))]
|
92 |
+
#stage_0
|
93 |
+
self.stage_0 = nn.Sequential(*[
|
94 |
+
nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False),
|
95 |
+
conv_norm_layer(stem_chs[0]),
|
96 |
+
conv_act_layer(inplace=True),
|
97 |
+
nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False),
|
98 |
+
conv_norm_layer(stem_chs[1]),
|
99 |
+
conv_act_layer(inplace=True),
|
100 |
+
nn.Conv2d(stem_chs[1], conv_embed_dims[0], 3, stride=1, padding=1, bias=False)])
|
101 |
+
self.bn1 = conv_norm_layer(conv_embed_dims[0])
|
102 |
+
self.act1 = conv_act_layer(inplace=True)
|
103 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
104 |
+
#stage_1
|
105 |
+
self.stage_1 = nn.ModuleList(make_blocks(1,conv_depths+attn_depths,conv_embed_dims+attn_embed_dims,img_size//4,
|
106 |
+
dpr=dpr,num_heads=num_heads,extra_token_num=extra_token_num,mlp_ratio=mlp_ratio,stage_type='conv'))
|
107 |
+
#stage_2
|
108 |
+
self.stage_2 = nn.ModuleList(make_blocks(2,conv_depths+attn_depths,conv_embed_dims+attn_embed_dims,img_size//4,
|
109 |
+
dpr=dpr,num_heads=num_heads,extra_token_num=extra_token_num,mlp_ratio=mlp_ratio,stage_type='conv'))
|
110 |
+
|
111 |
+
#stage_3
|
112 |
+
self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, attn_embed_dims[0]))
|
113 |
+
self.stage_3 = nn.ModuleList(make_blocks(3,conv_depths+attn_depths,conv_embed_dims+attn_embed_dims,img_size//8,
|
114 |
+
dpr=dpr,num_heads=num_heads,extra_token_num=extra_token_num,mlp_ratio=mlp_ratio,stage_type='mhsa'))
|
115 |
+
#stage_4
|
116 |
+
self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, attn_embed_dims[1]))
|
117 |
+
self.stage_4 = nn.ModuleList(make_blocks(4,conv_depths+attn_depths,conv_embed_dims+attn_embed_dims,img_size//16,
|
118 |
+
dpr=dpr,num_heads=num_heads,extra_token_num=extra_token_num,mlp_ratio=mlp_ratio,stage_type='mhsa'))
|
119 |
+
self.norm_2 = attn_norm_layer(attn_embed_dims[1])
|
120 |
+
|
121 |
+
#Aggregate
|
122 |
+
if not self.only_last_cls:
|
123 |
+
self.cl_1_fc = nn.Sequential(*[Mlp(in_features=attn_embed_dims[0], out_features=attn_embed_dims[1]),
|
124 |
+
attn_norm_layer(attn_embed_dims[1])])
|
125 |
+
self.aggregate = torch.nn.Conv1d(in_channels=2, out_channels=1, kernel_size=1)
|
126 |
+
self.norm = attn_norm_layer(attn_embed_dims[1])
|
127 |
+
self.norm_1 = attn_norm_layer(attn_embed_dims[0])
|
128 |
+
# Classifier head
|
129 |
+
self.head = nn.Linear(attn_embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
|
130 |
+
|
131 |
+
trunc_normal_(self.cls_token_1, std=.02)
|
132 |
+
trunc_normal_(self.cls_token_2, std=.02)
|
133 |
+
self.apply(self._init_weights)
|
134 |
+
def _init_weights(self, m):
|
135 |
+
if isinstance(m, nn.Linear):
|
136 |
+
trunc_normal_(m.weight, std=.02)
|
137 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
138 |
+
nn.init.constant_(m.bias, 0)
|
139 |
+
elif isinstance(m, nn.LayerNorm):
|
140 |
+
nn.init.constant_(m.bias, 0)
|
141 |
+
nn.init.constant_(m.weight, 1.0)
|
142 |
+
elif isinstance(m, nn.Conv2d):
|
143 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
144 |
+
# fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
145 |
+
# fan_out //= m.groups
|
146 |
+
# m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
147 |
+
# if m.bias is not None:
|
148 |
+
# m.bias.data.zero_()
|
149 |
+
elif isinstance(m, nn.BatchNorm2d):
|
150 |
+
nn.init.ones_(m.weight)
|
151 |
+
nn.init.zeros_(m.bias)
|
152 |
+
|
153 |
+
@torch.jit.ignore
|
154 |
+
def no_weight_decay(self):
|
155 |
+
return {'cls_token_1','cls_token_2'}
|
156 |
+
|
157 |
+
def get_classifier(self):
|
158 |
+
return self.head
|
159 |
+
|
160 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
161 |
+
self.num_classes = num_classes
|
162 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
163 |
+
|
164 |
+
def forward_features(self,x,meta=None):
|
165 |
+
B = x.shape[0]
|
166 |
+
extra_tokens_1 = [self.cls_token_1]
|
167 |
+
extra_tokens_2 = [self.cls_token_2]
|
168 |
+
if self.add_meta:
|
169 |
+
assert meta != None,'meta is None'
|
170 |
+
if len(self.meta_dims)>1:
|
171 |
+
metas = torch.split(meta,self.meta_dims,dim=1)
|
172 |
+
else:
|
173 |
+
metas = (meta,)
|
174 |
+
for ind,cur_meta in enumerate(metas):
|
175 |
+
meta_head_1 = getattr(self,f"meta_{ind+1}_head_1")
|
176 |
+
meta_head_2 = getattr(self,f"meta_{ind+1}_head_2")
|
177 |
+
meta_1 = meta_head_1(cur_meta)
|
178 |
+
meta_1 = meta_1.reshape(B, -1, self.attn_embed_dims[0])
|
179 |
+
meta_2 = meta_head_2(cur_meta)
|
180 |
+
meta_2 = meta_2.reshape(B, -1, self.attn_embed_dims[1])
|
181 |
+
extra_tokens_1.append(meta_1)
|
182 |
+
extra_tokens_2.append(meta_2)
|
183 |
+
|
184 |
+
x = self.stage_0(x)
|
185 |
+
x = self.bn1(x)
|
186 |
+
x = self.act1(x)
|
187 |
+
x = self.maxpool(x)
|
188 |
+
for blk in self.stage_1:
|
189 |
+
x = blk(x)
|
190 |
+
for blk in self.stage_2:
|
191 |
+
x = blk(x)
|
192 |
+
H0,W0 = self.img_size//8,self.img_size//8
|
193 |
+
for ind,blk in enumerate(self.stage_3):
|
194 |
+
if ind==0:
|
195 |
+
x = blk(x,H0,W0,extra_tokens_1)
|
196 |
+
else:
|
197 |
+
x = blk(x,H0,W0)
|
198 |
+
if not self.only_last_cls:
|
199 |
+
cls_1 = x[:, :1, :]
|
200 |
+
cls_1 = self.norm_1(cls_1)
|
201 |
+
cls_1 = self.cl_1_fc(cls_1)
|
202 |
+
|
203 |
+
x = x[:, self.extra_token_num:, :]
|
204 |
+
H1,W1 = self.img_size//16,self.img_size//16
|
205 |
+
x = x.reshape(B,H1,W1,-1).permute(0, 3, 1, 2).contiguous()
|
206 |
+
for ind,blk in enumerate(self.stage_4):
|
207 |
+
if ind==0:
|
208 |
+
x = blk(x,H1,W1,extra_tokens_2)
|
209 |
+
else:
|
210 |
+
x = blk(x,H1,W1)
|
211 |
+
cls_2 = x[:, :1, :]
|
212 |
+
cls_2 = self.norm_2(cls_2)
|
213 |
+
if not self.only_last_cls:
|
214 |
+
cls = torch.cat((cls_1,cls_2), dim=1)#B,2,C
|
215 |
+
cls = self.aggregate(cls).squeeze(dim=1)#B,C
|
216 |
+
cls = self.norm(cls)
|
217 |
+
else:
|
218 |
+
cls = cls_2.squeeze(dim=1)
|
219 |
+
return cls
|
220 |
+
def forward(self, x,meta=None):
|
221 |
+
if meta is not None:
|
222 |
+
if self.mask_type=='linear':
|
223 |
+
cur_mask_prob = self.mask_prob - self.cur_epoch/self.total_epoch
|
224 |
+
else:
|
225 |
+
cur_mask_prob = self.mask_prob
|
226 |
+
if cur_mask_prob != 0 and self.training:
|
227 |
+
mask = torch.ones_like(meta)
|
228 |
+
mask_index = torch.randperm(meta.size(0))[:int(meta.size(0)*cur_mask_prob)]
|
229 |
+
mask[mask_index] = 0
|
230 |
+
meta = mask * meta
|
231 |
+
x = self.forward_features(x,meta)
|
232 |
+
x = self.head(x)
|
233 |
+
return x
|
234 |
+
|
235 |
+
@register_model
|
236 |
+
def MetaFG_meta_0(pretrained=False, **kwargs):
|
237 |
+
model = MetaFG_Meta(conv_embed_dims = [64,96,192],attn_embed_dims=[384,768],
|
238 |
+
conv_depths = [2,2,3],attn_depths = [5,2],num_heads=8,mlp_ratio=4., **kwargs)
|
239 |
+
model.default_cfg = default_cfgs['MetaFG_0']
|
240 |
+
if pretrained:
|
241 |
+
load_pretrained(
|
242 |
+
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
243 |
+
return model
|
244 |
+
@register_model
|
245 |
+
def MetaFG_meta_1(pretrained=False, **kwargs):
|
246 |
+
model = MetaFG_Meta(conv_embed_dims = [64,96,192],attn_embed_dims=[384,768],
|
247 |
+
conv_depths = [2,2,6],attn_depths = [14,2],num_heads=8,mlp_ratio=4., **kwargs)
|
248 |
+
model.default_cfg = default_cfgs['MetaFG_1']
|
249 |
+
if pretrained:
|
250 |
+
load_pretrained(
|
251 |
+
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
252 |
+
return model
|
253 |
+
@register_model
|
254 |
+
def MetaFG_meta_2(pretrained=False, **kwargs):
|
255 |
+
model = MetaFG_Meta(conv_embed_dims = [128,128,256],attn_embed_dims=[512,1024],
|
256 |
+
conv_depths = [2,2,6],attn_depths = [14,2],num_heads=8,mlp_ratio=4., **kwargs)
|
257 |
+
model.default_cfg = default_cfgs['MetaFG_2']
|
258 |
+
if pretrained:
|
259 |
+
load_pretrained(
|
260 |
+
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
261 |
+
return model
|
262 |
+
if __name__ == "__main__":
|
263 |
+
x = torch.randn([2, 3, 224, 224])
|
264 |
+
meta = torch.randn([2,7])
|
265 |
+
model = MetaFG_meta()
|
266 |
+
import ipdb;ipdb.set_trace()
|
267 |
+
output = model(x,meta)
|
268 |
+
print(output.shape)
|
models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .build import build_model
|
models/build.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from timm.models import create_model
|
2 |
+
from .MetaFG import *
|
3 |
+
from .MetaFG_meta import *
|
4 |
+
def build_model(config):
|
5 |
+
model_type = config.MODEL.TYPE
|
6 |
+
if model_type == 'MetaFG':
|
7 |
+
model = create_model(
|
8 |
+
config.MODEL.NAME,
|
9 |
+
pretrained=False,
|
10 |
+
num_classes=config.MODEL.NUM_CLASSES,
|
11 |
+
drop_path_rate=config.MODEL.DROP_PATH_RATE,
|
12 |
+
img_size=config.DATA.IMG_SIZE,
|
13 |
+
only_last_cls=config.MODEL.ONLY_LAST_CLS,
|
14 |
+
extra_token_num=config.MODEL.EXTRA_TOKEN_NUM,
|
15 |
+
meta_dims=config.MODEL.META_DIMS
|
16 |
+
)
|
17 |
+
else:
|
18 |
+
raise NotImplementedError(f"Unkown model: {model_type}")
|
19 |
+
|
20 |
+
return model
|
models/meta_encoder.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
class ResNormLayer(nn.Module):
|
3 |
+
def __init__(self, linear_size,):
|
4 |
+
super(ResNormLayer, self).__init__()
|
5 |
+
self.l_size = linear_size
|
6 |
+
self.nonlin1 = nn.ReLU(inplace=True)
|
7 |
+
self.nonlin2 = nn.ReLU(inplace=True)
|
8 |
+
self.norm_fn1 = nn.LayerNorm(self.l_size)
|
9 |
+
self.norm_fn2 = nn.LayerNorm(self.l_size)
|
10 |
+
self.w1 = nn.Linear(self.l_size, self.l_size)
|
11 |
+
self.w2 = nn.Linear(self.l_size, self.l_size)
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
y = self.w1(x)
|
15 |
+
y = self.nonlin1(y)
|
16 |
+
y = self.norm_fn1(y)
|
17 |
+
y = self.w2(y)
|
18 |
+
y = self.nonlin2(y)
|
19 |
+
y = self.norm_fn2(y)
|
20 |
+
out = x + y
|
21 |
+
return out
|
optimizer.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import optim as optim
|
2 |
+
|
3 |
+
|
4 |
+
def build_optimizer(config, model):
|
5 |
+
"""
|
6 |
+
Build optimizer, set weight decay of normalization to 0 by default.
|
7 |
+
"""
|
8 |
+
skip = {}
|
9 |
+
skip_keywords = {}
|
10 |
+
if hasattr(model, 'no_weight_decay'):
|
11 |
+
skip = model.no_weight_decay()
|
12 |
+
if hasattr(model, 'no_weight_decay_keywords'):
|
13 |
+
skip_keywords = model.no_weight_decay_keywords()
|
14 |
+
parameters = set_weight_decay(model, skip, skip_keywords,config.TRAIN.BASE_LR)
|
15 |
+
|
16 |
+
opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
|
17 |
+
optimizer = None
|
18 |
+
if opt_lower == 'sgd':
|
19 |
+
optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
|
20 |
+
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
|
21 |
+
elif opt_lower == 'adamw':
|
22 |
+
optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
|
23 |
+
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
|
24 |
+
|
25 |
+
return optimizer
|
26 |
+
|
27 |
+
# def set_weight_decay(model, skip_list=(), skip_keywords=(),lr=0.0):
|
28 |
+
# has_decay = []
|
29 |
+
# no_decay = []
|
30 |
+
# high_lr = []
|
31 |
+
# for name, param in model.named_parameters():
|
32 |
+
# if not param.requires_grad:
|
33 |
+
# continue # frozen weights
|
34 |
+
# if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
|
35 |
+
# check_keywords_in_name(name, skip_keywords):
|
36 |
+
# if 'meta' in name:
|
37 |
+
# high_lr.append(param)
|
38 |
+
# else:
|
39 |
+
# no_decay.append(param)
|
40 |
+
# # print(f"{name} has no weight decay")
|
41 |
+
# else:
|
42 |
+
# has_decay.append(param)
|
43 |
+
# return [{'params': has_decay},
|
44 |
+
# # {'params':high_lr,'weight_decay': 0.,'lr':lr*10},
|
45 |
+
# {'params':high_lr,'lr':lr*20},
|
46 |
+
# {'params': no_decay, 'weight_decay': 0.}]
|
47 |
+
|
48 |
+
def set_weight_decay(model, skip_list=(), skip_keywords=(),lr=0.0):
|
49 |
+
has_decay = []
|
50 |
+
no_decay = []
|
51 |
+
|
52 |
+
for name, param in model.named_parameters():
|
53 |
+
if not param.requires_grad:
|
54 |
+
continue # frozen weights
|
55 |
+
if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
|
56 |
+
check_keywords_in_name(name, skip_keywords):
|
57 |
+
no_decay.append(param)
|
58 |
+
# print(f"{name} has no weight decay")
|
59 |
+
else:
|
60 |
+
has_decay.append(param)
|
61 |
+
return [{'params': has_decay},
|
62 |
+
{'params': no_decay, 'weight_decay': 0.}]
|
63 |
+
|
64 |
+
|
65 |
+
def check_keywords_in_name(name, keywords=()):
|
66 |
+
isin = False
|
67 |
+
for keyword in keywords:
|
68 |
+
if keyword in name:
|
69 |
+
isin = True
|
70 |
+
return isin
|
utils.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import importlib
|
4 |
+
import torch.distributed as dist
|
5 |
+
|
6 |
+
try:
|
7 |
+
# noinspection PyUnresolvedReferences
|
8 |
+
from apex import amp
|
9 |
+
except ImportError:
|
10 |
+
amp = None
|
11 |
+
|
12 |
+
def relative_bias_interpolate(checkpoint,config):
|
13 |
+
for k in list(checkpoint['model']):
|
14 |
+
if 'relative_position_index' in k:
|
15 |
+
del checkpoint['model'][k]
|
16 |
+
if 'relative_position_bias_table' in k:
|
17 |
+
relative_position_bias_table = checkpoint['model'][k]
|
18 |
+
cls_bias = relative_position_bias_table[:1,:]
|
19 |
+
relative_position_bias_table = relative_position_bias_table[1:,:]
|
20 |
+
size = int(relative_position_bias_table.shape[0]**0.5)
|
21 |
+
img_size = (size+1)//2
|
22 |
+
if 'stage_3' in k:
|
23 |
+
downsample_ratio = 16
|
24 |
+
elif 'stage_4' in k:
|
25 |
+
downsample_ratio = 32
|
26 |
+
new_img_size = config.DATA.IMG_SIZE//downsample_ratio
|
27 |
+
new_size = 2*new_img_size-1
|
28 |
+
if new_size == size:
|
29 |
+
continue
|
30 |
+
relative_position_bias_table = relative_position_bias_table.reshape(size,size,-1)
|
31 |
+
relative_position_bias_table = relative_position_bias_table.unsqueeze(0).permute(0,3,1,2)#bs,nhead,h,w
|
32 |
+
relative_position_bias_table = torch.nn.functional.interpolate(
|
33 |
+
relative_position_bias_table, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
34 |
+
relative_position_bias_table = relative_position_bias_table.permute(0,2,3,1)
|
35 |
+
relative_position_bias_table = relative_position_bias_table.squeeze(0).reshape(new_size*new_size,-1)
|
36 |
+
relative_position_bias_table = torch.cat((cls_bias,relative_position_bias_table),dim=0)
|
37 |
+
checkpoint['model'][k] = relative_position_bias_table
|
38 |
+
return checkpoint
|
39 |
+
|
40 |
+
|
41 |
+
def load_pretained(config,model,logger=None,strict=False):
|
42 |
+
if logger is not None:
|
43 |
+
logger.info(f"==============> pretrain form {config.MODEL.PRETRAINED}....................")
|
44 |
+
checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
|
45 |
+
if 'model' not in checkpoint:
|
46 |
+
if 'state_dict_ema' in checkpoint:
|
47 |
+
checkpoint['model'] = checkpoint['state_dict_ema']
|
48 |
+
else:
|
49 |
+
checkpoint['model'] = checkpoint
|
50 |
+
if config.MODEL.DORP_HEAD:
|
51 |
+
if 'head.weight' in checkpoint['model'] and 'head.bias' in checkpoint['model']:
|
52 |
+
if logger is not None:
|
53 |
+
logger.info(f"==============> drop head....................")
|
54 |
+
del checkpoint['model']['head.weight']
|
55 |
+
del checkpoint['model']['head.bias']
|
56 |
+
if 'head.fc.weight' in checkpoint['model'] and 'head.fc.bias' in checkpoint['model']:
|
57 |
+
if logger is not None:
|
58 |
+
logger.info(f"==============> drop head....................")
|
59 |
+
del checkpoint['model']['head.fc.weight']
|
60 |
+
del checkpoint['model']['head.fc.bias']
|
61 |
+
if config.MODEL.DORP_META:
|
62 |
+
if logger is not None:
|
63 |
+
logger.info(f"==============> drop meta head....................")
|
64 |
+
for k in list(checkpoint['model']):
|
65 |
+
if 'meta' in k:
|
66 |
+
del checkpoint['model'][k]
|
67 |
+
|
68 |
+
checkpoint = relative_bias_interpolate(checkpoint,config)
|
69 |
+
if 'point_coord' in checkpoint['model']:
|
70 |
+
if logger is not None:
|
71 |
+
logger.info(f"==============> drop point coord....................")
|
72 |
+
del checkpoint['model']['point_coord']
|
73 |
+
msg = model.load_state_dict(checkpoint['model'], strict=strict)
|
74 |
+
del checkpoint
|
75 |
+
torch.cuda.empty_cache()
|
76 |
+
|
77 |
+
|
78 |
+
def load_checkpoint(config, model, optimizer, lr_scheduler, logger):
|
79 |
+
logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
|
80 |
+
if config.MODEL.RESUME.startswith('https'):
|
81 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
82 |
+
config.MODEL.RESUME, map_location='cpu', check_hash=True)
|
83 |
+
else:
|
84 |
+
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
|
85 |
+
if 'model' not in checkpoint:
|
86 |
+
if 'state_dict_ema' in checkpoint:
|
87 |
+
checkpoint['model'] = checkpoint['state_dict_ema']
|
88 |
+
else:
|
89 |
+
checkpoint['model'] = checkpoint
|
90 |
+
msg = model.load_state_dict(checkpoint['model'], strict=False)
|
91 |
+
logger.info(msg)
|
92 |
+
max_accuracy = 0.0
|
93 |
+
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
|
94 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
95 |
+
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
96 |
+
config.defrost()
|
97 |
+
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
|
98 |
+
config.freeze()
|
99 |
+
if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0":
|
100 |
+
amp.load_state_dict(checkpoint['amp'])
|
101 |
+
logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
|
102 |
+
if 'max_accuracy' in checkpoint:
|
103 |
+
max_accuracy = checkpoint['max_accuracy']
|
104 |
+
|
105 |
+
del checkpoint
|
106 |
+
torch.cuda.empty_cache()
|
107 |
+
return max_accuracy
|
108 |
+
|
109 |
+
|
110 |
+
def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger):
|
111 |
+
save_state = {'model': model.state_dict(),
|
112 |
+
'optimizer': optimizer.state_dict(),
|
113 |
+
'lr_scheduler': lr_scheduler.state_dict(),
|
114 |
+
'max_accuracy': max_accuracy,
|
115 |
+
'epoch': epoch,
|
116 |
+
'config': config}
|
117 |
+
if config.AMP_OPT_LEVEL != "O0":
|
118 |
+
save_state['amp'] = amp.state_dict()
|
119 |
+
|
120 |
+
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
|
121 |
+
logger.info(f"{save_path} saving......")
|
122 |
+
torch.save(save_state, save_path)
|
123 |
+
logger.info(f"{save_path} saved !!!")
|
124 |
+
|
125 |
+
|
126 |
+
lastest_save_path = os.path.join(config.OUTPUT, f'latest.pth')
|
127 |
+
logger.info(f"{lastest_save_path} saving......")
|
128 |
+
torch.save(save_state, lastest_save_path)
|
129 |
+
logger.info(f"{lastest_save_path} saved !!!")
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
def get_grad_norm(parameters, norm_type=2):
|
134 |
+
if isinstance(parameters, torch.Tensor):
|
135 |
+
parameters = [parameters]
|
136 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
137 |
+
norm_type = float(norm_type)
|
138 |
+
total_norm = 0
|
139 |
+
for p in parameters:
|
140 |
+
param_norm = p.grad.data.norm(norm_type)
|
141 |
+
total_norm += param_norm.item() ** norm_type
|
142 |
+
total_norm = total_norm ** (1. / norm_type)
|
143 |
+
return total_norm
|
144 |
+
|
145 |
+
|
146 |
+
def auto_resume_helper(output_dir):
|
147 |
+
checkpoints = os.listdir(output_dir)
|
148 |
+
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
|
149 |
+
print(f"All checkpoints founded in {output_dir}: {checkpoints}")
|
150 |
+
if len(checkpoints) > 0:
|
151 |
+
latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
|
152 |
+
print(f"The latest checkpoint founded: {latest_checkpoint}")
|
153 |
+
resume_file = latest_checkpoint
|
154 |
+
else:
|
155 |
+
resume_file = None
|
156 |
+
return resume_file
|
157 |
+
|
158 |
+
|
159 |
+
def reduce_tensor(tensor):
|
160 |
+
rt = tensor.clone()
|
161 |
+
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
162 |
+
rt /= dist.get_world_size()
|
163 |
+
return rt
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
def load_ext(name, funcs):
|
169 |
+
ext = importlib.import_module(name)
|
170 |
+
for fun in funcs:
|
171 |
+
assert hasattr(ext, fun), f'{fun} miss in module {name}'
|
172 |
+
return ext
|
173 |
+
|