diaoqishuai commited on
Commit
4a3ad95
·
0 Parent(s):

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+ /output
131
+ /imagenet
132
+ /imagenet_raw
133
+ /pretrained_model
134
+ /inaturelist2021
135
+ /inaturelist2021_mini
136
+ /save_model
137
+ /inaturelist2017
138
+ /inaturelist2018
139
+ /cub-200
140
+ /stanfordcars
141
+ /oxfordflower
142
+ /stanforddogs
143
+ /nabirds
144
+ /aircraft
145
+ /datasets
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 dqshuai
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MetaFG
2
+ A repository for the code used to create and train the model defined in “MetaFormer : A Unified Meta Framework for Fine-Grained Recognition”
3
+
4
+ ## Model zoo
5
+ | name | resolution | 1k model | 21k model | iNat21 model |
6
+ | :--------: | :----------: | :--------: | :----------: | :------------: |
7
+ | MetaFormer-0 | 384x384 | [metafg_0_1k_384](https://drive.google.com/file/d/1r62S3CJFRWV_qA5udC9MOFOJYwRf8mE2/view?usp=sharing) | [metafg_0_21k_384](https://drive.google.com/file/d/1wVmlPjNTA6JKHcF3ROGorEVPxKVO83Ss/view?usp=sharing) | [metafg_0_inat21_384](https://drive.google.com/file/d/11gCk_IuSN7krdkOUSWSM4xlf8GGknmxc/view?usp=sharing) |
8
+ | MetaFormer-1 | 384x384 | [metafg_1_1k_384](https://drive.google.com/file/d/12OTmZg4J6fMGvs-colOTDfmhdA5EMMvo/view?usp=sharing) | [metafg_1_21k_384](https://drive.google.com/file/d/13dsarbtsNrkhpG5XpCRlN5ogXDGXO3Z_/view?usp=sharing) | [metafg_1_inat21_384](https://drive.google.com/file/d/1ATUIrDxaQaGqx4lJ8HE2IwX_evMhblPu/view?usp=sharing) |
9
+ | MetaFormer-2 | 384x384 | [metafg_2_1k_384](https://drive.google.com/file/d/167oBaseORq32aFA3Ex6lpHuasvu2PMb8/view?usp=sharing) | [metafg_2_21k_384](https://drive.google.com/file/d/1PnpntloQaYduEokFGQ6y79G7DdyjD_u3/view?usp=sharing) | [metafg_2_inat21_384](https://drive.google.com/file/d/17sUNST7ivQhonBAfZEiTOLAgtaHa4F3e/view?usp=sharing) |
10
+ ## Usage
11
+ #### python module
12
+ * install `Pytorch and torchvision`
13
+ ```
14
+ pip install torch==1.5.1 torchvision==0.6.1
15
+ ```
16
+ * install `timm`
17
+ ```
18
+ pip install timm==0.4.5
19
+ ```
20
+ * install `Apex`
21
+ ```
22
+ git clone https://github.com/NVIDIA/apex
23
+ cd apex
24
+ pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
25
+ ```
26
+ * install other requirements
27
+ ```
28
+ pip install opencv-python==4.5.1.48 yacs==0.1.8
29
+ ```
30
+ #### data preparation
31
+ Download [inat21,18,17](https://github.com/visipedia/inat_comp),[CUB](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html),[NABirds](https://dl.allaboutbirds.org/nabirds),[stanfordcars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html), and[aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/), put them in respective folders (\<root\>/datasets/<dataset_name>) and Unzip file. The folder sturture as follow:
32
+ ```
33
+ datasets
34
+ |————inraturelist2021
35
+ | └——————train
36
+ | └——————val
37
+ | └——————train.json
38
+ | └——————val.json
39
+ |————inraturelist2018
40
+ | └——————train_val_images
41
+ | └——————train2018.json
42
+ | └——————val2018.json
43
+ | └——————train2018_locations.json
44
+ | └——————val2018_locations.json
45
+ | └——————categories.json.json
46
+ |————inraturelist2017
47
+ | └——————train_val_images
48
+ | └——————train2017.json
49
+ | └——————val2017.json
50
+ | └——————train2017_locations.json
51
+ | └——————val2017_locations.json
52
+ |————cub-200
53
+ | └——————...
54
+ |————nabirds
55
+ | └——————...
56
+ |————stanfordcars
57
+ | └——————car_ims
58
+ | └——————cars_annos.mat
59
+ |————aircraft
60
+ | └——————...
61
+ ```
62
+ #### Training
63
+ You can dowmload pre-trained model from model zoo, and put them under \<root\>/pretrained.
64
+ To train MetaFG on datasets, run:
65
+ ```
66
+ python3 -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py --cfg <config-file> --dataset <dataset-name> --pretrain <pretainedmodel-path> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]
67
+ ```
68
+ \<dataset-name\>:inaturelist2021,inaturelist2018,inaturelist2017,cub-200,nabirds,stanfordcars,aircraft
69
+ For CUB-200-2011, run:
70
+ ```
71
+ python3 -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py --cfg ./configs/MetaFG_1_224.yaml --batch-size 32 --tag cub-200_v1 --lr 5e-5 --min-lr 5e-7 --warmup-lr 5e-8 --epochs 300 --warmup-epochs 20 --dataset cub-200 --pretrain ./pretrained_model/<xxxx>.pth --accumulation-steps 2 --opts DATA.IMG_SIZE 384
72
+ ```
73
+ note that final learning rate is total_bs/512.
74
+ #### Eval
75
+ To evaluate model on dataset,run:
76
+ ```
77
+ python3 -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py --eval --cfg <config-file> --dataset <dataset-name> --resume <checkpoint> [--batch-size <batch-size-per-gpu>]
78
+ ```
79
+ ## Main Result
80
+ #### ImageNet-1k
81
+ | Name | Resolution | #Param | #FLOPS | Throughput | Top-1 acc |
82
+ | :--------: | :----------: | :--------: | :----------: | :------------: | :------------: |
83
+ | MetaFormer-0 | 224x224 | 28M | 4.6G | 840.1 | 82.9 |
84
+ | MetaFormer-1 | 224x224 | 45M | 8.5G | 444.8 | 83.9 |
85
+ | MetaFormer-2 | 224x224 | 81M | 16.9G | 438.9 | 84.1 |
86
+ | MetaFormer-0 | 384x384 | 28M | 13.4G | 349.4 | 84.2 |
87
+ | MetaFormer-1 | 384x384 | 45M | 24.7G | 165.3 | 84.4 |
88
+ | MetaFormer-2 | 384x384 | 81M | 49.7G | 132.7 | 84.6 |
89
+ #### Fine-grained Datasets
90
+ Result on fine-grained datasets with different pre-trained model.
91
+ | Name | Pretrain | CUB | NABirds | iNat2017 | iNat2018 | Cars | Aircraft |
92
+ | :--------: | :----------: | :--------: | :----------: | :------------: | :------------: | :--------: |:--------: |
93
+ | MetaFormer-0|ImageNet-1k|89.6|89.1|75.7|79.5|95.0|91.2|
94
+ | MetaFormer-0|ImageNet-21k|89.7|89.5|75.8|79.9|94.6|91.2|
95
+ | MetaFormer-0|iNaturalist 2021|91.8|91.5|78.3|82.9|95.1|87.4|
96
+ | MetaFormer-1|ImageNet-1k|89.7|89.4|78.2|81.9|94.9|90.8|
97
+ | MetaFormer-1|ImageNet-21k|91.3|91.6|79.4|83.2|95.0|92.6|
98
+ | MetaFormer-1|iNaturalist 2021|92.3|92.7|82.0|87.5|95.0|92.5|
99
+ | MetaFormer-2|ImageNet-1k|89.7|89.7|79.0|82.6|95.0|92.4|
100
+ | MetaFormer-2|ImageNet-21k|91.8|92.2|80.4|84.3|95.1|92.9|
101
+ | MetaFormer-2|iNaturalist 2021|92.9|93.0|82.8|87.7|95.4|92.8|
102
+
103
+
104
+ Results in iNaturalist 2019, iNaturalist 2018, and iNaturalist 2021 with meta-information.
105
+ | Name | Pretrain | Meta added| iNat2017 | iNat2018 | iNat2021 |
106
+ | :--------: | :----------: | :--------: | :---------- | :------------ |:------------ |
107
+ |MetaFormer-0|ImageNet-1k|N|75.7|79.5|88.4|
108
+ |MetaFormer-0|ImageNet-1k|Y|79.8(+4.1)|85.4(+5.9)|92.6(+4.2)|
109
+ |MetaFormer-1|ImageNet-1k|N|78.2|81.9|90.2|
110
+ |MetaFormer-1|ImageNet-1k|Y|81.3(+3.1)|86.5(+4.6)|93.4(+3.2)|
111
+ |MetaFormer-2|ImageNet-1k|N|79.0|82.6|89.8|
112
+ |MetaFormer-2|ImageNet-1k|Y|82.0(+3.0)|86.8(+4.2)|93.2(+3.4)|
113
+ |MetaFormer-2|ImageNet-21k|N|80.4|84.3|90.3|
114
+ |MetaFormer-2|ImageNet-21k|Y|83.4(+3.0)|88.7(+4.4)|93.6(+3.3)|
115
+ ## Citation
116
+ ## Acknowledgement
117
+ Many thanks for [swin-transformer](https://github.com/microsoft/Swin-Transformer).A part of the code is borrowed from it.
config.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu
6
+ # --------------------------------------------------------'
7
+
8
+ import os
9
+ import yaml
10
+ from yacs.config import CfgNode as CN
11
+
12
+ _C = CN()
13
+
14
+ # Base config files
15
+ _C.BASE = ['']
16
+
17
+ # -----------------------------------------------------------------------------
18
+ # Data settings
19
+ # -----------------------------------------------------------------------------
20
+ _C.DATA = CN()
21
+ # Batch size for a single GPU, could be overwritten by command line argument
22
+ _C.DATA.BATCH_SIZE = 128
23
+ # Path to dataset, could be overwritten by command line argument
24
+ _C.DATA.DATA_PATH = ''
25
+ # Dataset name
26
+ _C.DATA.DATASET = 'imagenet'
27
+ # Input image size
28
+ _C.DATA.IMG_SIZE = 224
29
+ # Interpolation to resize image (random, bilinear, bicubic)
30
+ _C.DATA.INTERPOLATION = 'bicubic'
31
+ _C.DATA.TRAIN_INTERPOLATION = 'bicubic'
32
+ # Use zipped dataset instead of folder dataset
33
+ # could be overwritten by command line argument
34
+ _C.DATA.ZIP_MODE = False
35
+ # Cache Data in Memory, could be overwritten by command line argument
36
+ _C.DATA.CACHE_MODE = 'part'
37
+ # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
38
+ _C.DATA.PIN_MEMORY = True
39
+ # Number of data loading threads
40
+ _C.DATA.NUM_WORKERS = 8
41
+ # hdfs data dir
42
+ _C.DATA.TRAIN_PATH = None
43
+ _C.DATA.VAL_PATH = None
44
+ # arnold dataset parallel
45
+ _C.DATA.NUM_READERS = 4
46
+
47
+
48
+ #meta info
49
+ _C.DATA.ADD_META = False
50
+ _C.DATA.FUSION = 'early'
51
+ _C.DATA.MASK_PROB = 0.0
52
+ _C.DATA.MASK_TYPE = 'constant'
53
+ _C.DATA.LATE_FUSION_LAYER = -1
54
+
55
+ # -----------------------------------------------------------------------------
56
+ # Model settings
57
+ # -----------------------------------------------------------------------------
58
+ _C.MODEL = CN()
59
+ # Model type
60
+ _C.MODEL.TYPE = ''
61
+ # Model name
62
+ _C.MODEL.NAME = ''
63
+ # Checkpoint to resume, could be overwritten by command line argument
64
+ _C.MODEL.RESUME = ''
65
+ # Number of classes, overwritten in data preparation
66
+ _C.MODEL.NUM_CLASSES = 1000
67
+ # Dropout rate
68
+ _C.MODEL.DROP_RATE = 0.0
69
+ # Drop path rate
70
+ _C.MODEL.DROP_PATH_RATE = 0.1
71
+ # Label Smoothing
72
+ _C.MODEL.LABEL_SMOOTHING = 0.1
73
+ #pretrain
74
+ _C.MODEL.PRETRAINED = None
75
+ _C.MODEL.DORP_HEAD = True
76
+ _C.MODEL.DORP_META = True
77
+
78
+ _C.MODEL.ONLY_LAST_CLS = False
79
+ _C.MODEL.EXTRA_TOKEN_NUM = 1
80
+ _C.MODEL.META_DIMS = []
81
+
82
+
83
+
84
+ # -----------------------------------------------------------------------------
85
+ # Training settings
86
+ # -----------------------------------------------------------------------------
87
+ _C.TRAIN = CN()
88
+ _C.TRAIN.START_EPOCH = 0
89
+ _C.TRAIN.EPOCHS = 300
90
+ _C.TRAIN.WARMUP_EPOCHS = 20
91
+ _C.TRAIN.WEIGHT_DECAY = 0.05
92
+ _C.TRAIN.BASE_LR = 5e-4
93
+ _C.TRAIN.WARMUP_LR = 5e-7
94
+ _C.TRAIN.MIN_LR = 5e-6
95
+ # Clip gradient norm
96
+ _C.TRAIN.CLIP_GRAD = 5.0
97
+ # Auto resume from latest checkpoint
98
+ _C.TRAIN.AUTO_RESUME = True
99
+ # Gradient accumulation steps
100
+ # could be overwritten by command line argument
101
+ _C.TRAIN.ACCUMULATION_STEPS = 0
102
+ # Whether to use gradient checkpointing to save memory
103
+ # could be overwritten by command line argument
104
+ _C.TRAIN.USE_CHECKPOINT = False
105
+
106
+ # LR scheduler
107
+ _C.TRAIN.LR_SCHEDULER = CN()
108
+ _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
109
+ # Epoch interval to decay LR, used in StepLRScheduler
110
+ _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
111
+ # LR decay rate, used in StepLRScheduler
112
+ _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
113
+
114
+ # Optimizer
115
+ _C.TRAIN.OPTIMIZER = CN()
116
+ _C.TRAIN.OPTIMIZER.NAME = 'adamw'
117
+ # Optimizer Epsilon
118
+ _C.TRAIN.OPTIMIZER.EPS = 1e-8
119
+ # Optimizer Betas
120
+ _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
121
+ # SGD momentum
122
+ _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
123
+
124
+ # -----------------------------------------------------------------------------
125
+ # Augmentation settings
126
+ # -----------------------------------------------------------------------------
127
+ _C.AUG = CN()
128
+ # Color jitter factor
129
+ _C.AUG.COLOR_JITTER = 0.4
130
+ # Use AutoAugment policy. "v0" or "original"
131
+ _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
132
+ # Random erase prob
133
+ _C.AUG.REPROB = 0.25
134
+ # Random erase mode
135
+ _C.AUG.REMODE = 'pixel'
136
+ # Random erase count
137
+ _C.AUG.RECOUNT = 1
138
+ # Mixup alpha, mixup enabled if > 0
139
+ _C.AUG.MIXUP = 0.8
140
+ # Cutmix alpha, cutmix enabled if > 0
141
+ _C.AUG.CUTMIX = 1.0
142
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
143
+ _C.AUG.CUTMIX_MINMAX = None
144
+ # Probability of performing mixup or cutmix when either/both is enabled
145
+ _C.AUG.MIXUP_PROB = 1.0
146
+ # Probability of switching to cutmix when both mixup and cutmix enabled
147
+ _C.AUG.MIXUP_SWITCH_PROB = 0.5
148
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
149
+ _C.AUG.MIXUP_MODE = 'batch'
150
+
151
+ # -----------------------------------------------------------------------------
152
+ # Testing settings
153
+ # -----------------------------------------------------------------------------
154
+ _C.TEST = CN()
155
+ # Whether to use center crop when testing
156
+ _C.TEST.CROP = True
157
+
158
+ # -----------------------------------------------------------------------------
159
+ # Misc
160
+ # -----------------------------------------------------------------------------
161
+ # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
162
+ # overwritten by command line argument
163
+ _C.AMP_OPT_LEVEL = ''
164
+ # Path to output folder, overwritten by command line argument
165
+ _C.OUTPUT = ''
166
+ # Tag of experiment, overwritten by command line argument
167
+ _C.TAG = 'default'
168
+ # Frequency to save checkpoint
169
+ _C.SAVE_FREQ = 1
170
+ # Frequency to logging info
171
+ _C.PRINT_FREQ = 10
172
+ # Fixed random seed
173
+ _C.SEED = 0
174
+ # Perform evaluation only, overwritten by command line argument
175
+ _C.EVAL_MODE = False
176
+ # Test throughput only, overwritten by command line argument
177
+ _C.THROUGHPUT_MODE = False
178
+ # local rank for DistributedDataParallel, given by command line argument
179
+ _C.LOCAL_RANK = 0
180
+
181
+
182
+
183
+
184
+ def _update_config_from_file(config, cfg_file):
185
+ config.defrost()
186
+ with open(cfg_file, 'r') as f:
187
+ yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
188
+
189
+ for cfg in yaml_cfg.setdefault('BASE', ['']):
190
+ if cfg:
191
+ _update_config_from_file(
192
+ config, os.path.join(os.path.dirname(cfg_file), cfg)
193
+ )
194
+ print('=> merge config from {}'.format(cfg_file))
195
+ config.merge_from_file(cfg_file)
196
+ config.freeze()
197
+
198
+
199
+ def update_config(config, args):
200
+ _update_config_from_file(config, args.cfg)
201
+
202
+ config.defrost()
203
+ if args.opts:
204
+ config.merge_from_list(args.opts)
205
+
206
+ # merge from specific arguments
207
+ if args.batch_size:
208
+ config.DATA.BATCH_SIZE = args.batch_size
209
+ if args.data_path:
210
+ config.DATA.DATA_PATH = args.data_path
211
+ if args.zip:
212
+ config.DATA.ZIP_MODE = True
213
+ if args.cache_mode:
214
+ config.DATA.CACHE_MODE = args.cache_mode
215
+ if args.resume:
216
+ config.MODEL.RESUME = args.resume
217
+ if args.accumulation_steps:
218
+ config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
219
+ if args.use_checkpoint:
220
+ config.TRAIN.USE_CHECKPOINT = True
221
+ if args.amp_opt_level:
222
+ config.AMP_OPT_LEVEL = args.amp_opt_level
223
+ if args.output:
224
+ config.OUTPUT = args.output
225
+ if args.tag:
226
+ config.TAG = args.tag
227
+ if args.eval:
228
+ config.EVAL_MODE = True
229
+ if args.throughput:
230
+ config.THROUGHPUT_MODE = True
231
+
232
+
233
+ if args.num_workers is not None:
234
+ config.DATA.NUM_WORKERS = args.num_workers
235
+
236
+ #set lr and weight decay
237
+ if args.lr is not None:
238
+ config.TRAIN.BASE_LR = args.lr
239
+ if args.min_lr is not None:
240
+ config.TRAIN.MIN_LR = args.min_lr
241
+ if args.warmup_lr is not None:
242
+ config.TRAIN.WARMUP_LR = args.warmup_lr
243
+ if args.warmup_epochs is not None:
244
+ config.TRAIN.WARMUP_EPOCHS = args.warmup_epochs
245
+ if args.weight_decay is not None:
246
+ config.TRAIN.WEIGHT_DECAY = args.weight_decay
247
+
248
+ if args.epochs is not None:
249
+ config.TRAIN.EPOCHS = args.epochs
250
+ if args.dataset is not None:
251
+ config.DATA.DATASET = args.dataset
252
+ if args.lr_scheduler_name is not None:
253
+ config.TRAIN.LR_SCHEDULER.NAME = args.lr_scheduler_name
254
+ if args.pretrain is not None:
255
+ config.MODEL.PRETRAINED = args.pretrain
256
+
257
+ # set local rank for distributed training
258
+ config.LOCAL_RANK = args.local_rank
259
+
260
+ # output folder
261
+ config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
262
+
263
+ config.freeze()
264
+
265
+
266
+ def get_config(args):
267
+ """Get a yacs CfgNode object with default values."""
268
+ # Return a clone so that the defaults will not be altered
269
+ # This is for the "local variable" use pattern
270
+ config = _C.clone()
271
+ update_config(config, args)
272
+
273
+ return config
configs/MetaFG_0_224.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ DATA:
2
+ IMG_SIZE: 224
3
+ MODEL:
4
+ TYPE: MetaFG
5
+ NAME: MetaFG_0
configs/MetaFG_1_224.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ DATA:
2
+ IMG_SIZE: 224
3
+ MODEL:
4
+ TYPE: MetaFG
5
+ NAME: MetaFG_1
configs/MetaFG_2_224.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ DATA:
2
+ IMG_SIZE: 224
3
+ MODEL:
4
+ TYPE: MetaFG
5
+ NAME: MetaFG_2
configs/MetaFG_meta_0_224.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ IMG_SIZE: 224
3
+ ADD_META: True
4
+ MODEL:
5
+ TYPE: MetaFG
6
+ NAME: MetaFG_meta_0
7
+ EXTRA_TOKEN_NUM: 3
8
+ META_DIMS: [ 4, 3 ]
configs/MetaFG_meta_1_224.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ IMG_SIZE: 224
3
+ ADD_META: True
4
+ MODEL:
5
+ TYPE: MetaFG
6
+ NAME: MetaFG_meta_1
7
+ EXTRA_TOKEN_NUM: 3
8
+ META_DIMS: [ 4, 3 ]
configs/MetaFG_meta_2_224.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ IMG_SIZE: 224
3
+ ADD_META: True
4
+ MODEL:
5
+ TYPE: MetaFG
6
+ NAME: MetaFG_meta_2
7
+ EXTRA_TOKEN_NUM: 3
8
+ META_DIMS: [ 4, 3 ]
configs/MetaFG_meta_attribute_1_224.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ IMG_SIZE: 224
3
+ ADD_META: True
4
+ MODEL:
5
+ TYPE: MetaFG
6
+ NAME: MetaFG_meta_1
7
+ EXTRA_TOKEN_NUM: 2
8
+ META_DIMS: [ 312, ]
configs/MetaFG_meta_bert_0_224.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ IMG_SIZE: 224
3
+ ADD_META: True
4
+ MODEL:
5
+ TYPE: MetaFG
6
+ NAME: MetaFG_meta_0
7
+ EXTRA_TOKEN_NUM: 33
8
+ META_DIMS: [ 768, ]
configs/MetaFG_meta_bert_1_224.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ IMG_SIZE: 224
3
+ ADD_META: True
4
+ MODEL:
5
+ TYPE: MetaFG
6
+ NAME: MetaFG_meta_1
7
+ EXTRA_TOKEN_NUM: 33
8
+ META_DIMS: [ 768, ]
data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .build import build_loader
data/build.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import torch
10
+ import numpy as np
11
+ import torch.distributed as dist
12
+ from torchvision import datasets, transforms
13
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
14
+ from timm.data import Mixup
15
+ from timm.data import create_transform
16
+ from timm.data.transforms import _pil_interp
17
+
18
+ from .cached_image_folder import CachedImageFolder
19
+ from .samplers import SubsetRandomSampler
20
+ from .dataset_fg import DatasetMeta
21
+ def build_loader(config):
22
+ config.defrost()
23
+ dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
24
+ config.freeze()
25
+ print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
26
+ dataset_val, _ = build_dataset(is_train=False, config=config)
27
+ print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
28
+
29
+ num_tasks = dist.get_world_size()
30
+ global_rank = dist.get_rank()
31
+ if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
32
+ indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
33
+ sampler_train = SubsetRandomSampler(indices)
34
+ else:
35
+ sampler_train = torch.utils.data.DistributedSampler(
36
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
37
+ )
38
+
39
+ indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size())
40
+ sampler_val = SubsetRandomSampler(indices)
41
+
42
+ data_loader_train = torch.utils.data.DataLoader(
43
+ dataset_train, sampler=sampler_train,
44
+ batch_size=config.DATA.BATCH_SIZE,
45
+ num_workers=config.DATA.NUM_WORKERS,
46
+ pin_memory=config.DATA.PIN_MEMORY,
47
+ drop_last=True,
48
+ )
49
+
50
+ data_loader_val = torch.utils.data.DataLoader(
51
+ dataset_val, sampler=sampler_val,
52
+ batch_size=config.DATA.BATCH_SIZE,
53
+ shuffle=False,
54
+ num_workers=config.DATA.NUM_WORKERS,
55
+ pin_memory=config.DATA.PIN_MEMORY,
56
+ drop_last=False
57
+ )
58
+
59
+ # setup mixup / cutmix
60
+ mixup_fn = None
61
+ mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
62
+ if mixup_active:
63
+ mixup_fn = Mixup(
64
+ mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
65
+ prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
66
+ label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
67
+
68
+ return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
69
+
70
+
71
+ def build_dataset(is_train, config):
72
+ transform = build_transform(is_train, config)
73
+ if config.DATA.DATASET == 'imagenet':
74
+ prefix = 'train' if is_train else 'val'
75
+ if config.DATA.ZIP_MODE:
76
+ ann_file = prefix + "_map.txt"
77
+ prefix = prefix + ".zip@/"
78
+ dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
79
+ cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
80
+ else:
81
+ # root = os.path.join(config.DATA.DATA_PATH, prefix)
82
+ root = './datasets/imagenet'
83
+ dataset = datasets.ImageFolder(root, transform=transform)
84
+ nb_classes = 1000
85
+ elif config.DATA.DATASET == 'inaturelist2021':
86
+ root = './datasets/inaturelist2021'
87
+ dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET,
88
+ class_ratio=config.DATA.CLASS_RATIO,per_sample=config.DATA.PER_SAMPLE)
89
+ nb_classes = 10000
90
+ elif config.DATA.DATASET == 'inaturelist2021_mini':
91
+ root = './datasets/inaturelist2021_mini'
92
+ dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
93
+ nb_classes = 10000
94
+ elif config.DATA.DATASET == 'inaturelist2017':
95
+ root = './datasets/inaturelist2017'
96
+ dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
97
+ nb_classes = 5089
98
+ elif config.DATA.DATASET == 'inaturelist2018':
99
+ root = './datasets/inaturelist2018'
100
+ dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
101
+ nb_classes = 8142
102
+ elif config.DATA.DATASET == 'cub-200':
103
+ root = './datasets/cub-200'
104
+ dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
105
+ nb_classes = 200
106
+ elif config.DATA.DATASET == 'stanfordcars':
107
+ root = './datasets/stanfordcars'
108
+ dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
109
+ nb_classes = 196
110
+ elif config.DATA.DATASET == 'oxfordflower':
111
+ root = './datasets/oxfordflower'
112
+ dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
113
+ nb_classes = 102
114
+ elif config.DATA.DATASET == 'stanforddogs':
115
+ root = './datasets/stanforddogs'
116
+ dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
117
+ nb_classes = 120
118
+ elif config.DATA.DATASET == 'nabirds':
119
+ root = './datasets/nabirds'
120
+ dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
121
+ nb_classes = 555
122
+ elif config.DATA.DATASET == 'aircraft':
123
+ root = './datasets/aircraft'
124
+ dataset = DatasetMeta(root=root,transform=transform,train=is_train,aux_info=config.DATA.ADD_META,dataset=config.DATA.DATASET)
125
+ nb_classes = 100
126
+ else:
127
+ raise NotImplementedError("We only support ImageNet and inaturelist.")
128
+
129
+ return dataset, nb_classes
130
+
131
+
132
+ def build_transform(is_train, config):
133
+ resize_im = config.DATA.IMG_SIZE > 32
134
+ if is_train:
135
+ # this should always dispatch to transforms_imagenet_train
136
+ transform = create_transform(
137
+ input_size=config.DATA.IMG_SIZE,
138
+ is_training=True,
139
+ color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
140
+ auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
141
+ re_prob=config.AUG.REPROB,
142
+ re_mode=config.AUG.REMODE,
143
+ re_count=config.AUG.RECOUNT,
144
+ interpolation=config.DATA.TRAIN_INTERPOLATION,
145
+ )
146
+ if not resize_im:
147
+ # replace RandomResizedCropAndInterpolation with
148
+ # RandomCrop
149
+ transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
150
+ return transform
151
+
152
+ t = []
153
+ if resize_im:
154
+ if config.TEST.CROP:
155
+ size = int((256 / 224) * config.DATA.IMG_SIZE)
156
+ t.append(
157
+ transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
158
+ # to maintain same ratio w.r.t. 224 images
159
+ )
160
+ t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
161
+ else:
162
+ t.append(
163
+ transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
164
+ interpolation=_pil_interp(config.DATA.INTERPOLATION))
165
+ )
166
+
167
+ t.append(transforms.ToTensor())
168
+ t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
169
+ return transforms.Compose(t)
data/cached_image_folder.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu
6
+ # --------------------------------------------------------
7
+
8
+ import io
9
+ import os
10
+ import time
11
+ import torch.distributed as dist
12
+ import torch.utils.data as data
13
+ from PIL import Image
14
+
15
+ from .zipreader import is_zip_path, ZipReader
16
+
17
+
18
+ def has_file_allowed_extension(filename, extensions):
19
+ """Checks if a file is an allowed extension.
20
+ Args:
21
+ filename (string): path to a file
22
+ Returns:
23
+ bool: True if the filename ends with a known image extension
24
+ """
25
+ filename_lower = filename.lower()
26
+ return any(filename_lower.endswith(ext) for ext in extensions)
27
+
28
+
29
+ def find_classes(dir):
30
+ classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
31
+ classes.sort()
32
+ class_to_idx = {classes[i]: i for i in range(len(classes))}
33
+ return classes, class_to_idx
34
+
35
+
36
+ def make_dataset(dir, class_to_idx, extensions):
37
+ images = []
38
+ dir = os.path.expanduser(dir)
39
+ for target in sorted(os.listdir(dir)):
40
+ d = os.path.join(dir, target)
41
+ if not os.path.isdir(d):
42
+ continue
43
+
44
+ for root, _, fnames in sorted(os.walk(d)):
45
+ for fname in sorted(fnames):
46
+ if has_file_allowed_extension(fname, extensions):
47
+ path = os.path.join(root, fname)
48
+ item = (path, class_to_idx[target])
49
+ images.append(item)
50
+
51
+ return images
52
+
53
+
54
+ def make_dataset_with_ann(ann_file, img_prefix, extensions):
55
+ images = []
56
+ with open(ann_file, "r") as f:
57
+ contents = f.readlines()
58
+ for line_str in contents:
59
+ path_contents = [c for c in line_str.split('\t')]
60
+ im_file_name = path_contents[0]
61
+ class_index = int(path_contents[1])
62
+
63
+ assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions
64
+ item = (os.path.join(img_prefix, im_file_name), class_index)
65
+
66
+ images.append(item)
67
+
68
+ return images
69
+
70
+
71
+ class DatasetFolder(data.Dataset):
72
+ """A generic data loader where the samples are arranged in this way: ::
73
+ root/class_x/xxx.ext
74
+ root/class_x/xxy.ext
75
+ root/class_x/xxz.ext
76
+ root/class_y/123.ext
77
+ root/class_y/nsdf3.ext
78
+ root/class_y/asd932_.ext
79
+ Args:
80
+ root (string): Root directory path.
81
+ loader (callable): A function to load a sample given its path.
82
+ extensions (list[string]): A list of allowed extensions.
83
+ transform (callable, optional): A function/transform that takes in
84
+ a sample and returns a transformed version.
85
+ E.g, ``transforms.RandomCrop`` for images.
86
+ target_transform (callable, optional): A function/transform that takes
87
+ in the target and transforms it.
88
+ Attributes:
89
+ samples (list): List of (sample path, class_index) tuples
90
+ """
91
+
92
+ def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
93
+ cache_mode="no"):
94
+ # image folder mode
95
+ if ann_file == '':
96
+ _, class_to_idx = find_classes(root)
97
+ samples = make_dataset(root, class_to_idx, extensions)
98
+ # zip mode
99
+ else:
100
+ samples = make_dataset_with_ann(os.path.join(root, ann_file),
101
+ os.path.join(root, img_prefix),
102
+ extensions)
103
+
104
+ if len(samples) == 0:
105
+ raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
106
+ "Supported extensions are: " + ",".join(extensions)))
107
+
108
+ self.root = root
109
+ self.loader = loader
110
+ self.extensions = extensions
111
+
112
+ self.samples = samples
113
+ self.labels = [y_1k for _, y_1k in samples]
114
+ self.classes = list(set(self.labels))
115
+
116
+ self.transform = transform
117
+ self.target_transform = target_transform
118
+
119
+ self.cache_mode = cache_mode
120
+ if self.cache_mode != "no":
121
+ self.init_cache()
122
+
123
+ def init_cache(self):
124
+ assert self.cache_mode in ["part", "full"]
125
+ n_sample = len(self.samples)
126
+ global_rank = dist.get_rank()
127
+ world_size = dist.get_world_size()
128
+
129
+ samples_bytes = [None for _ in range(n_sample)]
130
+ start_time = time.time()
131
+ for index in range(n_sample):
132
+ if index % (n_sample // 10) == 0:
133
+ t = time.time() - start_time
134
+ print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')
135
+ start_time = time.time()
136
+ path, target = self.samples[index]
137
+ if self.cache_mode == "full":
138
+ samples_bytes[index] = (ZipReader.read(path), target)
139
+ elif self.cache_mode == "part" and index % world_size == global_rank:
140
+ samples_bytes[index] = (ZipReader.read(path), target)
141
+ else:
142
+ samples_bytes[index] = (path, target)
143
+ self.samples = samples_bytes
144
+
145
+ def __getitem__(self, index):
146
+ """
147
+ Args:
148
+ index (int): Index
149
+ Returns:
150
+ tuple: (sample, target) where target is class_index of the target class.
151
+ """
152
+ path, target = self.samples[index]
153
+ sample = self.loader(path)
154
+ if self.transform is not None:
155
+ sample = self.transform(sample)
156
+ if self.target_transform is not None:
157
+ target = self.target_transform(target)
158
+
159
+ return sample, target
160
+
161
+ def __len__(self):
162
+ return len(self.samples)
163
+
164
+ def __repr__(self):
165
+ fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
166
+ fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
167
+ fmt_str += ' Root Location: {}\n'.format(self.root)
168
+ tmp = ' Transforms (if any): '
169
+ fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
170
+ tmp = ' Target Transforms (if any): '
171
+ fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
172
+ return fmt_str
173
+
174
+
175
+ IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
176
+
177
+
178
+ def pil_loader(path):
179
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
180
+ if isinstance(path, bytes):
181
+ img = Image.open(io.BytesIO(path))
182
+ elif is_zip_path(path):
183
+ data = ZipReader.read(path)
184
+ img = Image.open(io.BytesIO(data))
185
+ else:
186
+ with open(path, 'rb') as f:
187
+ img = Image.open(f)
188
+ return img.convert('RGB')
189
+
190
+
191
+ def accimage_loader(path):
192
+ import accimage
193
+ try:
194
+ return accimage.Image(path)
195
+ except IOError:
196
+ # Potentially a decoding problem, fall back to PIL.Image
197
+ return pil_loader(path)
198
+
199
+
200
+ def default_img_loader(path):
201
+ from torchvision import get_image_backend
202
+ if get_image_backend() == 'accimage':
203
+ return accimage_loader(path)
204
+ else:
205
+ return pil_loader(path)
206
+
207
+
208
+ class CachedImageFolder(DatasetFolder):
209
+ """A generic data loader where the images are arranged in this way: ::
210
+ root/dog/xxx.png
211
+ root/dog/xxy.png
212
+ root/dog/xxz.png
213
+ root/cat/123.png
214
+ root/cat/nsdf3.png
215
+ root/cat/asd932_.png
216
+ Args:
217
+ root (string): Root directory path.
218
+ transform (callable, optional): A function/transform that takes in an PIL image
219
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
220
+ target_transform (callable, optional): A function/transform that takes in the
221
+ target and transforms it.
222
+ loader (callable, optional): A function to load an image given its path.
223
+ Attributes:
224
+ imgs (list): List of (image path, class_index) tuples
225
+ """
226
+
227
+ def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
228
+ loader=default_img_loader, cache_mode="no"):
229
+ super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
230
+ ann_file=ann_file, img_prefix=img_prefix,
231
+ transform=transform, target_transform=target_transform,
232
+ cache_mode=cache_mode)
233
+ self.imgs = self.samples
234
+
235
+ def __getitem__(self, index):
236
+ """
237
+ Args:
238
+ index (int): Index
239
+ Returns:
240
+ tuple: (image, target) where target is class_index of the target class.
241
+ """
242
+ path, target = self.samples[index]
243
+ image = self.loader(path)
244
+ if self.transform is not None:
245
+ img = self.transform(image)
246
+ else:
247
+ img = image
248
+ if self.target_transform is not None:
249
+ target = self.target_transform(target)
250
+
251
+ return img, target
data/dataset_fg.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ import os
4
+ import re
5
+ import csv
6
+ import json
7
+ import torch
8
+ import tarfile
9
+ import pickle
10
+ import numpy as np
11
+ import pandas as pd
12
+ import random
13
+ random.seed(2021)
14
+ from PIL import Image
15
+ from scipy import io as scio
16
+ from math import radians, cos, sin, asin, sqrt, pi
17
+ IMG_EXTENSIONS = ['.png', '.jpg', '.jpeg']
18
+ def get_spatial_info(latitude,longitude):
19
+ if latitude and longitude:
20
+ latitude = radians(latitude)
21
+ longitude = radians(longitude)
22
+ x = cos(latitude)*cos(longitude)
23
+ y = cos(latitude)*sin(longitude)
24
+ z = sin(latitude)
25
+ return [x,y,z]
26
+ else:
27
+ return [0,0,0]
28
+ def get_temporal_info(date,miss_hour=False):
29
+ try:
30
+ if date:
31
+ if miss_hour:
32
+ pattern = re.compile(r'(\d*)-(\d*)-(\d*)', re.I)
33
+ else:
34
+ pattern = re.compile(r'(\d*)-(\d*)-(\d*) (\d*):(\d*):(\d*)', re.I)
35
+ m = pattern.match(date.strip())
36
+
37
+ if m:
38
+ year = int(m.group(1))
39
+ month = int(m.group(2))
40
+ day = int(m.group(3))
41
+ x_month = sin(2*pi*month/12)
42
+ y_month = cos(2*pi*month/12)
43
+ if miss_hour:
44
+ x_hour = 0
45
+ y_hour = 0
46
+ else:
47
+ hour = int(m.group(4))
48
+ x_hour = sin(2*pi*hour/24)
49
+ y_hour = cos(2*pi*hour/24)
50
+ return [x_month,y_month,x_hour,y_hour]
51
+ else:
52
+ return [0,0,0,0]
53
+ else:
54
+ return [0,0,0,0]
55
+ except:
56
+ return [0,0,0,0]
57
+ def load_file(root,dataset):
58
+ if dataset == 'inaturelist2017':
59
+ year_flag = 7
60
+ elif dataset == 'inaturelist2018':
61
+ year_flag = 8
62
+
63
+ if dataset == 'inaturelist2018':
64
+ with open(os.path.join(root,'categories.json'),'r') as f:
65
+ map_label = json.load(f)
66
+ map_2018 = dict()
67
+ for _map in map_label:
68
+ map_2018[int(_map['id'])] = _map['name'].strip().lower()
69
+ with open(os.path.join(root,f'val201{year_flag}_locations.json'),'r') as f:
70
+ val_location = json.load(f)
71
+ val_id2meta = dict()
72
+ for meta_info in val_location:
73
+ val_id2meta[meta_info['id']] = meta_info
74
+ with open(os.path.join(root,f'train201{year_flag}_locations.json'),'r') as f:
75
+ train_location = json.load(f)
76
+ train_id2meta = dict()
77
+ for meta_info in train_location:
78
+ train_id2meta[meta_info['id']] = meta_info
79
+ with open(os.path.join(root,f'val201{year_flag}.json'),'r') as f:
80
+ val_class_info = json.load(f)
81
+ with open(os.path.join(root,f'train201{year_flag}.json'),'r') as f:
82
+ train_class_info = json.load(f)
83
+
84
+ if dataset == 'inaturelist2017':
85
+ categories_2017 = [x['name'].strip().lower() for x in val_class_info['categories']]
86
+ class_to_idx = {c: idx for idx, c in enumerate(categories_2017)}
87
+ id2label = dict()
88
+ for categorie in val_class_info['categories']:
89
+ id2label[int(categorie['id'])] = categorie['name'].strip().lower()
90
+ elif dataset == 'inaturelist2018':
91
+ categories_2018 = [x['name'].strip().lower() for x in map_label]
92
+ class_to_idx = {c: idx for idx, c in enumerate(categories_2018)}
93
+ id2label = dict()
94
+ for categorie in val_class_info['categories']:
95
+ name = map_2018[int(categorie['name'])]
96
+ id2label[int(categorie['id'])] = name.strip().lower()
97
+
98
+ return train_class_info,train_id2meta,val_class_info,val_id2meta,class_to_idx,id2label
99
+ def find_images_and_targets_cub200(root,dataset,istrain=False,aux_info=False):
100
+ imageid2label = {}
101
+ with open(os.path.join(os.path.join(root,'CUB_200_2011'),'image_class_labels.txt'),'r') as f:
102
+ for line in f:
103
+ image_id,label = line.split()
104
+ imageid2label[int(image_id)] = int(label)-1
105
+ imageid2split = {}
106
+ with open(os.path.join(os.path.join(root,'CUB_200_2011'),'train_test_split.txt'),'r') as f:
107
+ for line in f:
108
+ image_id,split = line.split()
109
+ imageid2split[int(image_id)] = int(split)
110
+ images_and_targets = []
111
+ images_info = []
112
+ images_root = os.path.join(os.path.join(root,'CUB_200_2011'),'images')
113
+ bert_embedding_root = os.path.join(root,'bert_embedding_cub')
114
+ text_root = os.path.join(root,'text_c10')
115
+ with open(os.path.join(os.path.join(root,'CUB_200_2011'),'images.txt'),'r') as f:
116
+ for line in f:
117
+ image_id,file_name = line.split()
118
+ file_path = os.path.join(images_root,file_name)
119
+ target = imageid2label[int(image_id)]
120
+ if aux_info:
121
+ with open(os.path.join(bert_embedding_root,file_name.replace('.jpg','.pickle')),'rb') as f_bert:
122
+ bert_embedding = pickle.load(f_bert)
123
+ bert_embedding = bert_embedding['embedding_words']
124
+ text_list = []
125
+ with open(os.path.join(text_root,file_name.replace('.jpg','.txt')),'r') as f_text:
126
+ for line in f_text:
127
+ line = line.encode(encoding='UTF-8',errors='strict')
128
+ line = line.replace(b'\xef\xbf\xbd\xef\xbf\xbd',b' ')
129
+ line = line.decode('UTF-8','strict')
130
+ text_list.append(line)
131
+ if istrain and imageid2split[int(image_id)]==1:
132
+ if aux_info:
133
+ images_and_targets.append([file_path,target,bert_embedding])
134
+ images_info.append({'text_list':text_list})
135
+ else:
136
+ images_and_targets.append([file_path,target])
137
+ elif not istrain and imageid2split[int(image_id)]==0:
138
+ if aux_info:
139
+ images_and_targets.append([file_path,target,bert_embedding])
140
+ images_info.append({'text_list':text_list})
141
+ else:
142
+ images_and_targets.append([file_path,target])
143
+ return images_and_targets,None,images_info
144
+ def find_images_and_targets_cub200_attribute(root,dataset,istrain=False,aux_info=False):
145
+ imageid2label = {}
146
+ with open(os.path.join(os.path.join(root,'CUB_200_2011'),'image_class_labels.txt'),'r') as f:
147
+ for line in f:
148
+ image_id,label = line.split()
149
+ imageid2label[int(image_id)] = int(label)-1
150
+ imageid2split = {}
151
+ with open(os.path.join(os.path.join(root,'CUB_200_2011'),'train_test_split.txt'),'r') as f:
152
+ for line in f:
153
+ image_id,split = line.split()
154
+ imageid2split[int(image_id)] = int(split)
155
+ images_and_targets = []
156
+ images_info = []
157
+ images_root = os.path.join(os.path.join(root,'CUB_200_2011'),'images')
158
+ attributes_root = os.path.join(os.path.join(root,'CUB_200_2011'),'attributes')
159
+ imageid2attribute = {}
160
+ with open(os.path.join(attributes_root,'image_attribute_labels.txt'),'r') as f:
161
+ for line in f:
162
+ if len(line.split())==6:
163
+ image_id,attribute_id,is_present,_,_,_ = line.split()
164
+ else:
165
+ image_id,attribute_id,is_present,certainty_id,time = line.split()
166
+ if int(image_id) not in imageid2attribute:
167
+ imageid2attribute[int(image_id)] = [0 for i in range(312)]
168
+ imageid2attribute[int(image_id)][int(attribute_id)-1] = int(is_present)
169
+ with open(os.path.join(os.path.join(root,'CUB_200_2011'),'images.txt'),'r') as f:
170
+ for line in f:
171
+ image_id,file_name = line.split()
172
+ file_path = os.path.join(images_root,file_name)
173
+ target = imageid2label[int(image_id)]
174
+ if aux_info:
175
+ pass
176
+ if istrain and imageid2split[int(image_id)]==1:
177
+ if aux_info:
178
+ images_and_targets.append([file_path,target,imageid2attribute[int(image_id)]])
179
+ images_info.append({'attributes':imageid2attribute[int(image_id)]})
180
+ else:
181
+ images_and_targets.append([file_path,target])
182
+ elif not istrain and imageid2split[int(image_id)]==0:
183
+ if aux_info:
184
+ images_and_targets.append([file_path,target,imageid2attribute[int(image_id)]])
185
+ images_info.append({'attributes':imageid2attribute[int(image_id)]})
186
+ else:
187
+ images_and_targets.append([file_path,target])
188
+ return images_and_targets,None,images_info
189
+ def find_images_and_targets_oxfordflower(root,dataset,istrain=False,aux_info=False):
190
+ imagelabels = scio.loadmat(os.path.join(root,'imagelabels.mat'))
191
+ imagelabels = imagelabels['labels'][0]
192
+ train_val_split = scio.loadmat(os.path.join(root,'setid.mat'))
193
+ train_data = train_val_split['trnid'][0].tolist()
194
+ val_data = train_val_split['valid'][0].tolist()
195
+ test_data = train_val_split['tstid'][0].tolist()
196
+ images_and_targets = []
197
+ images_info = []
198
+ images_root = os.path.join(root,'jpg')
199
+ bert_embedding_root = os.path.join(root,'bert_embedding_flower')
200
+ if istrain:
201
+ all_data = train_data+val_data
202
+ else:
203
+ all_data = test_data
204
+ for data in all_data:
205
+ file_path = os.path.join(images_root,f'image_{str(data).zfill(5)}.jpg')
206
+ target = int(imagelabels[int(data)-1])-1
207
+ if aux_info:
208
+ with open(os.path.join(bert_embedding_root,f'image_{str(data).zfill(5)}.pickle'),'rb') as f_bert:
209
+ bert_embedding = pickle.load(f_bert)
210
+ bert_embedding = bert_embedding['embedding_full']
211
+ images_and_targets.append([file_path,target,bert_embedding])
212
+ else:
213
+ images_and_targets.append([file_path,target])
214
+ return images_and_targets,None,images_info
215
+ def find_images_and_targets_stanforddogs(root,dataset,istrain=False,aux_info=False):
216
+ if istrain:
217
+ anno_data = scio.loadmat(os.path.join(root,'train_list.mat'))
218
+ else:
219
+ anno_data = scio.loadmat(os.path.join(root,'test_list.mat'))
220
+ images_and_targets = []
221
+ images_info = []
222
+ for file,label in zip(anno_data['file_list'],anno_data['labels']):
223
+ file_path = os.path.join(os.path.join(root,'Images'),file[0][0])
224
+ target = int(label[0])-1
225
+ images_and_targets.append([file_path,target])
226
+ return images_and_targets,None,images_info
227
+ def find_images_and_targets_nabirds(root,dataset,istrain=False,aux_info=False):
228
+ root = os.path.join(root,'nabirds')
229
+ image_paths = pd.read_csv(os.path.join(root,'images.txt'),sep=' ',names=['img_id','filepath'])
230
+ image_class_labels = pd.read_csv(os.path.join(root,'image_class_labels.txt'),sep=' ',names=['img_id','target'])
231
+ label_list = list(set(image_class_labels['target']))
232
+ label_list = sorted(label_list)
233
+ label_map = {k: i for i, k in enumerate(label_list)}
234
+ train_test_split = pd.read_csv(os.path.join(root, 'train_test_split.txt'), sep=' ', names=['img_id', 'is_training_img'])
235
+ data = image_paths.merge(image_class_labels, on='img_id')
236
+ data = data.merge(train_test_split, on='img_id')
237
+ if istrain:
238
+ data = data[data.is_training_img == 1]
239
+ else:
240
+ data = data[data.is_training_img == 0]
241
+ images_and_targets = []
242
+ images_info = []
243
+ for index,row in data.iterrows():
244
+ file_path = os.path.join(os.path.join(root,'images'),row['filepath'])
245
+ target = int(label_map[row['target']])
246
+ images_and_targets.append([file_path,target])
247
+ return images_and_targets,None,images_info
248
+ def find_images_and_targets_stanfordcars_v1(root,dataset,istrain=False,aux_info=False):
249
+ if istrain:
250
+ flag = 'train'
251
+ else:
252
+ flag = 'test'
253
+ if istrain:
254
+ anno_data = scio.loadmat(os.path.join(os.path.join(root,'devkit'),f'cars_{flag}_annos.mat'))
255
+ else:
256
+ anno_data = scio.loadmat(os.path.join(os.path.join(root,'devkit'),f'cars_{flag}_annos_withlabels.mat'))
257
+ annotation = anno_data['annotations']
258
+ images_and_targets = []
259
+ images_info = []
260
+ for r in annotation[0]:
261
+ _,_,_,_,label,name = r
262
+ file_path = os.path.join(os.path.join(root,f'cars_{flag}'),name[0])
263
+ target = int(label[0][0])-1
264
+ images_and_targets.append([file_path,target])
265
+ return images_and_targets,None,images_info
266
+ def find_images_and_targets_stanfordcars(root,dataset,istrain=False,aux_info=False):
267
+ anno_data = scio.loadmat(os.path.join(root,'cars_annos.mat'))
268
+ annotation = anno_data['annotations']
269
+ images_and_targets = []
270
+ images_info = []
271
+ for r in annotation[0]:
272
+ name,_,_,_,_,label,split = r
273
+ file_path = os.path.join(root,name[0])
274
+ target = int(label[0][0])-1
275
+ if istrain and int(split[0][0])==0:
276
+ images_and_targets.append([file_path,target])
277
+ elif not istrain and int(split[0][0])==1:
278
+ images_and_targets.append([file_path,target])
279
+ return images_and_targets,None,images_info
280
+ def find_images_and_targets_aircraft(root,dataset,istrain=False,aux_info=False):
281
+ file_root = os.path.join(root,'fgvc-aircraft-2013b','data')
282
+ if istrain:
283
+ data_file = os.path.join(file_root,'images_variant_trainval.txt')
284
+ else:
285
+ data_file = os.path.join(file_root,'images_variant_test.txt')
286
+ classes = set()
287
+ with open(data_file,'r') as f:
288
+ for line in f:
289
+ class_name = '_'.join(line.split()[1:])
290
+ classes.add(class_name)
291
+ classes = sorted(list(classes))
292
+ class_to_idx = {name:ind for ind,name in enumerate(classes)}
293
+
294
+ images_and_targets = []
295
+ images_info = []
296
+ with open(data_file,'r') as f:
297
+ images_root = os.path.join(file_root,'images')
298
+ for line in f:
299
+ image_file = line.split()[0]
300
+ class_name = '_'.join(line.split()[1:])
301
+ file_path = os.path.join(images_root,f'{image_file}.jpg')
302
+ target = class_to_idx[class_name]
303
+ images_and_targets.append([file_path,target])
304
+ return images_and_targets,class_to_idx,images_info
305
+
306
+ def find_images_and_targets_2017_2018(root,dataset,istrain=False,aux_info=False):
307
+ train_class_info,train_id2meta,val_class_info,val_id2meta,class_to_idx,id2label = load_file(root,dataset)
308
+ miss_hour = (dataset == 'inaturelist2017')
309
+
310
+ class_info = train_class_info if istrain else val_class_info
311
+ id2meta = train_id2meta if istrain else val_id2meta
312
+ images_and_targets = []
313
+ images_info = []
314
+ if aux_info:
315
+ temporal_info = []
316
+ spatial_info = []
317
+ for image,annotation in zip(class_info['images'],class_info['annotations']):
318
+ file_path = os.path.join(root,image['file_name'])
319
+ id_name = id2label[int(annotation['category_id'])]
320
+ target = class_to_idx[id_name]
321
+ image_id = image['id']
322
+ date = id2meta[image_id]['date']
323
+ latitude = id2meta[image_id]['lat']
324
+ longitude = id2meta[image_id]['lon']
325
+ location_uncertainty = id2meta[image_id]['loc_uncert']
326
+ images_info.append({'date':date,
327
+ 'latitude':latitude,
328
+ 'longitude':longitude,
329
+ 'location_uncertainty':location_uncertainty,
330
+ 'target':target})
331
+ if aux_info:
332
+ temporal_info = get_temporal_info(date,miss_hour=miss_hour)
333
+ spatial_info = get_spatial_info(latitude,longitude)
334
+ images_and_targets.append((file_path,target,temporal_info+spatial_info))
335
+ else:
336
+ images_and_targets.append((file_path,target))
337
+ return images_and_targets,class_to_idx,images_info
338
+ def find_images_and_targets(root,istrain=False,aux_info=False):
339
+ if os.path.exists(os.path.join(root,'train.json')):
340
+ with open(os.path.join(root,'train.json'),'r') as f:
341
+ train_class_info = json.load(f)
342
+ elif os.path.exists(os.path.join(root,'train_mini.json')):
343
+ with open(os.path.join(root,'train_mini.json'),'r') as f:
344
+ train_class_info = json.load(f)
345
+ else:
346
+ raise ValueError(f'not eixst file {root}/train.json or {root}/train_mini.json')
347
+ with open(os.path.join(root,'val.json'),'r') as f:
348
+ val_class_info = json.load(f)
349
+ categories_2021 = [x['name'].strip().lower() for x in val_class_info['categories']]
350
+ class_to_idx = {c: idx for idx, c in enumerate(categories_2021)}
351
+ id2label = dict()
352
+ for categorie in train_class_info['categories']:
353
+ id2label[int(categorie['id'])] = categorie['name'].strip().lower()
354
+ class_info = train_class_info if istrain else val_class_info
355
+
356
+ images_and_targets = []
357
+ images_info = []
358
+ if aux_info:
359
+ temporal_info = []
360
+ spatial_info = []
361
+
362
+ for image,annotation in zip(class_info['images'],class_info['annotations']):
363
+ file_path = os.path.join(root,image['file_name'])
364
+ id_name = id2label[int(annotation['category_id'])]
365
+ target = class_to_idx[id_name]
366
+ date = image['date']
367
+ latitude = image['latitude']
368
+ longitude = image['longitude']
369
+ location_uncertainty = image['location_uncertainty']
370
+ images_info.append({'date':date,
371
+ 'latitude':latitude,
372
+ 'longitude':longitude,
373
+ 'location_uncertainty':location_uncertainty,
374
+ 'target':target})
375
+ if aux_info:
376
+ temporal_info = get_temporal_info(date)
377
+ spatial_info = get_spatial_info(latitude,longitude)
378
+ images_and_targets.append((file_path,target,temporal_info+spatial_info))
379
+ else:
380
+ images_and_targets.append((file_path,target))
381
+ return images_and_targets,class_to_idx,images_info
382
+
383
+
384
+ class DatasetMeta(data.Dataset):
385
+ def __init__(
386
+ self,
387
+ root,
388
+ load_bytes=False,
389
+ transform=None,
390
+ train=False,
391
+ aux_info=False,
392
+ dataset='inaturelist2021',
393
+ class_ratio=1.0,
394
+ per_sample=1.0):
395
+ self.aux_info = aux_info
396
+ self.dataset = dataset
397
+ if dataset in ['inaturelist2021','inaturelist2021_mini']:
398
+ images, class_to_idx,images_info = find_images_and_targets(root,train,aux_info)
399
+ elif dataset in ['inaturelist2017','inaturelist2018']:
400
+ images, class_to_idx,images_info = find_images_and_targets_2017_2018(root,dataset,train,aux_info)
401
+ elif dataset == 'cub-200':
402
+ images, class_to_idx,images_info = find_images_and_targets_cub200(root,dataset,train,aux_info)
403
+ elif dataset == 'stanfordcars':
404
+ images, class_to_idx,images_info = find_images_and_targets_stanfordcars(root,dataset,train)
405
+ elif dataset == 'oxfordflower':
406
+ images, class_to_idx,images_info = find_images_and_targets_oxfordflower(root,dataset,train,aux_info)
407
+ elif dataset == 'stanforddogs':
408
+ images,class_to_idx,images_info = find_images_and_targets_stanforddogs(root,dataset,train)
409
+ elif dataset == 'nabirds':
410
+ images,class_to_idx,images_info = find_images_and_targets_nabirds(root,dataset,train)
411
+ elif dataset == 'aircraft':
412
+ images,class_to_idx,images_info = find_images_and_targets_aircraft(root,dataset,train)
413
+ if len(images) == 0:
414
+ raise RuntimeError(f'Found 0 images in subfolders of {root}. '
415
+ f'Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
416
+ self.root = root
417
+ self.samples = images
418
+ self.imgs = self.samples # torchvision ImageFolder compat
419
+ self.class_to_idx = class_to_idx
420
+ self.images_info = images_info
421
+ self.load_bytes = load_bytes
422
+ self.transform = transform
423
+
424
+
425
+ def __getitem__(self, index):
426
+ if self.aux_info:
427
+ path, target,aux_info = self.samples[index]
428
+ else:
429
+ path, target = self.samples[index]
430
+ img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB')
431
+ if self.transform is not None:
432
+ img = self.transform(img)
433
+ if self.aux_info:
434
+ if type(aux_info) is np.ndarray:
435
+ select_index = np.random.randint(aux_info.shape[0])
436
+ return img, target, aux_info[select_index,:]
437
+ else:
438
+ return img, target, np.asarray(aux_info).astype(np.float64)
439
+ else:
440
+ return img, target
441
+
442
+ def __len__(self):
443
+ return len(self.samples)
444
+ if __name__ == '__main__':
445
+ # train_dataset = DatasetPre('./fgvc_previous','./fgvc_previous',train=True,aux_info=True)
446
+ # import ipdb;ipdb.set_trace()
447
+ # train_dataset = DatasetMeta('./nabirds',train=True,aux_info=False,dataset='nabirds')
448
+ # find_images_and_targets_stanforddogs('./stanforddogs',None,istrain=True)
449
+ # find_images_and_targets_oxfordflower('./oxfordflower',None,istrain=True)
450
+ find_images_and_targets_ablation('./inaturelist2021',True,True,0.5,1.0)
451
+ # find_images_and_targets_cub200('./cub-200','cub-200',True,True)
452
+ # find_images_and_targets_aircraft('./aircraft','aircraft',True)
453
+ # train_dataset = DatasetMeta('./aircraft',train=False,aux_info=False,dataset='aircraft')
454
+ import ipdb;ipdb.set_trace()
455
+ # find_images_and_targets_2017('')
456
+
457
+
data/samplers.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+
10
+
11
+ class SubsetRandomSampler(torch.utils.data.Sampler):
12
+ r"""Samples elements randomly from a given list of indices, without replacement.
13
+
14
+ Arguments:
15
+ indices (sequence): a sequence of indices
16
+ """
17
+
18
+ def __init__(self, indices):
19
+ self.epoch = 0
20
+ self.indices = indices
21
+
22
+ def __iter__(self):
23
+ return (self.indices[i] for i in torch.randperm(len(self.indices)))
24
+
25
+ def __len__(self):
26
+ return len(self.indices)
27
+
28
+ def set_epoch(self, epoch):
29
+ self.epoch = epoch
data/zipreader.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import zipfile
10
+ import io
11
+ import numpy as np
12
+ from PIL import Image
13
+ from PIL import ImageFile
14
+
15
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
16
+
17
+
18
+ def is_zip_path(img_or_path):
19
+ """judge if this is a zip path"""
20
+ return '.zip@' in img_or_path
21
+
22
+
23
+ class ZipReader(object):
24
+ """A class to read zipped files"""
25
+ zip_bank = dict()
26
+
27
+ def __init__(self):
28
+ super(ZipReader, self).__init__()
29
+
30
+ @staticmethod
31
+ def get_zipfile(path):
32
+ zip_bank = ZipReader.zip_bank
33
+ if path not in zip_bank:
34
+ zfile = zipfile.ZipFile(path, 'r')
35
+ zip_bank[path] = zfile
36
+ return zip_bank[path]
37
+
38
+ @staticmethod
39
+ def split_zip_style_path(path):
40
+ pos_at = path.index('@')
41
+ assert pos_at != -1, "character '@' is not found from the given path '%s'" % path
42
+
43
+ zip_path = path[0: pos_at]
44
+ folder_path = path[pos_at + 1:]
45
+ folder_path = str.strip(folder_path, '/')
46
+ return zip_path, folder_path
47
+
48
+ @staticmethod
49
+ def list_folder(path):
50
+ zip_path, folder_path = ZipReader.split_zip_style_path(path)
51
+
52
+ zfile = ZipReader.get_zipfile(zip_path)
53
+ folder_list = []
54
+ for file_foler_name in zfile.namelist():
55
+ file_foler_name = str.strip(file_foler_name, '/')
56
+ if file_foler_name.startswith(folder_path) and \
57
+ len(os.path.splitext(file_foler_name)[-1]) == 0 and \
58
+ file_foler_name != folder_path:
59
+ if len(folder_path) == 0:
60
+ folder_list.append(file_foler_name)
61
+ else:
62
+ folder_list.append(file_foler_name[len(folder_path) + 1:])
63
+
64
+ return folder_list
65
+
66
+ @staticmethod
67
+ def list_files(path, extension=None):
68
+ if extension is None:
69
+ extension = ['.*']
70
+ zip_path, folder_path = ZipReader.split_zip_style_path(path)
71
+
72
+ zfile = ZipReader.get_zipfile(zip_path)
73
+ file_lists = []
74
+ for file_foler_name in zfile.namelist():
75
+ file_foler_name = str.strip(file_foler_name, '/')
76
+ if file_foler_name.startswith(folder_path) and \
77
+ str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
78
+ if len(folder_path) == 0:
79
+ file_lists.append(file_foler_name)
80
+ else:
81
+ file_lists.append(file_foler_name[len(folder_path) + 1:])
82
+
83
+ return file_lists
84
+
85
+ @staticmethod
86
+ def read(path):
87
+ zip_path, path_img = ZipReader.split_zip_style_path(path)
88
+ zfile = ZipReader.get_zipfile(zip_path)
89
+ data = zfile.read(path_img)
90
+ return data
91
+
92
+ @staticmethod
93
+ def imread(path):
94
+ zip_path, path_img = ZipReader.split_zip_style_path(path)
95
+ zfile = ZipReader.get_zipfile(zip_path)
96
+ data = zfile.read(path_img)
97
+ try:
98
+ im = Image.open(io.BytesIO(data))
99
+ except:
100
+ print("ERROR IMG LOADED: ", path_img)
101
+ random_img = np.random.rand(224, 224, 3) * 255
102
+ im = Image.fromarray(np.uint8(random_img))
103
+ return im
figs/overview.png ADDED
get_flops.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from timm.models import create_model
4
+ from models.CoAt import *
5
+
6
+ try:
7
+ from mmcv.cnn import get_model_complexity_info
8
+ from mmcv.cnn.utils.flops_counter import get_model_complexity_info, flops_to_string, params_to_string
9
+ except ImportError:
10
+ raise ImportError('Please upgrade mmcv to >0.6.2')
11
+
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser(description='Get FLOPS of a classification model')
15
+ parser.add_argument('model', help='train config file path')
16
+ parser.add_argument(
17
+ '--shape',
18
+ type=int,
19
+ nargs='+',
20
+ default=[224,],
21
+ help='input image size')
22
+ args = parser.parse_args()
23
+ return args
24
+
25
+ def get_flops(model, input_shape):
26
+ flops, params = get_model_complexity_info(model, input_shape, as_strings=False)
27
+ return flops_to_string(flops), params_to_string(params)
28
+
29
+
30
+ def main():
31
+ args = parse_args()
32
+
33
+ if len(args.shape) == 1:
34
+ input_shape = (3, args.shape[0], args.shape[0])
35
+ elif len(args.shape) == 2:
36
+ input_shape = (3,) + tuple(args.shape)
37
+ else:
38
+ raise ValueError('invalid input shape')
39
+
40
+ model = create_model(
41
+ args.model,
42
+ pretrained=False,
43
+ num_classes=1000,
44
+ img_size=args.shape[0],
45
+ )
46
+ model.name = args.model
47
+ if torch.cuda.is_available():
48
+ model.cuda()
49
+ model.eval()
50
+
51
+ flops, params = get_flops(model, input_shape)
52
+
53
+ split_line = '=' * 30
54
+ print(f'{split_line}\nInput shape: {input_shape}\n'
55
+ f'Flops: {flops}\nParams: {params}\n{split_line}')
56
+ print('!!!Please be cautious if you use the results in papers. '
57
+ 'You may need to check if all ops are supported and verify that the '
58
+ 'flops computation is correct.')
59
+
60
+
61
+ if __name__ == '__main__':
62
+ main()
logger.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import sys
10
+ import logging
11
+ import functools
12
+ from termcolor import colored
13
+
14
+
15
+ @functools.lru_cache()
16
+ def create_logger(output_dir, dist_rank=0, name='',local_rank=0):
17
+ # create logger
18
+ logger = logging.getLogger(name)
19
+ logger.setLevel(logging.DEBUG)
20
+ logger.propagate = False
21
+
22
+ # create formatter
23
+ fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
24
+ color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
25
+ colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
26
+
27
+ # create console handlers for master process
28
+ # if dist_rank == 0:
29
+ # console_handler = logging.StreamHandler(sys.stdout)
30
+ # console_handler.setLevel(logging.DEBUG)
31
+ # console_handler.setFormatter(
32
+ # logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
33
+ # logger.addHandler(console_handler)
34
+
35
+ if local_rank == 0:
36
+ console_handler = logging.StreamHandler(sys.stdout)
37
+ console_handler.setLevel(logging.DEBUG)
38
+ console_handler.setFormatter(
39
+ logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
40
+ logger.addHandler(console_handler)
41
+ # create file handlers
42
+ file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
43
+ file_handler.setLevel(logging.DEBUG)
44
+ file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
45
+ logger.addHandler(file_handler)
46
+
47
+ return logger
lr_scheduler.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ from timm.scheduler.cosine_lr import CosineLRScheduler
10
+ from timm.scheduler.step_lr import StepLRScheduler
11
+ from timm.scheduler.scheduler import Scheduler
12
+
13
+
14
+ def build_scheduler(config, optimizer, n_iter_per_epoch):
15
+ num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
16
+ warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
17
+ decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
18
+
19
+ lr_scheduler = None
20
+ if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
21
+ lr_scheduler = CosineLRScheduler(
22
+ optimizer,
23
+ t_initial=num_steps,
24
+ t_mul=1.,
25
+ lr_min=config.TRAIN.MIN_LR,
26
+ warmup_lr_init=config.TRAIN.WARMUP_LR,
27
+ warmup_t=warmup_steps,
28
+ cycle_limit=1,
29
+ t_in_epochs=False,
30
+ )
31
+ elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
32
+ lr_scheduler = LinearLRScheduler(
33
+ optimizer,
34
+ t_initial=num_steps,
35
+ lr_min_rate=0.01,
36
+ warmup_lr_init=config.TRAIN.WARMUP_LR,
37
+ warmup_t=warmup_steps,
38
+ t_in_epochs=False,
39
+ )
40
+ elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
41
+ lr_scheduler = StepLRScheduler(
42
+ optimizer,
43
+ decay_t=decay_steps,
44
+ decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
45
+ warmup_lr_init=config.TRAIN.WARMUP_LR,
46
+ warmup_t=warmup_steps,
47
+ t_in_epochs=False,
48
+ )
49
+
50
+ return lr_scheduler
51
+
52
+
53
+ class LinearLRScheduler(Scheduler):
54
+ def __init__(self,
55
+ optimizer: torch.optim.Optimizer,
56
+ t_initial: int,
57
+ lr_min_rate: float,
58
+ warmup_t=0,
59
+ warmup_lr_init=0.,
60
+ t_in_epochs=True,
61
+ noise_range_t=None,
62
+ noise_pct=0.67,
63
+ noise_std=1.0,
64
+ noise_seed=42,
65
+ initialize=True,
66
+ ) -> None:
67
+ super().__init__(
68
+ optimizer, param_group_field="lr",
69
+ noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
70
+ initialize=initialize)
71
+
72
+ self.t_initial = t_initial
73
+ self.lr_min_rate = lr_min_rate
74
+ self.warmup_t = warmup_t
75
+ self.warmup_lr_init = warmup_lr_init
76
+ self.t_in_epochs = t_in_epochs
77
+ if self.warmup_t:
78
+ self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
79
+ super().update_groups(self.warmup_lr_init)
80
+ else:
81
+ self.warmup_steps = [1 for _ in self.base_values]
82
+
83
+ def _get_lr(self, t):
84
+ if t < self.warmup_t:
85
+ lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
86
+ else:
87
+ t = t - self.warmup_t
88
+ total_t = self.t_initial - self.warmup_t
89
+ lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
90
+ return lrs
91
+
92
+ def get_epoch_values(self, epoch: int):
93
+ if self.t_in_epochs:
94
+ return self._get_lr(epoch)
95
+ else:
96
+ return None
97
+
98
+ def get_update_values(self, num_updates: int):
99
+ if not self.t_in_epochs:
100
+ return self._get_lr(num_updates)
101
+ else:
102
+ return None
main.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import argparse
4
+ import datetime
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.backends.cudnn as cudnn
9
+ import torch.distributed as dist
10
+
11
+ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
12
+ from timm.utils import accuracy, AverageMeter
13
+
14
+ from config import get_config
15
+ from models import build_model
16
+ from data import build_loader
17
+ from lr_scheduler import build_scheduler
18
+ from optimizer import build_optimizer
19
+ from logger import create_logger
20
+ from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor,load_pretained
21
+ from torch.utils.tensorboard import SummaryWriter
22
+ try:
23
+ # noinspection PyUnresolvedReferences
24
+ from apex import amp
25
+ except ImportError:
26
+ amp = None
27
+
28
+
29
+ def parse_option():
30
+ parser = argparse.ArgumentParser('MetaFG training and evaluation script', add_help=False)
31
+ parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
32
+ parser.add_argument(
33
+ "--opts",
34
+ help="Modify config options by adding 'KEY VALUE' pairs. ",
35
+ default=None,
36
+ nargs='+',
37
+ )
38
+
39
+ # easy config modification
40
+ parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
41
+ parser.add_argument('--data-path',default='./imagenet', type=str, help='path to dataset')
42
+ parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
43
+ parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
44
+ help='no: no cache, '
45
+ 'full: cache all data, '
46
+ 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
47
+ parser.add_argument('--resume', help='resume from checkpoint')
48
+ parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
49
+ parser.add_argument('--use-checkpoint', action='store_true',
50
+ help="whether to use gradient checkpointing to save memory")
51
+ parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
52
+ help='mixed precision opt level, if O0, no amp is used')
53
+ parser.add_argument('--output', default='output', type=str, metavar='PATH',
54
+ help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
55
+ parser.add_argument('--tag', help='tag of experiment')
56
+ parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
57
+ parser.add_argument('--throughput', action='store_true', help='Test throughput only')
58
+
59
+ parser.add_argument('--num-workers', type=int,
60
+ help="num of workers on dataloader ")
61
+
62
+ parser.add_argument('--lr', type=float, metavar='LR',
63
+ help='learning rate')
64
+ parser.add_argument('--weight-decay', type=float,
65
+ help='weight decay (default: 0.05 for adamw)')
66
+
67
+ parser.add_argument('--min-lr', type=float,
68
+ help='learning rate')
69
+ parser.add_argument('--warmup-lr', type=float,
70
+ help='warmup learning rate')
71
+ parser.add_argument('--epochs', type=int,
72
+ help="epochs")
73
+ parser.add_argument('--warmup-epochs', type=int,
74
+ help="epochs")
75
+
76
+ parser.add_argument('--dataset', type=str,
77
+ help='dataset')
78
+ parser.add_argument('--lr-scheduler-name', type=str,
79
+ help='lr scheduler name,cosin linear,step')
80
+
81
+ parser.add_argument('--pretrain', type=str,
82
+ help='pretrain')
83
+
84
+ parser.add_argument('--tensorboard', action='store_true', help='using tensorboard')
85
+
86
+
87
+ # distributed training
88
+ parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
89
+
90
+ args, unparsed = parser.parse_known_args()
91
+
92
+ config = get_config(args)
93
+
94
+ return args, config
95
+
96
+
97
+ def main(config):
98
+ dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
99
+ logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
100
+ model = build_model(config)
101
+ model.cuda()
102
+ logger.info(str(model))
103
+
104
+ optimizer = build_optimizer(config, model)
105
+ if config.AMP_OPT_LEVEL != "O0":
106
+ model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL)
107
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
108
+ model_without_ddp = model.module
109
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
110
+ logger.info(f"number of params: {n_parameters}")
111
+ if hasattr(model_without_ddp, 'flops'):
112
+ flops = model_without_ddp.flops()
113
+ logger.info(f"number of GFLOPs: {flops / 1e9}")
114
+ lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
115
+ if config.AUG.MIXUP > 0.:
116
+ # smoothing is handled with mixup label transform
117
+ criterion = SoftTargetCrossEntropy()
118
+ elif config.MODEL.LABEL_SMOOTHING > 0.:
119
+ criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
120
+ else:
121
+ criterion = torch.nn.CrossEntropyLoss()
122
+
123
+ max_accuracy = 0.0
124
+ if config.MODEL.PRETRAINED:
125
+ load_pretained(config,model_without_ddp,logger)
126
+ if config.EVAL_MODE:
127
+ acc1, acc5, loss = validate(config, data_loader_val, model)
128
+ logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
129
+ return
130
+
131
+ if config.TRAIN.AUTO_RESUME:
132
+ resume_file = auto_resume_helper(config.OUTPUT)
133
+ if resume_file:
134
+ if config.MODEL.RESUME:
135
+ logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
136
+ config.defrost()
137
+ config.MODEL.RESUME = resume_file
138
+ config.freeze()
139
+ logger.info(f'auto resuming from {resume_file}')
140
+ else:
141
+ logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
142
+
143
+ if config.MODEL.RESUME:
144
+ logger.info(f"**********normal test***********")
145
+ max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger)
146
+ acc1, acc5, loss = validate(config, data_loader_val, model)
147
+ logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
148
+ if config.DATA.ADD_META:
149
+ logger.info(f"**********mask meta test***********")
150
+ acc1, acc5, loss = validate(config, data_loader_val, model,mask_meta=True)
151
+ logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
152
+ if config.EVAL_MODE:
153
+ return
154
+
155
+ if config.THROUGHPUT_MODE:
156
+ throughput(data_loader_val, model, logger)
157
+ return
158
+
159
+ logger.info("Start training")
160
+ start_time = time.time()
161
+ for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
162
+ data_loader_train.sampler.set_epoch(epoch)
163
+ train_one_epoch_local_data(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler)
164
+ if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
165
+ save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger)
166
+
167
+ logger.info(f"**********normal test***********")
168
+ acc1, acc5, loss = validate(config, data_loader_val, model)
169
+ logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
170
+ max_accuracy = max(max_accuracy, acc1)
171
+ logger.info(f'Max accuracy: {max_accuracy:.2f}%')
172
+ if config.DATA.ADD_META:
173
+ logger.info(f"**********mask meta test***********")
174
+ acc1, acc5, loss = validate(config, data_loader_val, model,mask_meta=True)
175
+ logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
176
+ # data_loader_train.terminate()
177
+ total_time = time.time() - start_time
178
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
179
+ logger.info('Training time {}'.format(total_time_str))
180
+ def train_one_epoch_local_data(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler,tb_logger=None):
181
+ model.train()
182
+ if hasattr(model.module,'cur_epoch'):
183
+ model.module.cur_epoch = epoch
184
+ model.module.total_epoch = config.TRAIN.EPOCHS
185
+ optimizer.zero_grad()
186
+
187
+ num_steps = len(data_loader)
188
+ batch_time = AverageMeter()
189
+ loss_meter = AverageMeter()
190
+ norm_meter = AverageMeter()
191
+
192
+ start = time.time()
193
+ end = time.time()
194
+ for idx, data in enumerate(data_loader):
195
+ if config.DATA.ADD_META:
196
+ samples, targets,meta = data
197
+ meta = [m.float() for m in meta]
198
+ meta = torch.stack(meta,dim=0)
199
+ meta = meta.cuda(non_blocking=True)
200
+ else:
201
+ samples, targets= data
202
+ meta = None
203
+
204
+ samples = samples.cuda(non_blocking=True)
205
+ targets = targets.cuda(non_blocking=True)
206
+
207
+ if mixup_fn is not None:
208
+ samples, targets = mixup_fn(samples, targets)
209
+ if config.DATA.ADD_META:
210
+ outputs = model(samples,meta)
211
+ else:
212
+ outputs = model(samples)
213
+
214
+ if config.TRAIN.ACCUMULATION_STEPS > 1:
215
+ loss = criterion(outputs, targets)
216
+ loss = loss / config.TRAIN.ACCUMULATION_STEPS
217
+ if config.AMP_OPT_LEVEL != "O0":
218
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
219
+ scaled_loss.backward()
220
+ if config.TRAIN.CLIP_GRAD:
221
+ grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
222
+ else:
223
+ grad_norm = get_grad_norm(amp.master_params(optimizer))
224
+ else:
225
+ loss.backward()
226
+ if config.TRAIN.CLIP_GRAD:
227
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
228
+ else:
229
+ grad_norm = get_grad_norm(model.parameters())
230
+ if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
231
+ optimizer.step()
232
+ optimizer.zero_grad()
233
+ lr_scheduler.step_update(epoch * num_steps + idx)
234
+ else:
235
+ loss = criterion(outputs, targets)
236
+ optimizer.zero_grad()
237
+ if config.AMP_OPT_LEVEL != "O0":
238
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
239
+ scaled_loss.backward()
240
+ if config.TRAIN.CLIP_GRAD:
241
+ grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
242
+ else:
243
+ grad_norm = get_grad_norm(amp.master_params(optimizer))
244
+ else:
245
+ loss.backward()
246
+ if config.TRAIN.CLIP_GRAD:
247
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
248
+ else:
249
+ grad_norm = get_grad_norm(model.parameters())
250
+ optimizer.step()
251
+ lr_scheduler.step_update(epoch * num_steps + idx)
252
+
253
+ torch.cuda.synchronize()
254
+
255
+ loss_meter.update(loss.item(), targets.size(0))
256
+ norm_meter.update(grad_norm)
257
+ batch_time.update(time.time() - end)
258
+ end = time.time()
259
+
260
+ if idx % config.PRINT_FREQ == 0:
261
+ lr = optimizer.param_groups[0]['lr']
262
+ memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
263
+ etas = batch_time.avg * (num_steps - idx)
264
+ logger.info(
265
+ f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
266
+ f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
267
+ f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
268
+ f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
269
+ f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
270
+ f'mem {memory_used:.0f}MB')
271
+ epoch_time = time.time() - start
272
+ logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
273
+ @torch.no_grad()
274
+ def validate(config, data_loader, model, mask_meta=False):
275
+ criterion = torch.nn.CrossEntropyLoss()
276
+ model.eval()
277
+
278
+ batch_time = AverageMeter()
279
+ loss_meter = AverageMeter()
280
+ acc1_meter = AverageMeter()
281
+ acc5_meter = AverageMeter()
282
+
283
+ end = time.time()
284
+ for idx, data in enumerate(data_loader):
285
+ if config.DATA.ADD_META:
286
+ images,target,meta = data
287
+ meta = [m.float() for m in meta]
288
+ meta = torch.stack(meta,dim=0)
289
+ if mask_meta:
290
+ meta = torch.zeros_like(meta)
291
+ meta = meta.cuda(non_blocking=True)
292
+ else:
293
+ images, target = data
294
+ meta = None
295
+
296
+ images = images.cuda(non_blocking=True)
297
+ target = target.cuda(non_blocking=True)
298
+
299
+ # compute output
300
+ if config.DATA.ADD_META:
301
+ output = model(images,meta)
302
+ else:
303
+ output = model(images)
304
+
305
+ # measure accuracy and record loss
306
+ loss = criterion(output, target)
307
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
308
+
309
+ acc1 = reduce_tensor(acc1)
310
+ acc5 = reduce_tensor(acc5)
311
+ loss = reduce_tensor(loss)
312
+
313
+ loss_meter.update(loss.item(), target.size(0))
314
+ acc1_meter.update(acc1.item(), target.size(0))
315
+ acc5_meter.update(acc5.item(), target.size(0))
316
+
317
+ # measure elapsed time
318
+ batch_time.update(time.time() - end)
319
+ end = time.time()
320
+
321
+ if idx % config.PRINT_FREQ == 0:
322
+ memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
323
+ logger.info(
324
+ f'Test: [{idx}/{len(data_loader)}]\t'
325
+ f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
326
+ f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
327
+ f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
328
+ f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
329
+ f'Mem {memory_used:.0f}MB')
330
+ logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
331
+ return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
332
+
333
+
334
+ @torch.no_grad()
335
+ def throughput(data_loader, model, logger):
336
+ model.eval()
337
+
338
+ for idx, (images, _) in enumerate(data_loader):
339
+ images = images.cuda(non_blocking=True)
340
+ batch_size = images.shape[0]
341
+ for i in range(50):
342
+ model(images)
343
+ torch.cuda.synchronize()
344
+ logger.info(f"throughput averaged with 30 times")
345
+ tic1 = time.time()
346
+ for i in range(30):
347
+ model(images)
348
+ torch.cuda.synchronize()
349
+ tic2 = time.time()
350
+ logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
351
+ return
352
+
353
+
354
+ if __name__ == '__main__':
355
+ _, config = parse_option()
356
+
357
+ if config.AMP_OPT_LEVEL != "O0":
358
+ assert amp is not None, "amp not installed!"
359
+
360
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
361
+ rank = int(os.environ["RANK"])
362
+ world_size = int(os.environ['WORLD_SIZE'])
363
+ print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
364
+ else:
365
+ rank = -1
366
+ world_size = -1
367
+ torch.cuda.set_device(config.LOCAL_RANK)
368
+ torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
369
+ torch.distributed.barrier()
370
+
371
+ seed = config.SEED + dist.get_rank()
372
+ torch.manual_seed(seed)
373
+ np.random.seed(seed)
374
+ cudnn.benchmark = True
375
+
376
+ # linear scale the learning rate according to total batch size, may not be optimal
377
+ linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
378
+ linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
379
+ linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
380
+ # gradient accumulation also need to scale the learning rate
381
+ if config.TRAIN.ACCUMULATION_STEPS > 1:
382
+ linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
383
+ linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
384
+ linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
385
+ config.defrost()
386
+ config.TRAIN.BASE_LR = linear_scaled_lr
387
+ config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
388
+ config.TRAIN.MIN_LR = linear_scaled_min_lr
389
+ config.freeze()
390
+
391
+ os.makedirs(config.OUTPUT, exist_ok=True)
392
+ logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}",local_rank=config.LOCAL_RANK)
393
+
394
+ if dist.get_rank() == 0:
395
+ path = os.path.join(config.OUTPUT, "config.json")
396
+ with open(path, "w") as f:
397
+ f.write(config.dump())
398
+ logger.info(f"Full config saved to {path}")
399
+
400
+ # print config
401
+ logger.info(config.dump())
402
+
403
+ main(config)
models/MBConv.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ class SwishImplementation(torch.autograd.Function):
9
+ @staticmethod
10
+ def forward(ctx, i):
11
+ result = i * torch.sigmoid(i)
12
+ ctx.save_for_backward(i)
13
+ return result
14
+
15
+ @staticmethod
16
+ def backward(ctx, grad_output):
17
+ i = ctx.saved_variables[0]
18
+ sigmoid_i = torch.sigmoid(i)
19
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
20
+
21
+ class MemoryEfficientSwish(nn.Module):
22
+ def forward(self, x):
23
+ return SwishImplementation.apply(x)
24
+
25
+
26
+ def drop_connect(inputs, p, training):
27
+ """ Drop connect. """
28
+ if not training: return inputs
29
+ batch_size = inputs.shape[0]
30
+ keep_prob = 1 - p
31
+ random_tensor = keep_prob
32
+ random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
33
+ binary_tensor = torch.floor(random_tensor)
34
+ output = inputs / keep_prob * binary_tensor
35
+ return output
36
+
37
+
38
+ def get_same_padding_conv2d(image_size=None):
39
+ return partial(Conv2dStaticSamePadding, image_size=image_size)
40
+
41
+ def get_width_and_height_from_size(x):
42
+ """ Obtains width and height from a int or tuple """
43
+ if isinstance(x, int): return x, x
44
+ if isinstance(x, list) or isinstance(x, tuple): return x
45
+ else: raise TypeError()
46
+
47
+ def calculate_output_image_size(input_image_size, stride):
48
+ """
49
+ 计算出 Conv2dSamePadding with a stride.
50
+ """
51
+ if input_image_size is None: return None
52
+ image_height, image_width = get_width_and_height_from_size(input_image_size)
53
+ stride = stride if isinstance(stride, int) else stride[0]
54
+ image_height = int(math.ceil(image_height / stride))
55
+ image_width = int(math.ceil(image_width / stride))
56
+ return [image_height, image_width]
57
+
58
+
59
+
60
+ class Conv2dStaticSamePadding(nn.Conv2d):
61
+ """ 2D Convolutions like TensorFlow, for a fixed image size"""
62
+
63
+ def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
64
+ super().__init__(in_channels, out_channels, kernel_size, **kwargs)
65
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
66
+
67
+ # Calculate padding based on image size and save it
68
+ assert image_size is not None
69
+ ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
70
+ kh, kw = self.weight.size()[-2:]
71
+ sh, sw = self.stride
72
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
73
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
74
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
75
+ if pad_h > 0 or pad_w > 0:
76
+ self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
77
+ else:
78
+ self.static_padding = Identity()
79
+
80
+ def forward(self, x):
81
+ x = self.static_padding(x)
82
+ x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
83
+ return x
84
+
85
+ class Identity(nn.Module):
86
+ def __init__(self, ):
87
+ super(Identity, self).__init__()
88
+
89
+ def forward(self, input):
90
+ return input
91
+
92
+ # #MBConvBlock
93
+ class MBConvBlock(nn.Module):
94
+ '''
95
+ 层 ksize3*3 输入32 输出16 conv1 stride步长1
96
+ '''
97
+ def __init__(self, ksize, input_filters, output_filters, expand_ratio=1, stride=1,image_size=224,drop_connect_rate=0.):
98
+ super().__init__()
99
+ self._bn_mom = 0.1
100
+ self._bn_eps = 0.01
101
+ self._se_ratio = 0.25
102
+ self._input_filters = input_filters
103
+ self._output_filters = output_filters
104
+ self._expand_ratio = expand_ratio
105
+ self._kernel_size = ksize
106
+ self._stride = stride
107
+ self._drop_connect_rate = drop_connect_rate
108
+ inp = self._input_filters
109
+ oup = self._input_filters * self._expand_ratio
110
+ if self._expand_ratio != 1:
111
+ self._expand_conv = nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=1,bias=False)
112
+ self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
113
+
114
+
115
+ # Depthwise convolution
116
+ k = self._kernel_size
117
+ s = self._stride
118
+ self._depthwise_conv = nn.Conv2d(in_channels=oup, out_channels=oup, groups=oup,
119
+ kernel_size=k, stride=s, padding=1,bias=False)
120
+
121
+ self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
122
+ # Squeeze and Excitation layer, if desired
123
+ num_squeezed_channels = max(1, int(self._input_filters * self._se_ratio))
124
+ self._se_reduce = nn.Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
125
+ self._se_expand = nn.Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
126
+
127
+ # Output phase
128
+ final_oup = self._output_filters
129
+ self._project_conv = nn.Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1,bias=False)
130
+ self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
131
+ self._swish = MemoryEfficientSwish()
132
+
133
+ def forward(self, inputs):
134
+ """
135
+ :param inputs: input tensor
136
+ :return: output of block
137
+ """
138
+
139
+ # Expansion and Depthwise Convolution
140
+ x = inputs
141
+ if self._expand_ratio != 1:
142
+ expand = self._expand_conv(inputs)
143
+ bn0 = self._bn0(expand)
144
+ x = self._swish(bn0)
145
+ depthwise = self._depthwise_conv(x)
146
+ bn1 = self._bn1(depthwise)
147
+ x = self._swish(bn1)
148
+ # Squeeze and Excitation
149
+ x_squeezed = F.adaptive_avg_pool2d(x, 1)
150
+ x_squeezed = self._se_reduce(x_squeezed)
151
+ x_squeezed = self._swish(x_squeezed)
152
+ x_squeezed = self._se_expand(x_squeezed)
153
+ x = torch.sigmoid(x_squeezed) * x
154
+
155
+ x = self._bn2(self._project_conv(x))
156
+
157
+ # Skip connection and drop connect
158
+ input_filters, output_filters = self._input_filters, self._output_filters
159
+ if self._stride == 1 and input_filters == output_filters:
160
+ if self._drop_connect_rate!=0:
161
+ x = drop_connect(x, p=self._drop_connect_rate, training=self.training)
162
+ x = x + inputs # skip connection
163
+ return x
164
+ if __name__ == '__main__':
165
+ input=torch.randn(1,3,112,112)
166
+ mbconv=MBConvBlock(ksize=3,input_filters=3,output_filters=3,expand_ratio=4,stride=1)
167
+ print(mbconv)
168
+ out=mbconv(input)
169
+ print(out.shape)
models/MHSA.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ import numpy as np
6
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7
+
8
+ class Mlp(nn.Module):
9
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
10
+ super().__init__()
11
+ out_features = out_features or in_features
12
+ hidden_features = hidden_features or in_features
13
+ self.fc1 = nn.Linear(in_features, hidden_features)
14
+ self.act = act_layer()
15
+ self.fc2 = nn.Linear(hidden_features, out_features)
16
+ self.drop = nn.Dropout(drop)
17
+
18
+ def forward(self, x, H=None, W=None):
19
+ x = self.fc1(x)
20
+ x = self.act(x)
21
+ x = self.drop(x)
22
+ x = self.fc2(x)
23
+ x = self.drop(x)
24
+ return x
25
+ class DWConv(nn.Module):
26
+ def __init__(self, dim=768):
27
+ super(DWConv, self).__init__()
28
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
29
+
30
+ def forward(self, x, H, W):
31
+ B, N, C = x.shape
32
+ x = x.transpose(1, 2).view(B, C, H, W)
33
+ x = self.dwconv(x)
34
+ x = x.flatten(2).transpose(1, 2)
35
+
36
+ return x
37
+ class Relative_Attention(nn.Module):
38
+ def __init__(self,dim,img_size,extra_token_num=1,num_heads=8,qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
39
+ super().__init__()
40
+ self.num_heads = num_heads
41
+ self.extra_token_num = extra_token_num
42
+ head_dim = dim // num_heads
43
+ self.img_size = img_size # h,w
44
+ self.scale = qk_scale or head_dim ** -0.5
45
+ # define a parameter table of relative position bias,add cls_token bias
46
+ self.relative_position_bias_table = nn.Parameter(
47
+ torch.zeros((2 * img_size[0] - 1) * (2 * img_size[1] - 1) + 1, num_heads)) # 2*h-1 * 2*w-1 + 1, nH
48
+
49
+ # get pair-wise relative position index for each token
50
+ coords_h = torch.arange(self.img_size[0])
51
+ coords_w = torch.arange(self.img_size[1])
52
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, h, w
53
+ coords_flatten = torch.flatten(coords, 1) # 2, h*w
54
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, h*w, h*w
55
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # h*w, h*w, 2
56
+ relative_coords[:, :, 0] += self.img_size[0] - 1 # shift to start from 0
57
+ relative_coords[:, :, 1] += self.img_size[1] - 1
58
+ relative_coords[:, :, 0] *= 2 * self.img_size[1] - 1
59
+ relative_position_index = relative_coords.sum(-1) # h*w, h*w
60
+ relative_position_index = F.pad(relative_position_index,(extra_token_num,0,extra_token_num,0))
61
+ relative_position_index = relative_position_index.long()
62
+ self.register_buffer("relative_position_index", relative_position_index)
63
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
64
+ self.attn_drop = nn.Dropout(attn_drop)
65
+ self.proj = nn.Linear(dim, dim)
66
+ self.proj_drop = nn.Dropout(proj_drop)
67
+ trunc_normal_(self.relative_position_bias_table, std=.02)
68
+ self.softmax = nn.Softmax(dim=-1)
69
+ def forward(self, x,):
70
+ """
71
+ Args:
72
+ x: input features with shape of (B, N, C)
73
+ """
74
+ B_, N, C = x.shape
75
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
76
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
77
+
78
+ q = q * self.scale
79
+ attn = (q @ k.transpose(-2, -1))
80
+
81
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
82
+ self.img_size[0] * self.img_size[1] + self.extra_token_num, self.img_size[0] * self.img_size[1] + self.extra_token_num, -1) # h*w+1,h*w+1,nH
83
+
84
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, h*w+1, h*w+1
85
+ attn = attn + relative_position_bias.unsqueeze(0)
86
+
87
+ attn = self.softmax(attn)
88
+
89
+ attn = self.attn_drop(attn)
90
+
91
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
92
+ x = self.proj(x)
93
+ x = self.proj_drop(x)
94
+ return x
95
+ class OverlapPatchEmbed(nn.Module):
96
+ """ Image to Patch Embedding
97
+ """
98
+
99
+ def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
100
+ super().__init__()
101
+ patch_size = to_2tuple(patch_size)
102
+ self.patch_size = patch_size
103
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
104
+ padding=(patch_size[0] // 2, patch_size[1] // 2))
105
+ self.norm = nn.LayerNorm(embed_dim)
106
+
107
+ self.apply(self._init_weights)
108
+
109
+ def _init_weights(self, m):
110
+ if isinstance(m, nn.Linear):
111
+ trunc_normal_(m.weight, std=.02)
112
+ if isinstance(m, nn.Linear) and m.bias is not None:
113
+ nn.init.constant_(m.bias, 0)
114
+ elif isinstance(m, nn.LayerNorm):
115
+ nn.init.constant_(m.bias, 0)
116
+ nn.init.constant_(m.weight, 1.0)
117
+ elif isinstance(m, nn.Conv2d):
118
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
119
+ fan_out //= m.groups
120
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
121
+ if m.bias is not None:
122
+ m.bias.data.zero_()
123
+
124
+ def forward(self, x):
125
+ x = self.proj(x)
126
+ _, _, H, W = x.shape
127
+ x = x.flatten(2).transpose(1, 2)
128
+ x = self.norm(x)
129
+
130
+ return x, H, W
131
+ class MHSABlock(nn.Module):
132
+ def __init__(self, input_dim, output_dim,image_size, stride, num_heads,extra_token_num=1,mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
133
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
134
+ super().__init__()
135
+ if stride != 1:
136
+ self.patch_embed = OverlapPatchEmbed(patch_size=3,stride=stride,in_chans=input_dim,embed_dim=output_dim)
137
+ self.img_size = image_size//2
138
+ else:
139
+ self.patch_embed = None
140
+ self.img_size = image_size
141
+ self.img_size = to_2tuple(self.img_size)
142
+
143
+ self.norm1 = norm_layer(output_dim)
144
+ self.attn = Relative_Attention(
145
+ output_dim,self.img_size, extra_token_num=extra_token_num,num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
146
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
147
+ self.norm2 = norm_layer(output_dim)
148
+ mlp_hidden_dim = int(output_dim * mlp_ratio)
149
+ self.mlp = Mlp(in_features=output_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
150
+
151
+ def forward(self, x, H, W, extra_tokens=None):
152
+ if self.patch_embed is not None:
153
+ x,_,_ = self.patch_embed(x)
154
+
155
+ extra_tokens = [token.expand(x.shape[0],-1,-1) for token in extra_tokens]
156
+ extra_tokens.append(x)
157
+ x = torch.cat(extra_tokens,dim=1)
158
+ x = x + self.drop_path(self.attn(self.norm1(x)))
159
+ x = x + self.drop_path(self.mlp(self.norm2(x),H//2,W//2))
160
+ return x
161
+
models/MetaFG.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from timm.models.helpers import load_pretrained
6
+ from timm.models.registry import register_model
7
+ from timm.models.layers import trunc_normal_
8
+ import numpy as np
9
+ from .MBConv import MBConvBlock
10
+ from .MHSA import MHSABlock,Mlp
11
+ def _cfg(url='', **kwargs):
12
+ return {
13
+ 'url': url,
14
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
15
+ 'crop_pct': .9, 'interpolation': 'bicubic',
16
+ 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225),
17
+ 'classifier': 'head',
18
+ **kwargs
19
+ }
20
+
21
+ default_cfgs = {
22
+ 'MetaFG_0': _cfg(),
23
+ 'MetaFG_1': _cfg(),
24
+ 'MetaFG_2': _cfg(),
25
+ }
26
+
27
+ def make_blocks(stage_index,depths,embed_dims,img_size,dpr,extra_token_num=1,num_heads=8,mlp_ratio=4.,stage_type='conv'):
28
+ stage_name = f'stage_{stage_index}'
29
+ blocks = []
30
+ for block_idx in range(depths[stage_index]):
31
+ stride = 2 if block_idx == 0 and stage_index != 1 else 1
32
+ in_chans = embed_dims[stage_index] if block_idx != 0 else embed_dims[stage_index-1]
33
+ out_chans = embed_dims[stage_index]
34
+ image_size = img_size if block_idx == 0 or stage_index == 1 else img_size//2
35
+ drop_path_rate = dpr[sum(depths[1:stage_index])+block_idx]
36
+ if stage_type == 'conv':
37
+ blocks.append(MBConvBlock(ksize=3,input_filters=in_chans,output_filters=out_chans,
38
+ image_size=image_size,expand_ratio=int(mlp_ratio),stride=stride,drop_connect_rate=drop_path_rate))
39
+ elif stage_type == 'mhsa':
40
+ blocks.append(MHSABlock(input_dim=in_chans,output_dim=out_chans,
41
+ image_size=image_size,stride=stride,num_heads=num_heads,extra_token_num=extra_token_num,
42
+ mlp_ratio=mlp_ratio,drop_path=drop_path_rate))
43
+ else:
44
+ raise NotImplementedError("We only support conv and mhsa")
45
+ return blocks
46
+
47
+
48
+ class MetaFG(nn.Module):
49
+ def __init__(self,img_size=224,in_chans=3, num_classes=1000,
50
+ conv_embed_dims = [64,96,192],attn_embed_dims=[384,768],
51
+ conv_depths = [2,2,3],attn_depths = [5,2],num_heads=32,extra_token_num=1,mlp_ratio=4.,
52
+ conv_norm_layer=nn.BatchNorm2d,attn_norm_layer=nn.LayerNorm,
53
+ conv_act_layer=nn.ReLU,attn_act_layer=nn.GELU,
54
+ qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,drop_path_rate=0.,
55
+ meta_dims=[],
56
+ only_last_cls=False,
57
+ use_checkpoint=False):
58
+ super().__init__()
59
+ self.only_last_cls = only_last_cls
60
+ self.img_size = img_size
61
+ self.num_classes = num_classes
62
+ stem_chs = (3 * (conv_embed_dims[0] // 4), conv_embed_dims[0])
63
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(conv_depths[1:]+attn_depths))]
64
+ #stage_0
65
+ self.stage_0 = nn.Sequential(*[
66
+ nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False),
67
+ conv_norm_layer(stem_chs[0]),
68
+ conv_act_layer(inplace=True),
69
+ nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False),
70
+ conv_norm_layer(stem_chs[1]),
71
+ conv_act_layer(inplace=True),
72
+ nn.Conv2d(stem_chs[1], conv_embed_dims[0], 3, stride=1, padding=1, bias=False)])
73
+ self.bn1 = conv_norm_layer(conv_embed_dims[0])
74
+ self.act1 = conv_act_layer(inplace=True)
75
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
76
+ #stage_1
77
+ self.stage_1 = nn.ModuleList(make_blocks(1,conv_depths+attn_depths,conv_embed_dims+attn_embed_dims,img_size//4,
78
+ dpr=dpr,num_heads=num_heads,extra_token_num=extra_token_num,mlp_ratio=mlp_ratio,stage_type='conv'))
79
+ #stage_2
80
+ self.stage_2 = nn.ModuleList(make_blocks(2,conv_depths+attn_depths,conv_embed_dims+attn_embed_dims,img_size//4,
81
+ dpr=dpr,num_heads=num_heads,extra_token_num=extra_token_num,mlp_ratio=mlp_ratio,stage_type='conv'))
82
+
83
+ #stage_3
84
+ self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, attn_embed_dims[0]))
85
+ self.stage_3 = nn.ModuleList(make_blocks(3,conv_depths+attn_depths,conv_embed_dims+attn_embed_dims,img_size//8,
86
+ dpr=dpr,num_heads=num_heads,extra_token_num=extra_token_num,mlp_ratio=mlp_ratio,stage_type='mhsa'))
87
+
88
+ #stage_4
89
+ self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, attn_embed_dims[1]))
90
+ self.stage_4 = nn.ModuleList(make_blocks(4,conv_depths+attn_depths,conv_embed_dims+attn_embed_dims,img_size//16,
91
+ dpr=dpr,num_heads=num_heads,extra_token_num=extra_token_num,mlp_ratio=mlp_ratio,stage_type='mhsa'))
92
+ self.norm_2 = attn_norm_layer(attn_embed_dims[1])
93
+ #Aggregate
94
+ if not self.only_last_cls:
95
+ self.cl_1_fc = nn.Sequential(*[Mlp(in_features=attn_embed_dims[0], out_features=attn_embed_dims[1]),
96
+ attn_norm_layer(attn_embed_dims[1])])
97
+ self.aggregate = torch.nn.Conv1d(in_channels=2, out_channels=1, kernel_size=1)
98
+ self.norm_1 = attn_norm_layer(attn_embed_dims[0])
99
+ self.norm = attn_norm_layer(attn_embed_dims[1])
100
+
101
+ # Classifier head
102
+ self.head = nn.Linear(attn_embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
103
+
104
+ trunc_normal_(self.cls_token_1, std=.02)
105
+ trunc_normal_(self.cls_token_2, std=.02)
106
+ self.apply(self._init_weights)
107
+ def _init_weights(self, m):
108
+ if isinstance(m, nn.Linear):
109
+ trunc_normal_(m.weight, std=.02)
110
+ if isinstance(m, nn.Linear) and m.bias is not None:
111
+ nn.init.constant_(m.bias, 0)
112
+ elif isinstance(m, nn.LayerNorm):
113
+ nn.init.constant_(m.bias, 0)
114
+ nn.init.constant_(m.weight, 1.0)
115
+ elif isinstance(m, nn.Conv2d):
116
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
117
+ # fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
118
+ # fan_out //= m.groups
119
+ # m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
120
+ # if m.bias is not None:
121
+ # m.bias.data.zero_()
122
+ elif isinstance(m, nn.BatchNorm2d):
123
+ nn.init.ones_(m.weight)
124
+ nn.init.zeros_(m.bias)
125
+
126
+ @torch.jit.ignore
127
+ def no_weight_decay(self):
128
+ return {'cls_token_1','cls_token_2'}
129
+
130
+ def get_classifier(self):
131
+ return self.head
132
+
133
+ def reset_classifier(self, num_classes, global_pool=''):
134
+ self.num_classes = num_classes
135
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
136
+
137
+ def forward_features(self,x,meta=None):
138
+ extra_tokens_1 = [self.cls_token_1]
139
+ extra_tokens_2 = [self.cls_token_2]
140
+ B = x.shape[0]
141
+ x = self.stage_0(x)
142
+ x = self.bn1(x)
143
+ x = self.act1(x)
144
+ x = self.maxpool(x)
145
+ for blk in self.stage_1:
146
+ x = blk(x)
147
+ for blk in self.stage_2:
148
+ x = blk(x)
149
+ H0,W0 = self.img_size//8,self.img_size//8
150
+ for ind,blk in enumerate(self.stage_3):
151
+ if ind==0:
152
+ x = blk(x,H0,W0,extra_tokens_1)
153
+ else:
154
+ x = blk(x,H0,W0)
155
+ if not self.only_last_cls:
156
+ cls_1 = x[:, :1, :]
157
+ cls_1 = self.norm_1(cls_1)
158
+ cls_1 = self.cl_1_fc(cls_1)
159
+ x = x[:, 1:, :]
160
+ H1,W1 = self.img_size//16,self.img_size//16
161
+ x = x.reshape(B,H1,W1,-1).permute(0, 3, 1, 2).contiguous()
162
+ for ind,blk in enumerate(self.stage_4):
163
+ if ind==0:
164
+ x = blk(x,H1,W1,extra_tokens_2)
165
+ else:
166
+ x = blk(x,H1,W1)
167
+ cls_2 = x[:, :1, :]
168
+ cls_2 = self.norm_2(cls_2)
169
+ if not self.only_last_cls:
170
+ cls = torch.cat((cls_1,cls_2), dim=1)#B,2,C
171
+ cls = self.aggregate(cls).squeeze(dim=1)#B,C
172
+ cls = self.norm(cls)
173
+ else:
174
+ cls = cls_2.squeeze(dim=1)
175
+ return cls
176
+
177
+ def forward(self, x,meta=None):
178
+ x = self.forward_features(x,meta)
179
+ x = self.head(x)
180
+ return x
181
+ @register_model
182
+ def MetaFG_0(pretrained=False, **kwargs):
183
+ model = MetaFG(conv_embed_dims = [64,96,192],attn_embed_dims=[384,768],
184
+ conv_depths = [2,2,3],attn_depths = [5,2],num_heads=8,mlp_ratio=4., **kwargs)
185
+ model.default_cfg = default_cfgs['MetaFG_0']
186
+ if pretrained:
187
+ load_pretrained(
188
+ model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
189
+ return model
190
+ @register_model
191
+ def MetaFG_1(pretrained=False, **kwargs):
192
+ model = MetaFG(conv_embed_dims = [64,96,192],attn_embed_dims=[384,768],
193
+ conv_depths = [2,2,6],attn_depths = [14,2],num_heads=8,mlp_ratio=4., **kwargs)
194
+ model.default_cfg = default_cfgs['MetaFG_1']
195
+ if pretrained:
196
+ load_pretrained(
197
+ model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
198
+ return model
199
+ @register_model
200
+ def MetaFG_2(pretrained=False, **kwargs):
201
+ model = MetaFG(conv_embed_dims = [128,128,256],attn_embed_dims=[512,1024],
202
+ conv_depths = [2,2,6],attn_depths = [14,2],num_heads=8,mlp_ratio=4., **kwargs)
203
+ model.default_cfg = default_cfgs['MetaFG_2']
204
+ if pretrained:
205
+ load_pretrained(
206
+ model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
207
+ return model
208
+ if __name__ == "__main__":
209
+ x = torch.randn([2, 3, 224, 224])
210
+ model = MetaFG()
211
+ import ipdb;ipdb.set_trace()
212
+ output = model(x)
213
+ print(output.shape)
models/MetaFG_meta.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.utils.checkpoint as checkpoint
5
+ from timm.models.helpers import load_pretrained
6
+ from timm.models.registry import register_model
7
+ from timm.models.layers import trunc_normal_
8
+ import numpy as np
9
+ from .MBConv import MBConvBlock
10
+ from .MHSA import MHSABlock,Mlp
11
+ from .meta_encoder import ResNormLayer
12
+ def _cfg(url='', **kwargs):
13
+ return {
14
+ 'url': url,
15
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
16
+ 'crop_pct': .9, 'interpolation': 'bicubic',
17
+ 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225),
18
+ 'classifier': 'head',
19
+ **kwargs
20
+ }
21
+
22
+ default_cfgs = {
23
+ 'MetaFG_0': _cfg(),
24
+ 'MetaFG_1': _cfg(),
25
+ 'MetaFG_2': _cfg(),
26
+ }
27
+
28
+ def make_blocks(stage_index,depths,embed_dims,img_size,dpr,extra_token_num=1,num_heads=8,mlp_ratio=4.,stage_type='conv'):
29
+ stage_name = f'stage_{stage_index}'
30
+ blocks = []
31
+ for block_idx in range(depths[stage_index]):
32
+ stride = 2 if block_idx == 0 and stage_index != 1 else 1
33
+ in_chans = embed_dims[stage_index] if block_idx != 0 else embed_dims[stage_index-1]
34
+ out_chans = embed_dims[stage_index]
35
+ image_size = img_size if block_idx == 0 or stage_index == 1 else img_size//2
36
+ drop_path_rate = dpr[sum(depths[1:stage_index])+block_idx]
37
+ if stage_type == 'conv':
38
+ blocks.append(MBConvBlock(ksize=3,input_filters=in_chans,output_filters=out_chans,
39
+ image_size=image_size,expand_ratio=int(mlp_ratio),stride=stride,drop_connect_rate=drop_path_rate))
40
+ elif stage_type == 'mhsa':
41
+ blocks.append(MHSABlock(input_dim=in_chans,output_dim=out_chans,
42
+ image_size=image_size,stride=stride,num_heads=num_heads,extra_token_num=extra_token_num,
43
+ mlp_ratio=mlp_ratio,drop_path=drop_path_rate))
44
+ else:
45
+ raise NotImplementedError("We only support conv and mhsa")
46
+ return blocks
47
+
48
+
49
+ class MetaFG_Meta(nn.Module):
50
+ def __init__(self,img_size=224,in_chans=3, num_classes=1000,
51
+ conv_embed_dims = [64,96,192],attn_embed_dims=[384,768],
52
+ conv_depths = [2,2,3],attn_depths = [5,2],num_heads=32,extra_token_num=3,mlp_ratio=4.,
53
+ conv_norm_layer=nn.BatchNorm2d,attn_norm_layer=nn.LayerNorm,
54
+ conv_act_layer=nn.ReLU,attn_act_layer=nn.GELU,
55
+ qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,drop_path_rate=0.,
56
+ add_meta=True,meta_dims=[4,3],mask_prob=1.0,mask_type='linear',
57
+ only_last_cls=False,
58
+ use_checkpoint=False):
59
+ super().__init__()
60
+ self.only_last_cls = only_last_cls
61
+ self.img_size = img_size
62
+ self.num_classes = num_classes
63
+ self.add_meta = add_meta
64
+ self.meta_dims = meta_dims
65
+ self.cur_epoch = -1
66
+ self.total_epoch = -1
67
+ self.mask_prob = mask_prob
68
+ self.mask_type = mask_type
69
+ self.attn_embed_dims = attn_embed_dims
70
+ self.extra_token_num = extra_token_num
71
+ if self.add_meta:
72
+ # assert len(meta_dims)==extra_token_num-1
73
+ for ind,meta_dim in enumerate(meta_dims):
74
+ meta_head_1 = nn.Sequential(
75
+ nn.Linear(meta_dim, attn_embed_dims[0]),
76
+ nn.ReLU(inplace=True),
77
+ nn.LayerNorm(attn_embed_dims[0]),
78
+ ResNormLayer(attn_embed_dims[0]),
79
+ ) if meta_dim > 0 else nn.Identity()
80
+ meta_head_2 = nn.Sequential(
81
+ nn.Linear(meta_dim, attn_embed_dims[1]),
82
+ nn.ReLU(inplace=True),
83
+ nn.LayerNorm(attn_embed_dims[1]),
84
+ ResNormLayer(attn_embed_dims[1]),
85
+ ) if meta_dim > 0 else nn.Identity()
86
+ setattr(self, f"meta_{ind+1}_head_1", meta_head_1)
87
+ setattr(self, f"meta_{ind+1}_head_2", meta_head_2)
88
+
89
+
90
+ stem_chs = (3 * (conv_embed_dims[0] // 4), conv_embed_dims[0])
91
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(conv_depths[1:]+attn_depths))]
92
+ #stage_0
93
+ self.stage_0 = nn.Sequential(*[
94
+ nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False),
95
+ conv_norm_layer(stem_chs[0]),
96
+ conv_act_layer(inplace=True),
97
+ nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False),
98
+ conv_norm_layer(stem_chs[1]),
99
+ conv_act_layer(inplace=True),
100
+ nn.Conv2d(stem_chs[1], conv_embed_dims[0], 3, stride=1, padding=1, bias=False)])
101
+ self.bn1 = conv_norm_layer(conv_embed_dims[0])
102
+ self.act1 = conv_act_layer(inplace=True)
103
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
104
+ #stage_1
105
+ self.stage_1 = nn.ModuleList(make_blocks(1,conv_depths+attn_depths,conv_embed_dims+attn_embed_dims,img_size//4,
106
+ dpr=dpr,num_heads=num_heads,extra_token_num=extra_token_num,mlp_ratio=mlp_ratio,stage_type='conv'))
107
+ #stage_2
108
+ self.stage_2 = nn.ModuleList(make_blocks(2,conv_depths+attn_depths,conv_embed_dims+attn_embed_dims,img_size//4,
109
+ dpr=dpr,num_heads=num_heads,extra_token_num=extra_token_num,mlp_ratio=mlp_ratio,stage_type='conv'))
110
+
111
+ #stage_3
112
+ self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, attn_embed_dims[0]))
113
+ self.stage_3 = nn.ModuleList(make_blocks(3,conv_depths+attn_depths,conv_embed_dims+attn_embed_dims,img_size//8,
114
+ dpr=dpr,num_heads=num_heads,extra_token_num=extra_token_num,mlp_ratio=mlp_ratio,stage_type='mhsa'))
115
+ #stage_4
116
+ self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, attn_embed_dims[1]))
117
+ self.stage_4 = nn.ModuleList(make_blocks(4,conv_depths+attn_depths,conv_embed_dims+attn_embed_dims,img_size//16,
118
+ dpr=dpr,num_heads=num_heads,extra_token_num=extra_token_num,mlp_ratio=mlp_ratio,stage_type='mhsa'))
119
+ self.norm_2 = attn_norm_layer(attn_embed_dims[1])
120
+
121
+ #Aggregate
122
+ if not self.only_last_cls:
123
+ self.cl_1_fc = nn.Sequential(*[Mlp(in_features=attn_embed_dims[0], out_features=attn_embed_dims[1]),
124
+ attn_norm_layer(attn_embed_dims[1])])
125
+ self.aggregate = torch.nn.Conv1d(in_channels=2, out_channels=1, kernel_size=1)
126
+ self.norm = attn_norm_layer(attn_embed_dims[1])
127
+ self.norm_1 = attn_norm_layer(attn_embed_dims[0])
128
+ # Classifier head
129
+ self.head = nn.Linear(attn_embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
130
+
131
+ trunc_normal_(self.cls_token_1, std=.02)
132
+ trunc_normal_(self.cls_token_2, std=.02)
133
+ self.apply(self._init_weights)
134
+ def _init_weights(self, m):
135
+ if isinstance(m, nn.Linear):
136
+ trunc_normal_(m.weight, std=.02)
137
+ if isinstance(m, nn.Linear) and m.bias is not None:
138
+ nn.init.constant_(m.bias, 0)
139
+ elif isinstance(m, nn.LayerNorm):
140
+ nn.init.constant_(m.bias, 0)
141
+ nn.init.constant_(m.weight, 1.0)
142
+ elif isinstance(m, nn.Conv2d):
143
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
144
+ # fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
145
+ # fan_out //= m.groups
146
+ # m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
147
+ # if m.bias is not None:
148
+ # m.bias.data.zero_()
149
+ elif isinstance(m, nn.BatchNorm2d):
150
+ nn.init.ones_(m.weight)
151
+ nn.init.zeros_(m.bias)
152
+
153
+ @torch.jit.ignore
154
+ def no_weight_decay(self):
155
+ return {'cls_token_1','cls_token_2'}
156
+
157
+ def get_classifier(self):
158
+ return self.head
159
+
160
+ def reset_classifier(self, num_classes, global_pool=''):
161
+ self.num_classes = num_classes
162
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
163
+
164
+ def forward_features(self,x,meta=None):
165
+ B = x.shape[0]
166
+ extra_tokens_1 = [self.cls_token_1]
167
+ extra_tokens_2 = [self.cls_token_2]
168
+ if self.add_meta:
169
+ assert meta != None,'meta is None'
170
+ if len(self.meta_dims)>1:
171
+ metas = torch.split(meta,self.meta_dims,dim=1)
172
+ else:
173
+ metas = (meta,)
174
+ for ind,cur_meta in enumerate(metas):
175
+ meta_head_1 = getattr(self,f"meta_{ind+1}_head_1")
176
+ meta_head_2 = getattr(self,f"meta_{ind+1}_head_2")
177
+ meta_1 = meta_head_1(cur_meta)
178
+ meta_1 = meta_1.reshape(B, -1, self.attn_embed_dims[0])
179
+ meta_2 = meta_head_2(cur_meta)
180
+ meta_2 = meta_2.reshape(B, -1, self.attn_embed_dims[1])
181
+ extra_tokens_1.append(meta_1)
182
+ extra_tokens_2.append(meta_2)
183
+
184
+ x = self.stage_0(x)
185
+ x = self.bn1(x)
186
+ x = self.act1(x)
187
+ x = self.maxpool(x)
188
+ for blk in self.stage_1:
189
+ x = blk(x)
190
+ for blk in self.stage_2:
191
+ x = blk(x)
192
+ H0,W0 = self.img_size//8,self.img_size//8
193
+ for ind,blk in enumerate(self.stage_3):
194
+ if ind==0:
195
+ x = blk(x,H0,W0,extra_tokens_1)
196
+ else:
197
+ x = blk(x,H0,W0)
198
+ if not self.only_last_cls:
199
+ cls_1 = x[:, :1, :]
200
+ cls_1 = self.norm_1(cls_1)
201
+ cls_1 = self.cl_1_fc(cls_1)
202
+
203
+ x = x[:, self.extra_token_num:, :]
204
+ H1,W1 = self.img_size//16,self.img_size//16
205
+ x = x.reshape(B,H1,W1,-1).permute(0, 3, 1, 2).contiguous()
206
+ for ind,blk in enumerate(self.stage_4):
207
+ if ind==0:
208
+ x = blk(x,H1,W1,extra_tokens_2)
209
+ else:
210
+ x = blk(x,H1,W1)
211
+ cls_2 = x[:, :1, :]
212
+ cls_2 = self.norm_2(cls_2)
213
+ if not self.only_last_cls:
214
+ cls = torch.cat((cls_1,cls_2), dim=1)#B,2,C
215
+ cls = self.aggregate(cls).squeeze(dim=1)#B,C
216
+ cls = self.norm(cls)
217
+ else:
218
+ cls = cls_2.squeeze(dim=1)
219
+ return cls
220
+ def forward(self, x,meta=None):
221
+ if meta is not None:
222
+ if self.mask_type=='linear':
223
+ cur_mask_prob = self.mask_prob - self.cur_epoch/self.total_epoch
224
+ else:
225
+ cur_mask_prob = self.mask_prob
226
+ if cur_mask_prob != 0 and self.training:
227
+ mask = torch.ones_like(meta)
228
+ mask_index = torch.randperm(meta.size(0))[:int(meta.size(0)*cur_mask_prob)]
229
+ mask[mask_index] = 0
230
+ meta = mask * meta
231
+ x = self.forward_features(x,meta)
232
+ x = self.head(x)
233
+ return x
234
+
235
+ @register_model
236
+ def MetaFG_meta_0(pretrained=False, **kwargs):
237
+ model = MetaFG_Meta(conv_embed_dims = [64,96,192],attn_embed_dims=[384,768],
238
+ conv_depths = [2,2,3],attn_depths = [5,2],num_heads=8,mlp_ratio=4., **kwargs)
239
+ model.default_cfg = default_cfgs['MetaFG_0']
240
+ if pretrained:
241
+ load_pretrained(
242
+ model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
243
+ return model
244
+ @register_model
245
+ def MetaFG_meta_1(pretrained=False, **kwargs):
246
+ model = MetaFG_Meta(conv_embed_dims = [64,96,192],attn_embed_dims=[384,768],
247
+ conv_depths = [2,2,6],attn_depths = [14,2],num_heads=8,mlp_ratio=4., **kwargs)
248
+ model.default_cfg = default_cfgs['MetaFG_1']
249
+ if pretrained:
250
+ load_pretrained(
251
+ model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
252
+ return model
253
+ @register_model
254
+ def MetaFG_meta_2(pretrained=False, **kwargs):
255
+ model = MetaFG_Meta(conv_embed_dims = [128,128,256],attn_embed_dims=[512,1024],
256
+ conv_depths = [2,2,6],attn_depths = [14,2],num_heads=8,mlp_ratio=4., **kwargs)
257
+ model.default_cfg = default_cfgs['MetaFG_2']
258
+ if pretrained:
259
+ load_pretrained(
260
+ model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
261
+ return model
262
+ if __name__ == "__main__":
263
+ x = torch.randn([2, 3, 224, 224])
264
+ meta = torch.randn([2,7])
265
+ model = MetaFG_meta()
266
+ import ipdb;ipdb.set_trace()
267
+ output = model(x,meta)
268
+ print(output.shape)
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .build import build_model
models/build.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from timm.models import create_model
2
+ from .MetaFG import *
3
+ from .MetaFG_meta import *
4
+ def build_model(config):
5
+ model_type = config.MODEL.TYPE
6
+ if model_type == 'MetaFG':
7
+ model = create_model(
8
+ config.MODEL.NAME,
9
+ pretrained=False,
10
+ num_classes=config.MODEL.NUM_CLASSES,
11
+ drop_path_rate=config.MODEL.DROP_PATH_RATE,
12
+ img_size=config.DATA.IMG_SIZE,
13
+ only_last_cls=config.MODEL.ONLY_LAST_CLS,
14
+ extra_token_num=config.MODEL.EXTRA_TOKEN_NUM,
15
+ meta_dims=config.MODEL.META_DIMS
16
+ )
17
+ else:
18
+ raise NotImplementedError(f"Unkown model: {model_type}")
19
+
20
+ return model
models/meta_encoder.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ class ResNormLayer(nn.Module):
3
+ def __init__(self, linear_size,):
4
+ super(ResNormLayer, self).__init__()
5
+ self.l_size = linear_size
6
+ self.nonlin1 = nn.ReLU(inplace=True)
7
+ self.nonlin2 = nn.ReLU(inplace=True)
8
+ self.norm_fn1 = nn.LayerNorm(self.l_size)
9
+ self.norm_fn2 = nn.LayerNorm(self.l_size)
10
+ self.w1 = nn.Linear(self.l_size, self.l_size)
11
+ self.w2 = nn.Linear(self.l_size, self.l_size)
12
+
13
+ def forward(self, x):
14
+ y = self.w1(x)
15
+ y = self.nonlin1(y)
16
+ y = self.norm_fn1(y)
17
+ y = self.w2(y)
18
+ y = self.nonlin2(y)
19
+ y = self.norm_fn2(y)
20
+ out = x + y
21
+ return out
optimizer.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import optim as optim
2
+
3
+
4
+ def build_optimizer(config, model):
5
+ """
6
+ Build optimizer, set weight decay of normalization to 0 by default.
7
+ """
8
+ skip = {}
9
+ skip_keywords = {}
10
+ if hasattr(model, 'no_weight_decay'):
11
+ skip = model.no_weight_decay()
12
+ if hasattr(model, 'no_weight_decay_keywords'):
13
+ skip_keywords = model.no_weight_decay_keywords()
14
+ parameters = set_weight_decay(model, skip, skip_keywords,config.TRAIN.BASE_LR)
15
+
16
+ opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
17
+ optimizer = None
18
+ if opt_lower == 'sgd':
19
+ optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
20
+ lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
21
+ elif opt_lower == 'adamw':
22
+ optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
23
+ lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
24
+
25
+ return optimizer
26
+
27
+ # def set_weight_decay(model, skip_list=(), skip_keywords=(),lr=0.0):
28
+ # has_decay = []
29
+ # no_decay = []
30
+ # high_lr = []
31
+ # for name, param in model.named_parameters():
32
+ # if not param.requires_grad:
33
+ # continue # frozen weights
34
+ # if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
35
+ # check_keywords_in_name(name, skip_keywords):
36
+ # if 'meta' in name:
37
+ # high_lr.append(param)
38
+ # else:
39
+ # no_decay.append(param)
40
+ # # print(f"{name} has no weight decay")
41
+ # else:
42
+ # has_decay.append(param)
43
+ # return [{'params': has_decay},
44
+ # # {'params':high_lr,'weight_decay': 0.,'lr':lr*10},
45
+ # {'params':high_lr,'lr':lr*20},
46
+ # {'params': no_decay, 'weight_decay': 0.}]
47
+
48
+ def set_weight_decay(model, skip_list=(), skip_keywords=(),lr=0.0):
49
+ has_decay = []
50
+ no_decay = []
51
+
52
+ for name, param in model.named_parameters():
53
+ if not param.requires_grad:
54
+ continue # frozen weights
55
+ if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
56
+ check_keywords_in_name(name, skip_keywords):
57
+ no_decay.append(param)
58
+ # print(f"{name} has no weight decay")
59
+ else:
60
+ has_decay.append(param)
61
+ return [{'params': has_decay},
62
+ {'params': no_decay, 'weight_decay': 0.}]
63
+
64
+
65
+ def check_keywords_in_name(name, keywords=()):
66
+ isin = False
67
+ for keyword in keywords:
68
+ if keyword in name:
69
+ isin = True
70
+ return isin
utils.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import importlib
4
+ import torch.distributed as dist
5
+
6
+ try:
7
+ # noinspection PyUnresolvedReferences
8
+ from apex import amp
9
+ except ImportError:
10
+ amp = None
11
+
12
+ def relative_bias_interpolate(checkpoint,config):
13
+ for k in list(checkpoint['model']):
14
+ if 'relative_position_index' in k:
15
+ del checkpoint['model'][k]
16
+ if 'relative_position_bias_table' in k:
17
+ relative_position_bias_table = checkpoint['model'][k]
18
+ cls_bias = relative_position_bias_table[:1,:]
19
+ relative_position_bias_table = relative_position_bias_table[1:,:]
20
+ size = int(relative_position_bias_table.shape[0]**0.5)
21
+ img_size = (size+1)//2
22
+ if 'stage_3' in k:
23
+ downsample_ratio = 16
24
+ elif 'stage_4' in k:
25
+ downsample_ratio = 32
26
+ new_img_size = config.DATA.IMG_SIZE//downsample_ratio
27
+ new_size = 2*new_img_size-1
28
+ if new_size == size:
29
+ continue
30
+ relative_position_bias_table = relative_position_bias_table.reshape(size,size,-1)
31
+ relative_position_bias_table = relative_position_bias_table.unsqueeze(0).permute(0,3,1,2)#bs,nhead,h,w
32
+ relative_position_bias_table = torch.nn.functional.interpolate(
33
+ relative_position_bias_table, size=(new_size, new_size), mode='bicubic', align_corners=False)
34
+ relative_position_bias_table = relative_position_bias_table.permute(0,2,3,1)
35
+ relative_position_bias_table = relative_position_bias_table.squeeze(0).reshape(new_size*new_size,-1)
36
+ relative_position_bias_table = torch.cat((cls_bias,relative_position_bias_table),dim=0)
37
+ checkpoint['model'][k] = relative_position_bias_table
38
+ return checkpoint
39
+
40
+
41
+ def load_pretained(config,model,logger=None,strict=False):
42
+ if logger is not None:
43
+ logger.info(f"==============> pretrain form {config.MODEL.PRETRAINED}....................")
44
+ checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
45
+ if 'model' not in checkpoint:
46
+ if 'state_dict_ema' in checkpoint:
47
+ checkpoint['model'] = checkpoint['state_dict_ema']
48
+ else:
49
+ checkpoint['model'] = checkpoint
50
+ if config.MODEL.DORP_HEAD:
51
+ if 'head.weight' in checkpoint['model'] and 'head.bias' in checkpoint['model']:
52
+ if logger is not None:
53
+ logger.info(f"==============> drop head....................")
54
+ del checkpoint['model']['head.weight']
55
+ del checkpoint['model']['head.bias']
56
+ if 'head.fc.weight' in checkpoint['model'] and 'head.fc.bias' in checkpoint['model']:
57
+ if logger is not None:
58
+ logger.info(f"==============> drop head....................")
59
+ del checkpoint['model']['head.fc.weight']
60
+ del checkpoint['model']['head.fc.bias']
61
+ if config.MODEL.DORP_META:
62
+ if logger is not None:
63
+ logger.info(f"==============> drop meta head....................")
64
+ for k in list(checkpoint['model']):
65
+ if 'meta' in k:
66
+ del checkpoint['model'][k]
67
+
68
+ checkpoint = relative_bias_interpolate(checkpoint,config)
69
+ if 'point_coord' in checkpoint['model']:
70
+ if logger is not None:
71
+ logger.info(f"==============> drop point coord....................")
72
+ del checkpoint['model']['point_coord']
73
+ msg = model.load_state_dict(checkpoint['model'], strict=strict)
74
+ del checkpoint
75
+ torch.cuda.empty_cache()
76
+
77
+
78
+ def load_checkpoint(config, model, optimizer, lr_scheduler, logger):
79
+ logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
80
+ if config.MODEL.RESUME.startswith('https'):
81
+ checkpoint = torch.hub.load_state_dict_from_url(
82
+ config.MODEL.RESUME, map_location='cpu', check_hash=True)
83
+ else:
84
+ checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
85
+ if 'model' not in checkpoint:
86
+ if 'state_dict_ema' in checkpoint:
87
+ checkpoint['model'] = checkpoint['state_dict_ema']
88
+ else:
89
+ checkpoint['model'] = checkpoint
90
+ msg = model.load_state_dict(checkpoint['model'], strict=False)
91
+ logger.info(msg)
92
+ max_accuracy = 0.0
93
+ if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
94
+ optimizer.load_state_dict(checkpoint['optimizer'])
95
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
96
+ config.defrost()
97
+ config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
98
+ config.freeze()
99
+ if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0":
100
+ amp.load_state_dict(checkpoint['amp'])
101
+ logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
102
+ if 'max_accuracy' in checkpoint:
103
+ max_accuracy = checkpoint['max_accuracy']
104
+
105
+ del checkpoint
106
+ torch.cuda.empty_cache()
107
+ return max_accuracy
108
+
109
+
110
+ def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger):
111
+ save_state = {'model': model.state_dict(),
112
+ 'optimizer': optimizer.state_dict(),
113
+ 'lr_scheduler': lr_scheduler.state_dict(),
114
+ 'max_accuracy': max_accuracy,
115
+ 'epoch': epoch,
116
+ 'config': config}
117
+ if config.AMP_OPT_LEVEL != "O0":
118
+ save_state['amp'] = amp.state_dict()
119
+
120
+ save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
121
+ logger.info(f"{save_path} saving......")
122
+ torch.save(save_state, save_path)
123
+ logger.info(f"{save_path} saved !!!")
124
+
125
+
126
+ lastest_save_path = os.path.join(config.OUTPUT, f'latest.pth')
127
+ logger.info(f"{lastest_save_path} saving......")
128
+ torch.save(save_state, lastest_save_path)
129
+ logger.info(f"{lastest_save_path} saved !!!")
130
+
131
+
132
+
133
+ def get_grad_norm(parameters, norm_type=2):
134
+ if isinstance(parameters, torch.Tensor):
135
+ parameters = [parameters]
136
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
137
+ norm_type = float(norm_type)
138
+ total_norm = 0
139
+ for p in parameters:
140
+ param_norm = p.grad.data.norm(norm_type)
141
+ total_norm += param_norm.item() ** norm_type
142
+ total_norm = total_norm ** (1. / norm_type)
143
+ return total_norm
144
+
145
+
146
+ def auto_resume_helper(output_dir):
147
+ checkpoints = os.listdir(output_dir)
148
+ checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
149
+ print(f"All checkpoints founded in {output_dir}: {checkpoints}")
150
+ if len(checkpoints) > 0:
151
+ latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
152
+ print(f"The latest checkpoint founded: {latest_checkpoint}")
153
+ resume_file = latest_checkpoint
154
+ else:
155
+ resume_file = None
156
+ return resume_file
157
+
158
+
159
+ def reduce_tensor(tensor):
160
+ rt = tensor.clone()
161
+ dist.all_reduce(rt, op=dist.ReduceOp.SUM)
162
+ rt /= dist.get_world_size()
163
+ return rt
164
+
165
+
166
+
167
+
168
+ def load_ext(name, funcs):
169
+ ext = importlib.import_module(name)
170
+ for fun in funcs:
171
+ assert hasattr(ext, fun), f'{fun} miss in module {name}'
172
+ return ext
173
+