Spaces:
Configuration error
Configuration error
Commit
·
1ab1a09
1
Parent(s):
7a15352
Added model *.pdparams
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- LICENSE +21 -0
- README.md +19 -11
- bg_replace.py +146 -0
- configs/ppmatting/README.md +20 -0
- configs/ppmatting/ppmatting-hrnet_w18-human_1024.yml +29 -0
- configs/ppmatting/ppmatting-hrnet_w18-human_512.yml +44 -0
- configs/ppmatting/ppmatting-hrnet_w48-composition.yml +7 -0
- configs/ppmatting/ppmatting-hrnet_w48-distinctions.yml +55 -0
- image/readme.md +1 -0
- models/ppmatting-hrnet_w18-human_1024.pdparams +3 -0
- models/readme.md +1 -0
- output/readme.md +1 -0
- paddleseg/__init__.py +17 -0
- paddleseg/core/__init__.py +20 -0
- paddleseg/core/infer.py +232 -0
- paddleseg/core/predict.py +147 -0
- paddleseg/core/train.py +334 -0
- paddleseg/core/val.py +237 -0
- paddleseg/cvlibs/__init__.py +17 -0
- paddleseg/cvlibs/callbacks.py +279 -0
- paddleseg/cvlibs/config.py +445 -0
- paddleseg/cvlibs/manager.py +147 -0
- paddleseg/cvlibs/param_init.py +146 -0
- paddleseg/datasets/__init__.py +30 -0
- paddleseg/datasets/ade.py +119 -0
- paddleseg/datasets/chase_db1.py +98 -0
- paddleseg/datasets/cityscapes.py +88 -0
- paddleseg/datasets/cocostuff.py +83 -0
- paddleseg/datasets/dataset.py +163 -0
- paddleseg/datasets/drive.py +96 -0
- paddleseg/datasets/eg1800.py +137 -0
- paddleseg/datasets/hrf.py +95 -0
- paddleseg/datasets/mini_deep_globe_road_extraction.py +95 -0
- paddleseg/datasets/optic_disc_seg.py +97 -0
- paddleseg/datasets/pascal_context.py +86 -0
- paddleseg/datasets/pp_humanseg14k.py +82 -0
- paddleseg/datasets/pssl.py +135 -0
- paddleseg/datasets/stare.py +95 -0
- paddleseg/datasets/supervisely.py +136 -0
- paddleseg/datasets/voc.py +112 -0
- paddleseg/models/ann.py +434 -0
- paddleseg/models/attention_unet.py +189 -0
- paddleseg/models/backbones/__init__.py +26 -0
- paddleseg/models/backbones/ghostnet.py +318 -0
- paddleseg/models/backbones/hrnet.py +837 -0
- paddleseg/models/backbones/lite_hrnet.py +972 -0
- paddleseg/models/backbones/mix_transformer.py +593 -0
- paddleseg/models/backbones/mobilenetv2.py +264 -0
- paddleseg/models/backbones/mobilenetv3.py +496 -0
.gitattributes
CHANGED
@@ -29,3 +29,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.pdparams filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Anagha S Menon
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,20 @@
|
|
1 |
-
|
2 |
-
title: Pipeline Paddle
|
3 |
-
emoji: 😻
|
4 |
-
colorFrom: pink
|
5 |
-
colorTo: pink
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.1.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pipeline_paddle_viton
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
To run
|
4 |
+
|
5 |
+
<br> Step 1: git clone https://github.com/ANAGHA-20/pipeline-paddle
|
6 |
+
<br> Step 2: %cd ./pipeline_paddle_viton
|
7 |
+
<br> Step 3: pip install -r requirements.txt
|
8 |
+
<br> pip install paddlepaddle-gpu
|
9 |
+
<br> pip install pymatting
|
10 |
+
<br> import os
|
11 |
+
<br> Step 4: export CUDA_VISIBLE_DEVICES=0
|
12 |
+
<br> Step 5: wget "https://paddleseg.bj.bcebos.com/matting/models/ppmatting-hrnet_w18-human_1024.pdparams" -O "/content/pipeline_paddle_viton/models/ppmatting-hrnet_w18-human_1024.pdparams"
|
13 |
+
<br> Step 6: run
|
14 |
+
<br>
|
15 |
+
!python bg_replace.py \
|
16 |
+
--config configs/ppmatting/ppmatting-hrnet_w18-human_1024.yml \
|
17 |
+
--model_path models/ppmatting-hrnet_w18-human_1024.pdparams \
|
18 |
+
--image_path ./image/person.jpg \
|
19 |
+
--save_dir ./output \
|
20 |
+
--fg_estimate True
|
bg_replace.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import os
|
17 |
+
import sys
|
18 |
+
|
19 |
+
import cv2
|
20 |
+
import numpy as np
|
21 |
+
import paddle
|
22 |
+
from paddleseg.cvlibs import manager, Config
|
23 |
+
from paddleseg.utils import get_sys_env, logger
|
24 |
+
|
25 |
+
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
|
26 |
+
sys.path.append(os.path.join(LOCAL_PATH, '..'))
|
27 |
+
|
28 |
+
manager.BACKBONES._components_dict.clear()
|
29 |
+
manager.TRANSFORMS._components_dict.clear()
|
30 |
+
|
31 |
+
import ppmatting
|
32 |
+
from ppmatting.core import predict
|
33 |
+
from ppmatting.utils import get_image_list, estimate_foreground_ml
|
34 |
+
|
35 |
+
|
36 |
+
def parse_args():
|
37 |
+
parser = argparse.ArgumentParser(
|
38 |
+
description='PP-HumanSeg inference for video')
|
39 |
+
parser.add_argument(
|
40 |
+
"--config",
|
41 |
+
dest="cfg",
|
42 |
+
help="The config file.",
|
43 |
+
default=None,
|
44 |
+
type=str,
|
45 |
+
required=True)
|
46 |
+
parser.add_argument(
|
47 |
+
'--model_path',
|
48 |
+
dest='model_path',
|
49 |
+
help='The path of model for prediction',
|
50 |
+
type=str,
|
51 |
+
default=None)
|
52 |
+
parser.add_argument(
|
53 |
+
'--image_path',
|
54 |
+
dest='image_path',
|
55 |
+
help='Image including human',
|
56 |
+
type=str,
|
57 |
+
default=None)
|
58 |
+
parser.add_argument(
|
59 |
+
'--trimap_path',
|
60 |
+
dest='trimap_path',
|
61 |
+
help='The path of trimap',
|
62 |
+
type=str,
|
63 |
+
default=None)
|
64 |
+
parser.add_argument(
|
65 |
+
'--background',
|
66 |
+
dest='background',
|
67 |
+
help='Background for replacing. It is a string which specifies the background color (r,g,b,w) or a path to background image. If not specified, a green background is used.',
|
68 |
+
type=str,
|
69 |
+
default=None)
|
70 |
+
parser.add_argument(
|
71 |
+
'--save_dir',
|
72 |
+
dest='save_dir',
|
73 |
+
help='The directory for saving the inference results',
|
74 |
+
type=str,
|
75 |
+
default='./output')
|
76 |
+
parser.add_argument(
|
77 |
+
'--fg_estimate',
|
78 |
+
default=True,
|
79 |
+
type=eval,
|
80 |
+
choices=[True, False],
|
81 |
+
help='Whether to estimate foreground when predicting.')
|
82 |
+
|
83 |
+
return parser.parse_args()
|
84 |
+
|
85 |
+
|
86 |
+
def main(args):
|
87 |
+
env_info = get_sys_env()
|
88 |
+
place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[
|
89 |
+
'GPUs used'] else 'cpu'
|
90 |
+
paddle.set_device(place)
|
91 |
+
if not args.cfg:
|
92 |
+
raise RuntimeError('No configuration file specified.')
|
93 |
+
|
94 |
+
cfg = Config(args.cfg)
|
95 |
+
|
96 |
+
msg = '\n---------------Config Information---------------\n'
|
97 |
+
msg += str(cfg)
|
98 |
+
msg += '------------------------------------------------'
|
99 |
+
logger.info(msg)
|
100 |
+
|
101 |
+
model = cfg.model
|
102 |
+
transforms = ppmatting.transforms.Compose(cfg.val_transforms)
|
103 |
+
|
104 |
+
alpha, fg = predict(
|
105 |
+
model,
|
106 |
+
model_path=args.model_path,
|
107 |
+
transforms=transforms,
|
108 |
+
image_list=[args.image_path],
|
109 |
+
trimap_list=[args.trimap_path],
|
110 |
+
save_dir=args.save_dir,
|
111 |
+
fg_estimate=args.fg_estimate)
|
112 |
+
|
113 |
+
img_ori = cv2.imread(args.image_path)
|
114 |
+
bg = get_bg(args.background, img_ori.shape)
|
115 |
+
alpha = alpha / 255.0
|
116 |
+
alpha = alpha[:, :, np.newaxis]
|
117 |
+
com = alpha * fg + (1 - alpha) * bg
|
118 |
+
com = com.astype('uint8')
|
119 |
+
com_save_path = os.path.join(args.save_dir,
|
120 |
+
os.path.basename(args.image_path))
|
121 |
+
cv2.imwrite(com_save_path, com)
|
122 |
+
|
123 |
+
|
124 |
+
def get_bg(background, img_shape):
|
125 |
+
bg = np.zeros(img_shape)
|
126 |
+
if background == 'r':
|
127 |
+
bg[:, :, 2] = 255
|
128 |
+
elif background is None or background == 'g':
|
129 |
+
bg[:, :, 1] = 255
|
130 |
+
elif background == 'b':
|
131 |
+
bg[:, :, 0] = 255
|
132 |
+
elif background == 'w':
|
133 |
+
bg[:, :, :] = 255
|
134 |
+
|
135 |
+
elif not os.path.exists(background):
|
136 |
+
raise Exception('The --background is not existed: {}'.format(
|
137 |
+
background))
|
138 |
+
else:
|
139 |
+
bg = cv2.imread(background)
|
140 |
+
bg = cv2.resize(bg, (img_shape[1], img_shape[0]))
|
141 |
+
return bg
|
142 |
+
|
143 |
+
|
144 |
+
if __name__ == "__main__":
|
145 |
+
args = parse_args()
|
146 |
+
main(args)
|
configs/ppmatting/README.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PP-Matting: High-Accuracy Natural Image Matting
|
2 |
+
|
3 |
+
## Reference
|
4 |
+
|
5 |
+
> Chen G, Liu Y, Wang J, et al. PP-Matting: High-Accuracy Natural Image Matting[J]. arXiv preprint arXiv:2204.09433, 2022.
|
6 |
+
|
7 |
+
## Performance
|
8 |
+
|
9 |
+
### Composition-1k
|
10 |
+
|
11 |
+
| Model | Backbone | Resolution | Training Iters | SAD $\downarrow$ | MSE $\downarrow$ | Grad $\downarrow$ | Conn $\downarrow$ | Links |
|
12 |
+
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
|
13 |
+
|PP-Matting|HRNet_W48|512x512|300000|46.22|0.005|22.69|45.40|[model](https://paddleseg.bj.bcebos.com/matting/models/ppmatting-hrnet_w48-composition.pdparams)|
|
14 |
+
|
15 |
+
|
16 |
+
### Distinctions-646
|
17 |
+
|
18 |
+
| Model | Backbone | Resolution | Training Iters | SAD $\downarrow$ | MSE $\downarrow$ | Grad $\downarrow$ | Conn $\downarrow$ | Links |
|
19 |
+
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
|
20 |
+
|PP-Matting|HRNet_W48|512x512|300000|40.69|0.009|43.91|40.56|[model](https://paddleseg.bj.bcebos.com/matting/models/ppmatting-hrnet_w48-distinctions.pdparams)|
|
configs/ppmatting/ppmatting-hrnet_w18-human_1024.yml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_: 'ppmatting-hrnet_w18-human_512.yml'
|
2 |
+
|
3 |
+
|
4 |
+
train_dataset:
|
5 |
+
transforms:
|
6 |
+
- type: LoadImages
|
7 |
+
- type: LimitShort
|
8 |
+
max_short: 1024
|
9 |
+
- type: RandomCrop
|
10 |
+
crop_size: [1024, 1024]
|
11 |
+
- type: RandomDistort
|
12 |
+
- type: RandomBlur
|
13 |
+
prob: 0.1
|
14 |
+
- type: RandomNoise
|
15 |
+
prob: 0.5
|
16 |
+
- type: RandomReJpeg
|
17 |
+
prob: 0.2
|
18 |
+
- type: RandomHorizontalFlip
|
19 |
+
- type: Normalize
|
20 |
+
|
21 |
+
val_dataset:
|
22 |
+
transforms:
|
23 |
+
- type: LoadImages
|
24 |
+
- type: LimitShort
|
25 |
+
max_short: 1024
|
26 |
+
- type: ResizeToIntMult
|
27 |
+
mult_int: 32
|
28 |
+
- type: Normalize
|
29 |
+
|
configs/ppmatting/ppmatting-hrnet_w18-human_512.yml
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_: 'ppmatting-hrnet_w48-distinctions.yml'
|
2 |
+
|
3 |
+
batch_size: 4
|
4 |
+
iters: 200000
|
5 |
+
|
6 |
+
train_dataset:
|
7 |
+
type: MattingDataset
|
8 |
+
dataset_root: data/PPM-100
|
9 |
+
train_file: train.txt
|
10 |
+
transforms:
|
11 |
+
- type: LoadImages
|
12 |
+
- type: LimitShort
|
13 |
+
max_short: 512
|
14 |
+
- type: RandomCrop
|
15 |
+
crop_size: [512, 512]
|
16 |
+
- type: RandomDistort
|
17 |
+
- type: RandomBlur
|
18 |
+
prob: 0.1
|
19 |
+
- type: RandomNoise
|
20 |
+
prob: 0.5
|
21 |
+
- type: RandomReJpeg
|
22 |
+
prob: 0.2
|
23 |
+
- type: RandomHorizontalFlip
|
24 |
+
- type: Normalize
|
25 |
+
mode: train
|
26 |
+
|
27 |
+
val_dataset:
|
28 |
+
type: MattingDataset
|
29 |
+
dataset_root: data/PPM-100
|
30 |
+
val_file: val.txt
|
31 |
+
transforms:
|
32 |
+
- type: LoadImages
|
33 |
+
- type: LimitShort
|
34 |
+
max_short: 512
|
35 |
+
- type: ResizeToIntMult
|
36 |
+
mult_int: 32
|
37 |
+
- type: Normalize
|
38 |
+
mode: val
|
39 |
+
get_trimap: False
|
40 |
+
|
41 |
+
model:
|
42 |
+
backbone:
|
43 |
+
type: HRNet_W18
|
44 |
+
pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
|
configs/ppmatting/ppmatting-hrnet_w48-composition.yml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_: 'ppmatting-hrnet_w48-distinctions.yml'
|
2 |
+
|
3 |
+
train_dataset:
|
4 |
+
dataset_root: data/matting/Composition-1k
|
5 |
+
|
6 |
+
val_dataset:
|
7 |
+
dataset_root: data/matting/Composition-1k
|
configs/ppmatting/ppmatting-hrnet_w48-distinctions.yml
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
batch_size: 4
|
2 |
+
iters: 300000
|
3 |
+
|
4 |
+
train_dataset:
|
5 |
+
type: MattingDataset
|
6 |
+
dataset_root: data/matting/Distinctions-646
|
7 |
+
train_file: train.txt
|
8 |
+
transforms:
|
9 |
+
- type: LoadImages
|
10 |
+
- type: Padding
|
11 |
+
target_size: [512, 512]
|
12 |
+
- type: RandomCrop
|
13 |
+
crop_size: [[512, 512],[640, 640], [800, 800]]
|
14 |
+
- type: Resize
|
15 |
+
target_size: [512, 512]
|
16 |
+
- type: RandomDistort
|
17 |
+
- type: RandomBlur
|
18 |
+
prob: 0.1
|
19 |
+
- type: RandomHorizontalFlip
|
20 |
+
- type: Normalize
|
21 |
+
mode: train
|
22 |
+
separator: '|'
|
23 |
+
|
24 |
+
val_dataset:
|
25 |
+
type: MattingDataset
|
26 |
+
dataset_root: data/matting/Distinctions-646
|
27 |
+
val_file: val.txt
|
28 |
+
transforms:
|
29 |
+
- type: LoadImages
|
30 |
+
- type: LimitShort
|
31 |
+
max_short: 1536
|
32 |
+
- type: ResizeToIntMult
|
33 |
+
mult_int: 32
|
34 |
+
- type: Normalize
|
35 |
+
mode: val
|
36 |
+
get_trimap: False
|
37 |
+
separator: '|'
|
38 |
+
|
39 |
+
model:
|
40 |
+
type: PPMatting
|
41 |
+
backbone:
|
42 |
+
type: HRNet_W48
|
43 |
+
pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w48_ssld.tar.gz
|
44 |
+
pretrained: Null
|
45 |
+
|
46 |
+
optimizer:
|
47 |
+
type: sgd
|
48 |
+
momentum: 0.9
|
49 |
+
weight_decay: 4.0e-5
|
50 |
+
|
51 |
+
lr_scheduler:
|
52 |
+
type: PolynomialDecay
|
53 |
+
learning_rate: 0.01
|
54 |
+
end_lr: 0
|
55 |
+
power: 0.9
|
image/readme.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Upload image of person as person.jpg
|
models/ppmatting-hrnet_w18-human_1024.pdparams
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:65315fb68255266cb9adedc6879ba1e28ed7b84e5d02c0dc7ad8caace8370011
|
3 |
+
size 98439023
|
models/readme.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
models required
|
output/readme.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Output of paddle using ppmatting will be available here
|
paddleseg/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from . import models, datasets, transforms
|
16 |
+
|
17 |
+
__version__ = '2.6.0'
|
paddleseg/core/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .train import train
|
16 |
+
from .val import evaluate
|
17 |
+
from .predict import predict
|
18 |
+
from . import infer
|
19 |
+
|
20 |
+
__all__ = ['train', 'evaluate', 'predict']
|
paddleseg/core/infer.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import collections.abc
|
16 |
+
from itertools import combinations
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import cv2
|
20 |
+
import paddle
|
21 |
+
import paddle.nn.functional as F
|
22 |
+
|
23 |
+
|
24 |
+
def reverse_transform(pred, trans_info, mode='nearest'):
|
25 |
+
"""recover pred to origin shape"""
|
26 |
+
intTypeList = [paddle.int8, paddle.int16, paddle.int32, paddle.int64]
|
27 |
+
dtype = pred.dtype
|
28 |
+
for item in trans_info[::-1]:
|
29 |
+
if isinstance(item[0], list):
|
30 |
+
trans_mode = item[0][0]
|
31 |
+
else:
|
32 |
+
trans_mode = item[0]
|
33 |
+
if trans_mode == 'resize':
|
34 |
+
h, w = item[1][0], item[1][1]
|
35 |
+
if paddle.get_device() == 'cpu' and dtype in intTypeList:
|
36 |
+
pred = paddle.cast(pred, 'float32')
|
37 |
+
pred = F.interpolate(pred, (h, w), mode=mode)
|
38 |
+
pred = paddle.cast(pred, dtype)
|
39 |
+
else:
|
40 |
+
pred = F.interpolate(pred, (h, w), mode=mode)
|
41 |
+
elif trans_mode == 'padding':
|
42 |
+
h, w = item[1][0], item[1][1]
|
43 |
+
pred = pred[:, :, 0:h, 0:w]
|
44 |
+
else:
|
45 |
+
raise Exception("Unexpected info '{}' in im_info".format(item[0]))
|
46 |
+
return pred
|
47 |
+
|
48 |
+
|
49 |
+
def flip_combination(flip_horizontal=False, flip_vertical=False):
|
50 |
+
"""
|
51 |
+
Get flip combination.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
flip_horizontal (bool): Whether to flip horizontally. Default: False.
|
55 |
+
flip_vertical (bool): Whether to flip vertically. Default: False.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
list: List of tuple. The first element of tuple is whether to flip horizontally,
|
59 |
+
and the second is whether to flip vertically.
|
60 |
+
"""
|
61 |
+
|
62 |
+
flip_comb = [(False, False)]
|
63 |
+
if flip_horizontal:
|
64 |
+
flip_comb.append((True, False))
|
65 |
+
if flip_vertical:
|
66 |
+
flip_comb.append((False, True))
|
67 |
+
if flip_horizontal:
|
68 |
+
flip_comb.append((True, True))
|
69 |
+
return flip_comb
|
70 |
+
|
71 |
+
|
72 |
+
def tensor_flip(x, flip):
|
73 |
+
"""Flip tensor according directions"""
|
74 |
+
if flip[0]:
|
75 |
+
x = x[:, :, :, ::-1]
|
76 |
+
if flip[1]:
|
77 |
+
x = x[:, :, ::-1, :]
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
def slide_inference(model, im, crop_size, stride):
|
82 |
+
"""
|
83 |
+
Infer by sliding window.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
model (paddle.nn.Layer): model to get logits of image.
|
87 |
+
im (Tensor): the input image.
|
88 |
+
crop_size (tuple|list). The size of sliding window, (w, h).
|
89 |
+
stride (tuple|list). The size of stride, (w, h).
|
90 |
+
|
91 |
+
Return:
|
92 |
+
Tensor: The logit of input image.
|
93 |
+
"""
|
94 |
+
h_im, w_im = im.shape[-2:]
|
95 |
+
w_crop, h_crop = crop_size
|
96 |
+
w_stride, h_stride = stride
|
97 |
+
# calculate the crop nums
|
98 |
+
rows = np.int(np.ceil(1.0 * (h_im - h_crop) / h_stride)) + 1
|
99 |
+
cols = np.int(np.ceil(1.0 * (w_im - w_crop) / w_stride)) + 1
|
100 |
+
# prevent negative sliding rounds when imgs after scaling << crop_size
|
101 |
+
rows = 1 if h_im <= h_crop else rows
|
102 |
+
cols = 1 if w_im <= w_crop else cols
|
103 |
+
# TODO 'Tensor' object does not support item assignment. If support, use tensor to calculation.
|
104 |
+
final_logit = None
|
105 |
+
count = np.zeros([1, 1, h_im, w_im])
|
106 |
+
for r in range(rows):
|
107 |
+
for c in range(cols):
|
108 |
+
h1 = r * h_stride
|
109 |
+
w1 = c * w_stride
|
110 |
+
h2 = min(h1 + h_crop, h_im)
|
111 |
+
w2 = min(w1 + w_crop, w_im)
|
112 |
+
h1 = max(h2 - h_crop, 0)
|
113 |
+
w1 = max(w2 - w_crop, 0)
|
114 |
+
im_crop = im[:, :, h1:h2, w1:w2]
|
115 |
+
logits = model(im_crop)
|
116 |
+
if not isinstance(logits, collections.abc.Sequence):
|
117 |
+
raise TypeError(
|
118 |
+
"The type of logits must be one of collections.abc.Sequence, e.g. list, tuple. But received {}"
|
119 |
+
.format(type(logits)))
|
120 |
+
logit = logits[0].numpy()
|
121 |
+
if final_logit is None:
|
122 |
+
final_logit = np.zeros([1, logit.shape[1], h_im, w_im])
|
123 |
+
final_logit[:, :, h1:h2, w1:w2] += logit[:, :, :h2 - h1, :w2 - w1]
|
124 |
+
count[:, :, h1:h2, w1:w2] += 1
|
125 |
+
if np.sum(count == 0) != 0:
|
126 |
+
raise RuntimeError(
|
127 |
+
'There are pixel not predicted. It is possible that stride is greater than crop_size'
|
128 |
+
)
|
129 |
+
final_logit = final_logit / count
|
130 |
+
final_logit = paddle.to_tensor(final_logit)
|
131 |
+
return final_logit
|
132 |
+
|
133 |
+
|
134 |
+
def inference(model,
|
135 |
+
im,
|
136 |
+
trans_info=None,
|
137 |
+
is_slide=False,
|
138 |
+
stride=None,
|
139 |
+
crop_size=None):
|
140 |
+
"""
|
141 |
+
Inference for image.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
model (paddle.nn.Layer): model to get logits of image.
|
145 |
+
im (Tensor): the input image.
|
146 |
+
trans_info (list): Image shape informating changed process. Default: None.
|
147 |
+
is_slide (bool): Whether to infer by sliding window. Default: False.
|
148 |
+
crop_size (tuple|list). The size of sliding window, (w, h). It should be probided if is_slide is True.
|
149 |
+
stride (tuple|list). The size of stride, (w, h). It should be probided if is_slide is True.
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
Tensor: If ori_shape is not None, a prediction with shape (1, 1, h, w) is returned.
|
153 |
+
If ori_shape is None, a logit with shape (1, num_classes, h, w) is returned.
|
154 |
+
"""
|
155 |
+
if hasattr(model, 'data_format') and model.data_format == 'NHWC':
|
156 |
+
im = im.transpose((0, 2, 3, 1))
|
157 |
+
if not is_slide:
|
158 |
+
logits = model(im)
|
159 |
+
if not isinstance(logits, collections.abc.Sequence):
|
160 |
+
raise TypeError(
|
161 |
+
"The type of logits must be one of collections.abc.Sequence, e.g. list, tuple. But received {}"
|
162 |
+
.format(type(logits)))
|
163 |
+
logit = logits[0]
|
164 |
+
else:
|
165 |
+
logit = slide_inference(model, im, crop_size=crop_size, stride=stride)
|
166 |
+
if hasattr(model, 'data_format') and model.data_format == 'NHWC':
|
167 |
+
logit = logit.transpose((0, 3, 1, 2))
|
168 |
+
if trans_info is not None:
|
169 |
+
logit = reverse_transform(logit, trans_info, mode='bilinear')
|
170 |
+
pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
|
171 |
+
return pred, logit
|
172 |
+
else:
|
173 |
+
return logit
|
174 |
+
|
175 |
+
|
176 |
+
def aug_inference(model,
|
177 |
+
im,
|
178 |
+
trans_info,
|
179 |
+
scales=1.0,
|
180 |
+
flip_horizontal=False,
|
181 |
+
flip_vertical=False,
|
182 |
+
is_slide=False,
|
183 |
+
stride=None,
|
184 |
+
crop_size=None):
|
185 |
+
"""
|
186 |
+
Infer with augmentation.
|
187 |
+
|
188 |
+
Args:
|
189 |
+
model (paddle.nn.Layer): model to get logits of image.
|
190 |
+
im (Tensor): the input image.
|
191 |
+
trans_info (list): Transforms for image.
|
192 |
+
scales (float|tuple|list): Scales for resize. Default: 1.
|
193 |
+
flip_horizontal (bool): Whether to flip horizontally. Default: False.
|
194 |
+
flip_vertical (bool): Whether to flip vertically. Default: False.
|
195 |
+
is_slide (bool): Whether to infer by sliding wimdow. Default: False.
|
196 |
+
crop_size (tuple|list). The size of sliding window, (w, h). It should be probided if is_slide is True.
|
197 |
+
stride (tuple|list). The size of stride, (w, h). It should be probided if is_slide is True.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
Tensor: Prediction of image with shape (1, 1, h, w) is returned.
|
201 |
+
"""
|
202 |
+
if isinstance(scales, float):
|
203 |
+
scales = [scales]
|
204 |
+
elif not isinstance(scales, (tuple, list)):
|
205 |
+
raise TypeError(
|
206 |
+
'`scales` expects float/tuple/list type, but received {}'.format(
|
207 |
+
type(scales)))
|
208 |
+
final_logit = 0
|
209 |
+
h_input, w_input = im.shape[-2], im.shape[-1]
|
210 |
+
flip_comb = flip_combination(flip_horizontal, flip_vertical)
|
211 |
+
for scale in scales:
|
212 |
+
h = int(h_input * scale + 0.5)
|
213 |
+
w = int(w_input * scale + 0.5)
|
214 |
+
im = F.interpolate(im, (h, w), mode='bilinear')
|
215 |
+
for flip in flip_comb:
|
216 |
+
im_flip = tensor_flip(im, flip)
|
217 |
+
logit = inference(
|
218 |
+
model,
|
219 |
+
im_flip,
|
220 |
+
is_slide=is_slide,
|
221 |
+
crop_size=crop_size,
|
222 |
+
stride=stride)
|
223 |
+
logit = tensor_flip(logit, flip)
|
224 |
+
logit = F.interpolate(logit, (h_input, w_input), mode='bilinear')
|
225 |
+
|
226 |
+
logit = F.softmax(logit, axis=1)
|
227 |
+
final_logit = final_logit + logit
|
228 |
+
|
229 |
+
final_logit = reverse_transform(final_logit, trans_info, mode='bilinear')
|
230 |
+
pred = paddle.argmax(final_logit, axis=1, keepdim=True, dtype='int32')
|
231 |
+
|
232 |
+
return pred, final_logit
|
paddleseg/core/predict.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import math
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
import numpy as np
|
20 |
+
import paddle
|
21 |
+
|
22 |
+
from paddleseg import utils
|
23 |
+
from paddleseg.core import infer
|
24 |
+
from paddleseg.utils import logger, progbar, visualize
|
25 |
+
|
26 |
+
|
27 |
+
def mkdir(path):
|
28 |
+
sub_dir = os.path.dirname(path)
|
29 |
+
if not os.path.exists(sub_dir):
|
30 |
+
os.makedirs(sub_dir)
|
31 |
+
|
32 |
+
|
33 |
+
def partition_list(arr, m):
|
34 |
+
"""split the list 'arr' into m pieces"""
|
35 |
+
n = int(math.ceil(len(arr) / float(m)))
|
36 |
+
return [arr[i:i + n] for i in range(0, len(arr), n)]
|
37 |
+
|
38 |
+
|
39 |
+
def preprocess(im_path, transforms):
|
40 |
+
data = {}
|
41 |
+
data['img'] = im_path
|
42 |
+
data = transforms(data)
|
43 |
+
data['img'] = data['img'][np.newaxis, ...]
|
44 |
+
data['img'] = paddle.to_tensor(data['img'])
|
45 |
+
return data
|
46 |
+
|
47 |
+
|
48 |
+
def predict(model,
|
49 |
+
model_path,
|
50 |
+
transforms,
|
51 |
+
image_list,
|
52 |
+
image_dir=None,
|
53 |
+
save_dir='output',
|
54 |
+
aug_pred=False,
|
55 |
+
scales=1.0,
|
56 |
+
flip_horizontal=True,
|
57 |
+
flip_vertical=False,
|
58 |
+
is_slide=False,
|
59 |
+
stride=None,
|
60 |
+
crop_size=None,
|
61 |
+
custom_color=None):
|
62 |
+
"""
|
63 |
+
predict and visualize the image_list.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
model (nn.Layer): Used to predict for input image.
|
67 |
+
model_path (str): The path of pretrained model.
|
68 |
+
transforms (transform.Compose): Preprocess for input image.
|
69 |
+
image_list (list): A list of image path to be predicted.
|
70 |
+
image_dir (str, optional): The root directory of the images predicted. Default: None.
|
71 |
+
save_dir (str, optional): The directory to save the visualized results. Default: 'output'.
|
72 |
+
aug_pred (bool, optional): Whether to use mulit-scales and flip augment for predition. Default: False.
|
73 |
+
scales (list|float, optional): Scales for augment. It is valid when `aug_pred` is True. Default: 1.0.
|
74 |
+
flip_horizontal (bool, optional): Whether to use flip horizontally augment. It is valid when `aug_pred` is True. Default: True.
|
75 |
+
flip_vertical (bool, optional): Whether to use flip vertically augment. It is valid when `aug_pred` is True. Default: False.
|
76 |
+
is_slide (bool, optional): Whether to predict by sliding window. Default: False.
|
77 |
+
stride (tuple|list, optional): The stride of sliding window, the first is width and the second is height.
|
78 |
+
It should be provided when `is_slide` is True.
|
79 |
+
crop_size (tuple|list, optional): The crop size of sliding window, the first is width and the second is height.
|
80 |
+
It should be provided when `is_slide` is True.
|
81 |
+
custom_color (list, optional): Save images with a custom color map. Default: None, use paddleseg's default color map.
|
82 |
+
|
83 |
+
"""
|
84 |
+
utils.utils.load_entire_model(model, model_path)
|
85 |
+
model.eval()
|
86 |
+
nranks = paddle.distributed.get_world_size()
|
87 |
+
local_rank = paddle.distributed.get_rank()
|
88 |
+
if nranks > 1:
|
89 |
+
img_lists = partition_list(image_list, nranks)
|
90 |
+
else:
|
91 |
+
img_lists = [image_list]
|
92 |
+
|
93 |
+
added_saved_dir = os.path.join(save_dir, 'added_prediction')
|
94 |
+
pred_saved_dir = os.path.join(save_dir, 'pseudo_color_prediction')
|
95 |
+
|
96 |
+
logger.info("Start to predict...")
|
97 |
+
progbar_pred = progbar.Progbar(target=len(img_lists[0]), verbose=1)
|
98 |
+
color_map = visualize.get_color_map_list(256, custom_color=custom_color)
|
99 |
+
with paddle.no_grad():
|
100 |
+
for i, im_path in enumerate(img_lists[local_rank]):
|
101 |
+
data = preprocess(im_path, transforms)
|
102 |
+
|
103 |
+
if aug_pred:
|
104 |
+
pred, _ = infer.aug_inference(
|
105 |
+
model,
|
106 |
+
data['img'],
|
107 |
+
trans_info=data['trans_info'],
|
108 |
+
scales=scales,
|
109 |
+
flip_horizontal=flip_horizontal,
|
110 |
+
flip_vertical=flip_vertical,
|
111 |
+
is_slide=is_slide,
|
112 |
+
stride=stride,
|
113 |
+
crop_size=crop_size)
|
114 |
+
else:
|
115 |
+
pred, _ = infer.inference(
|
116 |
+
model,
|
117 |
+
data['img'],
|
118 |
+
trans_info=data['trans_info'],
|
119 |
+
is_slide=is_slide,
|
120 |
+
stride=stride,
|
121 |
+
crop_size=crop_size)
|
122 |
+
pred = paddle.squeeze(pred)
|
123 |
+
pred = pred.numpy().astype('uint8')
|
124 |
+
|
125 |
+
# get the saved name
|
126 |
+
if image_dir is not None:
|
127 |
+
im_file = im_path.replace(image_dir, '')
|
128 |
+
else:
|
129 |
+
im_file = os.path.basename(im_path)
|
130 |
+
if im_file[0] == '/' or im_file[0] == '\\':
|
131 |
+
im_file = im_file[1:]
|
132 |
+
|
133 |
+
# save added image
|
134 |
+
added_image = utils.visualize.visualize(
|
135 |
+
im_path, pred, color_map, weight=0.6)
|
136 |
+
added_image_path = os.path.join(added_saved_dir, im_file)
|
137 |
+
mkdir(added_image_path)
|
138 |
+
cv2.imwrite(added_image_path, added_image)
|
139 |
+
|
140 |
+
# save pseudo color prediction
|
141 |
+
pred_mask = utils.visualize.get_pseudo_color_map(pred, color_map)
|
142 |
+
pred_saved_path = os.path.join(
|
143 |
+
pred_saved_dir, os.path.splitext(im_file)[0] + ".png")
|
144 |
+
mkdir(pred_saved_path)
|
145 |
+
pred_mask.save(pred_saved_path)
|
146 |
+
|
147 |
+
progbar_pred.update(i + 1)
|
paddleseg/core/train.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import time
|
17 |
+
from collections import deque
|
18 |
+
import shutil
|
19 |
+
|
20 |
+
import paddle
|
21 |
+
import paddle.nn.functional as F
|
22 |
+
|
23 |
+
from paddleseg.utils import (TimeAverager, calculate_eta, resume, logger,
|
24 |
+
worker_init_fn, train_profiler, op_flops_funs)
|
25 |
+
from paddleseg.core.val import evaluate
|
26 |
+
|
27 |
+
|
28 |
+
def check_logits_losses(logits_list, losses):
|
29 |
+
len_logits = len(logits_list)
|
30 |
+
len_losses = len(losses['types'])
|
31 |
+
if len_logits != len_losses:
|
32 |
+
raise RuntimeError(
|
33 |
+
'The length of logits_list should equal to the types of loss config: {} != {}.'
|
34 |
+
.format(len_logits, len_losses))
|
35 |
+
|
36 |
+
|
37 |
+
def loss_computation(logits_list, labels, edges, losses):
|
38 |
+
check_logits_losses(logits_list, losses)
|
39 |
+
loss_list = []
|
40 |
+
for i in range(len(logits_list)):
|
41 |
+
logits = logits_list[i]
|
42 |
+
loss_i = losses['types'][i]
|
43 |
+
coef_i = losses['coef'][i]
|
44 |
+
if loss_i.__class__.__name__ in ('BCELoss', ) and loss_i.edge_label:
|
45 |
+
# Use edges as labels According to loss type.
|
46 |
+
loss_list.append(coef_i * loss_i(logits, edges))
|
47 |
+
elif loss_i.__class__.__name__ == 'MixedLoss':
|
48 |
+
mixed_loss_list = loss_i(logits, labels)
|
49 |
+
for mixed_loss in mixed_loss_list:
|
50 |
+
loss_list.append(coef_i * mixed_loss)
|
51 |
+
elif loss_i.__class__.__name__ in ("KLLoss", ):
|
52 |
+
loss_list.append(coef_i *
|
53 |
+
loss_i(logits_list[0], logits_list[1].detach()))
|
54 |
+
else:
|
55 |
+
loss_list.append(coef_i * loss_i(logits, labels))
|
56 |
+
return loss_list
|
57 |
+
|
58 |
+
|
59 |
+
def train(model,
|
60 |
+
train_dataset,
|
61 |
+
val_dataset=None,
|
62 |
+
optimizer=None,
|
63 |
+
save_dir='output',
|
64 |
+
iters=10000,
|
65 |
+
batch_size=2,
|
66 |
+
resume_model=None,
|
67 |
+
save_interval=1000,
|
68 |
+
log_iters=10,
|
69 |
+
num_workers=0,
|
70 |
+
use_vdl=False,
|
71 |
+
losses=None,
|
72 |
+
keep_checkpoint_max=5,
|
73 |
+
test_config=None,
|
74 |
+
precision='fp32',
|
75 |
+
amp_level='O1',
|
76 |
+
profiler_options=None,
|
77 |
+
to_static_training=False):
|
78 |
+
"""
|
79 |
+
Launch training.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
model(nn.Layer): A semantic segmentation model.
|
83 |
+
train_dataset (paddle.io.Dataset): Used to read and process training datasets.
|
84 |
+
val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets.
|
85 |
+
optimizer (paddle.optimizer.Optimizer): The optimizer.
|
86 |
+
save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
|
87 |
+
iters (int, optional): How may iters to train the model. Defualt: 10000.
|
88 |
+
batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
|
89 |
+
resume_model (str, optional): The path of resume model.
|
90 |
+
save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
|
91 |
+
log_iters (int, optional): Display logging information at every log_iters. Default: 10.
|
92 |
+
num_workers (int, optional): Num workers for data loader. Default: 0.
|
93 |
+
use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
|
94 |
+
losses (dict, optional): A dict including 'types' and 'coef'. The length of coef should equal to 1 or len(losses['types']).
|
95 |
+
The 'types' item is a list of object of paddleseg.models.losses while the 'coef' item is a list of the relevant coefficient.
|
96 |
+
keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
|
97 |
+
test_config(dict, optional): Evaluation config.
|
98 |
+
precision (str, optional): Use AMP if precision='fp16'. If precision='fp32', the training is normal.
|
99 |
+
amp_level (str, optional): Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision,
|
100 |
+
the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators
|
101 |
+
parameters and input data will be casted to fp16, except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp)
|
102 |
+
profiler_options (str, optional): The option of train profiler.
|
103 |
+
to_static_training (bool, optional): Whether to use @to_static for training.
|
104 |
+
"""
|
105 |
+
model.train()
|
106 |
+
nranks = paddle.distributed.ParallelEnv().nranks
|
107 |
+
local_rank = paddle.distributed.ParallelEnv().local_rank
|
108 |
+
|
109 |
+
start_iter = 0
|
110 |
+
if resume_model is not None:
|
111 |
+
start_iter = resume(model, optimizer, resume_model)
|
112 |
+
|
113 |
+
if not os.path.isdir(save_dir):
|
114 |
+
if os.path.exists(save_dir):
|
115 |
+
os.remove(save_dir)
|
116 |
+
os.makedirs(save_dir, exist_ok=True)
|
117 |
+
|
118 |
+
# use amp
|
119 |
+
if precision == 'fp16':
|
120 |
+
logger.info('use AMP to train. AMP level = {}'.format(amp_level))
|
121 |
+
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
|
122 |
+
if amp_level == 'O2':
|
123 |
+
model, optimizer = paddle.amp.decorate(
|
124 |
+
models=model,
|
125 |
+
optimizers=optimizer,
|
126 |
+
level='O2',
|
127 |
+
save_dtype='float32')
|
128 |
+
|
129 |
+
if nranks > 1:
|
130 |
+
paddle.distributed.fleet.init(is_collective=True)
|
131 |
+
optimizer = paddle.distributed.fleet.distributed_optimizer(
|
132 |
+
optimizer) # The return is Fleet object
|
133 |
+
ddp_model = paddle.distributed.fleet.distributed_model(model)
|
134 |
+
|
135 |
+
batch_sampler = paddle.io.DistributedBatchSampler(
|
136 |
+
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
|
137 |
+
|
138 |
+
loader = paddle.io.DataLoader(
|
139 |
+
train_dataset,
|
140 |
+
batch_sampler=batch_sampler,
|
141 |
+
num_workers=num_workers,
|
142 |
+
return_list=True,
|
143 |
+
worker_init_fn=worker_init_fn, )
|
144 |
+
|
145 |
+
if use_vdl:
|
146 |
+
from visualdl import LogWriter
|
147 |
+
log_writer = LogWriter(save_dir)
|
148 |
+
|
149 |
+
if to_static_training:
|
150 |
+
model = paddle.jit.to_static(model)
|
151 |
+
logger.info("Successfully applied @to_static")
|
152 |
+
|
153 |
+
avg_loss = 0.0
|
154 |
+
avg_loss_list = []
|
155 |
+
iters_per_epoch = len(batch_sampler)
|
156 |
+
best_mean_iou = -1.0
|
157 |
+
best_model_iter = -1
|
158 |
+
reader_cost_averager = TimeAverager()
|
159 |
+
batch_cost_averager = TimeAverager()
|
160 |
+
save_models = deque()
|
161 |
+
batch_start = time.time()
|
162 |
+
|
163 |
+
iter = start_iter
|
164 |
+
while iter < iters:
|
165 |
+
for data in loader:
|
166 |
+
iter += 1
|
167 |
+
if iter > iters:
|
168 |
+
version = paddle.__version__
|
169 |
+
if version == '2.1.2':
|
170 |
+
continue
|
171 |
+
else:
|
172 |
+
break
|
173 |
+
reader_cost_averager.record(time.time() - batch_start)
|
174 |
+
images = data['img']
|
175 |
+
labels = data['label'].astype('int64')
|
176 |
+
edges = None
|
177 |
+
if 'edge' in data.keys():
|
178 |
+
edges = data['edge'].astype('int64')
|
179 |
+
if hasattr(model, 'data_format') and model.data_format == 'NHWC':
|
180 |
+
images = images.transpose((0, 2, 3, 1))
|
181 |
+
|
182 |
+
if precision == 'fp16':
|
183 |
+
with paddle.amp.auto_cast(
|
184 |
+
level=amp_level,
|
185 |
+
enable=True,
|
186 |
+
custom_white_list={
|
187 |
+
"elementwise_add", "batch_norm", "sync_batch_norm"
|
188 |
+
},
|
189 |
+
custom_black_list={'bilinear_interp_v2'}):
|
190 |
+
logits_list = ddp_model(images) if nranks > 1 else model(
|
191 |
+
images)
|
192 |
+
loss_list = loss_computation(
|
193 |
+
logits_list=logits_list,
|
194 |
+
labels=labels,
|
195 |
+
edges=edges,
|
196 |
+
losses=losses)
|
197 |
+
loss = sum(loss_list)
|
198 |
+
|
199 |
+
scaled = scaler.scale(loss) # scale the loss
|
200 |
+
scaled.backward() # do backward
|
201 |
+
if isinstance(optimizer, paddle.distributed.fleet.Fleet):
|
202 |
+
scaler.minimize(optimizer.user_defined_optimizer, scaled)
|
203 |
+
else:
|
204 |
+
scaler.minimize(optimizer, scaled) # update parameters
|
205 |
+
else:
|
206 |
+
logits_list = ddp_model(images) if nranks > 1 else model(images)
|
207 |
+
loss_list = loss_computation(
|
208 |
+
logits_list=logits_list,
|
209 |
+
labels=labels,
|
210 |
+
edges=edges,
|
211 |
+
losses=losses)
|
212 |
+
loss = sum(loss_list)
|
213 |
+
loss.backward()
|
214 |
+
# if the optimizer is ReduceOnPlateau, the loss is the one which has been pass into step.
|
215 |
+
if isinstance(optimizer, paddle.optimizer.lr.ReduceOnPlateau):
|
216 |
+
optimizer.step(loss)
|
217 |
+
else:
|
218 |
+
optimizer.step()
|
219 |
+
|
220 |
+
lr = optimizer.get_lr()
|
221 |
+
|
222 |
+
# update lr
|
223 |
+
if isinstance(optimizer, paddle.distributed.fleet.Fleet):
|
224 |
+
lr_sche = optimizer.user_defined_optimizer._learning_rate
|
225 |
+
else:
|
226 |
+
lr_sche = optimizer._learning_rate
|
227 |
+
if isinstance(lr_sche, paddle.optimizer.lr.LRScheduler):
|
228 |
+
lr_sche.step()
|
229 |
+
|
230 |
+
train_profiler.add_profiler_step(profiler_options)
|
231 |
+
|
232 |
+
model.clear_gradients()
|
233 |
+
avg_loss += loss.numpy()[0]
|
234 |
+
if not avg_loss_list:
|
235 |
+
avg_loss_list = [l.numpy() for l in loss_list]
|
236 |
+
else:
|
237 |
+
for i in range(len(loss_list)):
|
238 |
+
avg_loss_list[i] += loss_list[i].numpy()
|
239 |
+
batch_cost_averager.record(
|
240 |
+
time.time() - batch_start, num_samples=batch_size)
|
241 |
+
|
242 |
+
if (iter) % log_iters == 0 and local_rank == 0:
|
243 |
+
avg_loss /= log_iters
|
244 |
+
avg_loss_list = [l[0] / log_iters for l in avg_loss_list]
|
245 |
+
remain_iters = iters - iter
|
246 |
+
avg_train_batch_cost = batch_cost_averager.get_average()
|
247 |
+
avg_train_reader_cost = reader_cost_averager.get_average()
|
248 |
+
eta = calculate_eta(remain_iters, avg_train_batch_cost)
|
249 |
+
logger.info(
|
250 |
+
"[TRAIN] epoch: {}, iter: {}/{}, loss: {:.4f}, lr: {:.6f}, batch_cost: {:.4f}, reader_cost: {:.5f}, ips: {:.4f} samples/sec | ETA {}"
|
251 |
+
.format((iter - 1
|
252 |
+
) // iters_per_epoch + 1, iter, iters, avg_loss,
|
253 |
+
lr, avg_train_batch_cost, avg_train_reader_cost,
|
254 |
+
batch_cost_averager.get_ips_average(), eta))
|
255 |
+
if use_vdl:
|
256 |
+
log_writer.add_scalar('Train/loss', avg_loss, iter)
|
257 |
+
# Record all losses if there are more than 2 losses.
|
258 |
+
if len(avg_loss_list) > 1:
|
259 |
+
avg_loss_dict = {}
|
260 |
+
for i, value in enumerate(avg_loss_list):
|
261 |
+
avg_loss_dict['loss_' + str(i)] = value
|
262 |
+
for key, value in avg_loss_dict.items():
|
263 |
+
log_tag = 'Train/' + key
|
264 |
+
log_writer.add_scalar(log_tag, value, iter)
|
265 |
+
|
266 |
+
log_writer.add_scalar('Train/lr', lr, iter)
|
267 |
+
log_writer.add_scalar('Train/batch_cost',
|
268 |
+
avg_train_batch_cost, iter)
|
269 |
+
log_writer.add_scalar('Train/reader_cost',
|
270 |
+
avg_train_reader_cost, iter)
|
271 |
+
avg_loss = 0.0
|
272 |
+
avg_loss_list = []
|
273 |
+
reader_cost_averager.reset()
|
274 |
+
batch_cost_averager.reset()
|
275 |
+
|
276 |
+
if (iter % save_interval == 0 or
|
277 |
+
iter == iters) and (val_dataset is not None):
|
278 |
+
num_workers = 1 if num_workers > 0 else 0
|
279 |
+
|
280 |
+
if test_config is None:
|
281 |
+
test_config = {}
|
282 |
+
|
283 |
+
mean_iou, acc, _, _, _ = evaluate(
|
284 |
+
model,
|
285 |
+
val_dataset,
|
286 |
+
num_workers=num_workers,
|
287 |
+
precision=precision,
|
288 |
+
amp_level=amp_level,
|
289 |
+
**test_config)
|
290 |
+
|
291 |
+
model.train()
|
292 |
+
|
293 |
+
if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
|
294 |
+
current_save_dir = os.path.join(save_dir,
|
295 |
+
"iter_{}".format(iter))
|
296 |
+
if not os.path.isdir(current_save_dir):
|
297 |
+
os.makedirs(current_save_dir)
|
298 |
+
paddle.save(model.state_dict(),
|
299 |
+
os.path.join(current_save_dir, 'model.pdparams'))
|
300 |
+
paddle.save(optimizer.state_dict(),
|
301 |
+
os.path.join(current_save_dir, 'model.pdopt'))
|
302 |
+
save_models.append(current_save_dir)
|
303 |
+
if len(save_models) > keep_checkpoint_max > 0:
|
304 |
+
model_to_remove = save_models.popleft()
|
305 |
+
shutil.rmtree(model_to_remove)
|
306 |
+
|
307 |
+
if val_dataset is not None:
|
308 |
+
if mean_iou > best_mean_iou:
|
309 |
+
best_mean_iou = mean_iou
|
310 |
+
best_model_iter = iter
|
311 |
+
best_model_dir = os.path.join(save_dir, "best_model")
|
312 |
+
paddle.save(
|
313 |
+
model.state_dict(),
|
314 |
+
os.path.join(best_model_dir, 'model.pdparams'))
|
315 |
+
logger.info(
|
316 |
+
'[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.'
|
317 |
+
.format(best_mean_iou, best_model_iter))
|
318 |
+
|
319 |
+
if use_vdl:
|
320 |
+
log_writer.add_scalar('Evaluate/mIoU', mean_iou, iter)
|
321 |
+
log_writer.add_scalar('Evaluate/Acc', acc, iter)
|
322 |
+
batch_start = time.time()
|
323 |
+
|
324 |
+
# Calculate flops.
|
325 |
+
if local_rank == 0 and not (precision == 'fp16' and amp_level == 'O2'):
|
326 |
+
_, c, h, w = images.shape
|
327 |
+
_ = paddle.flops(
|
328 |
+
model, [1, c, h, w],
|
329 |
+
custom_ops={paddle.nn.SyncBatchNorm: op_flops_funs.count_syncbn})
|
330 |
+
|
331 |
+
# Sleep for half a second to let dataloader release resources.
|
332 |
+
time.sleep(0.5)
|
333 |
+
if use_vdl:
|
334 |
+
log_writer.close()
|
paddleseg/core/val.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import time
|
19 |
+
import paddle
|
20 |
+
import paddle.nn.functional as F
|
21 |
+
|
22 |
+
from paddleseg.utils import metrics, TimeAverager, calculate_eta, logger, progbar
|
23 |
+
from paddleseg.core import infer
|
24 |
+
|
25 |
+
np.set_printoptions(suppress=True)
|
26 |
+
|
27 |
+
|
28 |
+
def evaluate(model,
|
29 |
+
eval_dataset,
|
30 |
+
aug_eval=False,
|
31 |
+
scales=1.0,
|
32 |
+
flip_horizontal=False,
|
33 |
+
flip_vertical=False,
|
34 |
+
is_slide=False,
|
35 |
+
stride=None,
|
36 |
+
crop_size=None,
|
37 |
+
precision='fp32',
|
38 |
+
amp_level='O1',
|
39 |
+
num_workers=0,
|
40 |
+
print_detail=True,
|
41 |
+
auc_roc=False):
|
42 |
+
"""
|
43 |
+
Launch evalution.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
model(nn.Layer): A semantic segmentation model.
|
47 |
+
eval_dataset (paddle.io.Dataset): Used to read and process validation datasets.
|
48 |
+
aug_eval (bool, optional): Whether to use mulit-scales and flip augment for evaluation. Default: False.
|
49 |
+
scales (list|float, optional): Scales for augment. It is valid when `aug_eval` is True. Default: 1.0.
|
50 |
+
flip_horizontal (bool, optional): Whether to use flip horizontally augment. It is valid when `aug_eval` is True. Default: True.
|
51 |
+
flip_vertical (bool, optional): Whether to use flip vertically augment. It is valid when `aug_eval` is True. Default: False.
|
52 |
+
is_slide (bool, optional): Whether to evaluate by sliding window. Default: False.
|
53 |
+
stride (tuple|list, optional): The stride of sliding window, the first is width and the second is height.
|
54 |
+
It should be provided when `is_slide` is True.
|
55 |
+
crop_size (tuple|list, optional): The crop size of sliding window, the first is width and the second is height.
|
56 |
+
It should be provided when `is_slide` is True.
|
57 |
+
precision (str, optional): Use AMP if precision='fp16'. If precision='fp32', the evaluation is normal.
|
58 |
+
amp_level (str, optional): Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp)
|
59 |
+
num_workers (int, optional): Num workers for data loader. Default: 0.
|
60 |
+
print_detail (bool, optional): Whether to print detailed information about the evaluation process. Default: True.
|
61 |
+
auc_roc(bool, optional): whether add auc_roc metric
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
float: The mIoU of validation datasets.
|
65 |
+
float: The accuracy of validation datasets.
|
66 |
+
"""
|
67 |
+
model.eval()
|
68 |
+
nranks = paddle.distributed.ParallelEnv().nranks
|
69 |
+
local_rank = paddle.distributed.ParallelEnv().local_rank
|
70 |
+
if nranks > 1:
|
71 |
+
# Initialize parallel environment if not done.
|
72 |
+
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
|
73 |
+
):
|
74 |
+
paddle.distributed.init_parallel_env()
|
75 |
+
batch_sampler = paddle.io.DistributedBatchSampler(
|
76 |
+
eval_dataset, batch_size=1, shuffle=False, drop_last=False)
|
77 |
+
loader = paddle.io.DataLoader(
|
78 |
+
eval_dataset,
|
79 |
+
batch_sampler=batch_sampler,
|
80 |
+
num_workers=num_workers,
|
81 |
+
return_list=True, )
|
82 |
+
|
83 |
+
total_iters = len(loader)
|
84 |
+
intersect_area_all = paddle.zeros([1], dtype='int64')
|
85 |
+
pred_area_all = paddle.zeros([1], dtype='int64')
|
86 |
+
label_area_all = paddle.zeros([1], dtype='int64')
|
87 |
+
logits_all = None
|
88 |
+
label_all = None
|
89 |
+
|
90 |
+
if print_detail:
|
91 |
+
logger.info("Start evaluating (total_samples: {}, total_iters: {})...".
|
92 |
+
format(len(eval_dataset), total_iters))
|
93 |
+
#TODO(chenguowei): fix log print error with multi-gpus
|
94 |
+
progbar_val = progbar.Progbar(
|
95 |
+
target=total_iters, verbose=1 if nranks < 2 else 2)
|
96 |
+
reader_cost_averager = TimeAverager()
|
97 |
+
batch_cost_averager = TimeAverager()
|
98 |
+
batch_start = time.time()
|
99 |
+
with paddle.no_grad():
|
100 |
+
for iter, data in enumerate(loader):
|
101 |
+
reader_cost_averager.record(time.time() - batch_start)
|
102 |
+
label = data['label'].astype('int64')
|
103 |
+
|
104 |
+
if aug_eval:
|
105 |
+
if precision == 'fp16':
|
106 |
+
with paddle.amp.auto_cast(
|
107 |
+
level=amp_level,
|
108 |
+
enable=True,
|
109 |
+
custom_white_list={
|
110 |
+
"elementwise_add", "batch_norm",
|
111 |
+
"sync_batch_norm"
|
112 |
+
},
|
113 |
+
custom_black_list={'bilinear_interp_v2'}):
|
114 |
+
pred, logits = infer.aug_inference(
|
115 |
+
model,
|
116 |
+
data['img'],
|
117 |
+
trans_info=data['trans_info'],
|
118 |
+
scales=scales,
|
119 |
+
flip_horizontal=flip_horizontal,
|
120 |
+
flip_vertical=flip_vertical,
|
121 |
+
is_slide=is_slide,
|
122 |
+
stride=stride,
|
123 |
+
crop_size=crop_size)
|
124 |
+
else:
|
125 |
+
pred, logits = infer.aug_inference(
|
126 |
+
model,
|
127 |
+
data['img'],
|
128 |
+
trans_info=data['trans_info'],
|
129 |
+
scales=scales,
|
130 |
+
flip_horizontal=flip_horizontal,
|
131 |
+
flip_vertical=flip_vertical,
|
132 |
+
is_slide=is_slide,
|
133 |
+
stride=stride,
|
134 |
+
crop_size=crop_size)
|
135 |
+
else:
|
136 |
+
if precision == 'fp16':
|
137 |
+
with paddle.amp.auto_cast(
|
138 |
+
level=amp_level,
|
139 |
+
enable=True,
|
140 |
+
custom_white_list={
|
141 |
+
"elementwise_add", "batch_norm",
|
142 |
+
"sync_batch_norm"
|
143 |
+
},
|
144 |
+
custom_black_list={'bilinear_interp_v2'}):
|
145 |
+
pred, logits = infer.inference(
|
146 |
+
model,
|
147 |
+
data['img'],
|
148 |
+
trans_info=data['trans_info'],
|
149 |
+
is_slide=is_slide,
|
150 |
+
stride=stride,
|
151 |
+
crop_size=crop_size)
|
152 |
+
else:
|
153 |
+
pred, logits = infer.inference(
|
154 |
+
model,
|
155 |
+
data['img'],
|
156 |
+
trans_info=data['trans_info'],
|
157 |
+
is_slide=is_slide,
|
158 |
+
stride=stride,
|
159 |
+
crop_size=crop_size)
|
160 |
+
|
161 |
+
intersect_area, pred_area, label_area = metrics.calculate_area(
|
162 |
+
pred,
|
163 |
+
label,
|
164 |
+
eval_dataset.num_classes,
|
165 |
+
ignore_index=eval_dataset.ignore_index)
|
166 |
+
|
167 |
+
# Gather from all ranks
|
168 |
+
if nranks > 1:
|
169 |
+
intersect_area_list = []
|
170 |
+
pred_area_list = []
|
171 |
+
label_area_list = []
|
172 |
+
paddle.distributed.all_gather(intersect_area_list,
|
173 |
+
intersect_area)
|
174 |
+
paddle.distributed.all_gather(pred_area_list, pred_area)
|
175 |
+
paddle.distributed.all_gather(label_area_list, label_area)
|
176 |
+
|
177 |
+
# Some image has been evaluated and should be eliminated in last iter
|
178 |
+
if (iter + 1) * nranks > len(eval_dataset):
|
179 |
+
valid = len(eval_dataset) - iter * nranks
|
180 |
+
intersect_area_list = intersect_area_list[:valid]
|
181 |
+
pred_area_list = pred_area_list[:valid]
|
182 |
+
label_area_list = label_area_list[:valid]
|
183 |
+
|
184 |
+
for i in range(len(intersect_area_list)):
|
185 |
+
intersect_area_all = intersect_area_all + intersect_area_list[
|
186 |
+
i]
|
187 |
+
pred_area_all = pred_area_all + pred_area_list[i]
|
188 |
+
label_area_all = label_area_all + label_area_list[i]
|
189 |
+
else:
|
190 |
+
intersect_area_all = intersect_area_all + intersect_area
|
191 |
+
pred_area_all = pred_area_all + pred_area
|
192 |
+
label_area_all = label_area_all + label_area
|
193 |
+
|
194 |
+
if auc_roc:
|
195 |
+
logits = F.softmax(logits, axis=1)
|
196 |
+
if logits_all is None:
|
197 |
+
logits_all = logits.numpy()
|
198 |
+
label_all = label.numpy()
|
199 |
+
else:
|
200 |
+
logits_all = np.concatenate(
|
201 |
+
[logits_all, logits.numpy()]) # (KN, C, H, W)
|
202 |
+
label_all = np.concatenate([label_all, label.numpy()])
|
203 |
+
|
204 |
+
batch_cost_averager.record(
|
205 |
+
time.time() - batch_start, num_samples=len(label))
|
206 |
+
batch_cost = batch_cost_averager.get_average()
|
207 |
+
reader_cost = reader_cost_averager.get_average()
|
208 |
+
|
209 |
+
if local_rank == 0 and print_detail:
|
210 |
+
progbar_val.update(iter + 1, [('batch_cost', batch_cost),
|
211 |
+
('reader cost', reader_cost)])
|
212 |
+
reader_cost_averager.reset()
|
213 |
+
batch_cost_averager.reset()
|
214 |
+
batch_start = time.time()
|
215 |
+
|
216 |
+
metrics_input = (intersect_area_all, pred_area_all, label_area_all)
|
217 |
+
class_iou, miou = metrics.mean_iou(*metrics_input)
|
218 |
+
acc, class_precision, class_recall = metrics.class_measurement(
|
219 |
+
*metrics_input)
|
220 |
+
kappa = metrics.kappa(*metrics_input)
|
221 |
+
class_dice, mdice = metrics.dice(*metrics_input)
|
222 |
+
|
223 |
+
if auc_roc:
|
224 |
+
auc_roc = metrics.auc_roc(
|
225 |
+
logits_all, label_all, num_classes=eval_dataset.num_classes)
|
226 |
+
auc_infor = ' Auc_roc: {:.4f}'.format(auc_roc)
|
227 |
+
|
228 |
+
if print_detail:
|
229 |
+
infor = "[EVAL] #Images: {} mIoU: {:.4f} Acc: {:.4f} Kappa: {:.4f} Dice: {:.4f}".format(
|
230 |
+
len(eval_dataset), miou, acc, kappa, mdice)
|
231 |
+
infor = infor + auc_infor if auc_roc else infor
|
232 |
+
logger.info(infor)
|
233 |
+
logger.info("[EVAL] Class IoU: \n" + str(np.round(class_iou, 4)))
|
234 |
+
logger.info("[EVAL] Class Precision: \n" + str(
|
235 |
+
np.round(class_precision, 4)))
|
236 |
+
logger.info("[EVAL] Class Recall: \n" + str(np.round(class_recall, 4)))
|
237 |
+
return miou, acc, class_iou, class_precision, kappa
|
paddleseg/cvlibs/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from . import manager
|
16 |
+
from . import param_init
|
17 |
+
from .config import Config
|
paddleseg/cvlibs/callbacks.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import time
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import paddle
|
20 |
+
from paddle.distributed.parallel import ParallelEnv
|
21 |
+
from visualdl import LogWriter
|
22 |
+
from paddleseg.utils.progbar import Progbar
|
23 |
+
import paddleseg.utils.logger as logger
|
24 |
+
|
25 |
+
|
26 |
+
class CallbackList(object):
|
27 |
+
"""
|
28 |
+
Container abstracting a list of callbacks.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
callbacks (list[Callback]): List of `Callback` instances.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, callbacks=None):
|
35 |
+
callbacks = callbacks or []
|
36 |
+
self.callbacks = [c for c in callbacks]
|
37 |
+
|
38 |
+
def append(self, callback):
|
39 |
+
self.callbacks.append(callback)
|
40 |
+
|
41 |
+
def set_params(self, params):
|
42 |
+
for callback in self.callbacks:
|
43 |
+
callback.set_params(params)
|
44 |
+
|
45 |
+
def set_model(self, model):
|
46 |
+
for callback in self.callbacks:
|
47 |
+
callback.set_model(model)
|
48 |
+
|
49 |
+
def set_optimizer(self, optimizer):
|
50 |
+
for callback in self.callbacks:
|
51 |
+
callback.set_optimizer(optimizer)
|
52 |
+
|
53 |
+
def on_iter_begin(self, iter, logs=None):
|
54 |
+
"""Called right before processing a batch.
|
55 |
+
"""
|
56 |
+
logs = logs or {}
|
57 |
+
for callback in self.callbacks:
|
58 |
+
callback.on_iter_begin(iter, logs)
|
59 |
+
self._t_enter_iter = time.time()
|
60 |
+
|
61 |
+
def on_iter_end(self, iter, logs=None):
|
62 |
+
"""Called at the end of a batch.
|
63 |
+
"""
|
64 |
+
logs = logs or {}
|
65 |
+
for callback in self.callbacks:
|
66 |
+
callback.on_iter_end(iter, logs)
|
67 |
+
self._t_exit_iter = time.time()
|
68 |
+
|
69 |
+
def on_train_begin(self, logs=None):
|
70 |
+
"""Called at the beginning of training.
|
71 |
+
"""
|
72 |
+
logs = logs or {}
|
73 |
+
for callback in self.callbacks:
|
74 |
+
callback.on_train_begin(logs)
|
75 |
+
|
76 |
+
def on_train_end(self, logs=None):
|
77 |
+
"""Called at the end of training.
|
78 |
+
"""
|
79 |
+
logs = logs or {}
|
80 |
+
for callback in self.callbacks:
|
81 |
+
callback.on_train_end(logs)
|
82 |
+
|
83 |
+
def __iter__(self):
|
84 |
+
return iter(self.callbacks)
|
85 |
+
|
86 |
+
|
87 |
+
class Callback(object):
|
88 |
+
"""Abstract base class used to build new callbacks.
|
89 |
+
"""
|
90 |
+
|
91 |
+
def __init__(self):
|
92 |
+
self.validation_data = None
|
93 |
+
|
94 |
+
def set_params(self, params):
|
95 |
+
self.params = params
|
96 |
+
|
97 |
+
def set_model(self, model):
|
98 |
+
self.model = model
|
99 |
+
|
100 |
+
def set_optimizer(self, optimizer):
|
101 |
+
self.optimizer = optimizer
|
102 |
+
|
103 |
+
def on_iter_begin(self, iter, logs=None):
|
104 |
+
pass
|
105 |
+
|
106 |
+
def on_iter_end(self, iter, logs=None):
|
107 |
+
pass
|
108 |
+
|
109 |
+
def on_train_begin(self, logs=None):
|
110 |
+
pass
|
111 |
+
|
112 |
+
def on_train_end(self, logs=None):
|
113 |
+
pass
|
114 |
+
|
115 |
+
|
116 |
+
class BaseLogger(Callback):
|
117 |
+
def __init__(self, period=10):
|
118 |
+
super(BaseLogger, self).__init__()
|
119 |
+
self.period = period
|
120 |
+
|
121 |
+
def _reset(self):
|
122 |
+
self.totals = {}
|
123 |
+
|
124 |
+
def on_train_begin(self, logs=None):
|
125 |
+
self.totals = {}
|
126 |
+
|
127 |
+
def on_iter_end(self, iter, logs=None):
|
128 |
+
logs = logs or {}
|
129 |
+
#(iter - 1) // iters_per_epoch + 1
|
130 |
+
for k, v in logs.items():
|
131 |
+
if k in self.totals.keys():
|
132 |
+
self.totals[k] += v
|
133 |
+
else:
|
134 |
+
self.totals[k] = v
|
135 |
+
|
136 |
+
if iter % self.period == 0 and ParallelEnv().local_rank == 0:
|
137 |
+
|
138 |
+
for k in self.totals:
|
139 |
+
logs[k] = self.totals[k] / self.period
|
140 |
+
self._reset()
|
141 |
+
|
142 |
+
|
143 |
+
class TrainLogger(Callback):
|
144 |
+
def __init__(self, log_freq=10):
|
145 |
+
self.log_freq = log_freq
|
146 |
+
|
147 |
+
def _calculate_eta(self, remaining_iters, speed):
|
148 |
+
if remaining_iters < 0:
|
149 |
+
remaining_iters = 0
|
150 |
+
remaining_time = int(remaining_iters * speed)
|
151 |
+
result = "{:0>2}:{:0>2}:{:0>2}"
|
152 |
+
arr = []
|
153 |
+
for i in range(2, -1, -1):
|
154 |
+
arr.append(int(remaining_time / 60**i))
|
155 |
+
remaining_time %= 60**i
|
156 |
+
return result.format(*arr)
|
157 |
+
|
158 |
+
def on_iter_end(self, iter, logs=None):
|
159 |
+
|
160 |
+
if iter % self.log_freq == 0 and ParallelEnv().local_rank == 0:
|
161 |
+
total_iters = self.params["total_iters"]
|
162 |
+
iters_per_epoch = self.params["iters_per_epoch"]
|
163 |
+
remaining_iters = total_iters - iter
|
164 |
+
eta = self._calculate_eta(remaining_iters, logs["batch_cost"])
|
165 |
+
current_epoch = (iter - 1) // self.params["iters_per_epoch"] + 1
|
166 |
+
loss = logs["loss"]
|
167 |
+
lr = self.optimizer.get_lr()
|
168 |
+
batch_cost = logs["batch_cost"]
|
169 |
+
reader_cost = logs["reader_cost"]
|
170 |
+
|
171 |
+
logger.info(
|
172 |
+
"[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}"
|
173 |
+
.format(current_epoch, iter, total_iters, loss, lr, batch_cost,
|
174 |
+
reader_cost, eta))
|
175 |
+
|
176 |
+
|
177 |
+
class ProgbarLogger(Callback):
|
178 |
+
def __init__(self):
|
179 |
+
super(ProgbarLogger, self).__init__()
|
180 |
+
|
181 |
+
def on_train_begin(self, logs=None):
|
182 |
+
self.verbose = self.params["verbose"]
|
183 |
+
self.total_iters = self.params["total_iters"]
|
184 |
+
self.target = self.params["total_iters"]
|
185 |
+
self.progbar = Progbar(target=self.target, verbose=self.verbose)
|
186 |
+
self.seen = 0
|
187 |
+
self.log_values = []
|
188 |
+
|
189 |
+
def on_iter_begin(self, iter, logs=None):
|
190 |
+
#self.seen = 0
|
191 |
+
if self.seen < self.target:
|
192 |
+
self.log_values = []
|
193 |
+
|
194 |
+
def on_iter_end(self, iter, logs=None):
|
195 |
+
logs = logs or {}
|
196 |
+
self.seen += 1
|
197 |
+
for k in self.params['metrics']:
|
198 |
+
if k in logs:
|
199 |
+
self.log_values.append((k, logs[k]))
|
200 |
+
|
201 |
+
#if self.verbose and self.seen < self.target and ParallelEnv.local_rank == 0:
|
202 |
+
#print(self.log_values)
|
203 |
+
if self.seen < self.target:
|
204 |
+
self.progbar.update(self.seen, self.log_values)
|
205 |
+
|
206 |
+
|
207 |
+
class ModelCheckpoint(Callback):
|
208 |
+
def __init__(self,
|
209 |
+
save_dir,
|
210 |
+
monitor="miou",
|
211 |
+
save_best_only=False,
|
212 |
+
save_params_only=True,
|
213 |
+
mode="max",
|
214 |
+
period=1):
|
215 |
+
|
216 |
+
super(ModelCheckpoint, self).__init__()
|
217 |
+
self.monitor = monitor
|
218 |
+
self.save_dir = save_dir
|
219 |
+
self.save_best_only = save_best_only
|
220 |
+
self.save_params_only = save_params_only
|
221 |
+
self.period = period
|
222 |
+
self.iters_since_last_save = 0
|
223 |
+
|
224 |
+
if mode == "min":
|
225 |
+
self.monitor_op = np.less
|
226 |
+
self.best = np.Inf
|
227 |
+
elif mode == "max":
|
228 |
+
self.monitor_op = np.greater
|
229 |
+
self.best = -np.Inf
|
230 |
+
else:
|
231 |
+
raise RuntimeError("`mode` is neither \"min\" nor \"max\"!")
|
232 |
+
|
233 |
+
def on_train_begin(self, logs=None):
|
234 |
+
self.verbose = self.params["verbose"]
|
235 |
+
save_dir = self.save_dir
|
236 |
+
if not os.path.isdir(save_dir):
|
237 |
+
if os.path.exists(save_dir):
|
238 |
+
os.remove(save_dir)
|
239 |
+
os.makedirs(save_dir)
|
240 |
+
|
241 |
+
def on_iter_end(self, iter, logs=None):
|
242 |
+
logs = logs or {}
|
243 |
+
self.iters_since_last_save += 1
|
244 |
+
current_save_dir = os.path.join(self.save_dir, "iter_{}".format(iter))
|
245 |
+
current_save_dir = os.path.abspath(current_save_dir)
|
246 |
+
#if self.iters_since_last_save % self.period and ParallelEnv().local_rank == 0:
|
247 |
+
#self.iters_since_last_save = 0
|
248 |
+
if iter % self.period == 0 and ParallelEnv().local_rank == 0:
|
249 |
+
if self.verbose > 0:
|
250 |
+
print("iter {iter_num}: saving model to {path}".format(
|
251 |
+
iter_num=iter, path=current_save_dir))
|
252 |
+
|
253 |
+
paddle.save(self.model.state_dict(),
|
254 |
+
os.path.join(current_save_dir, 'model.pdparams'))
|
255 |
+
|
256 |
+
if not self.save_params_only:
|
257 |
+
paddle.save(self.optimizer.state_dict(),
|
258 |
+
os.path.join(current_save_dir, 'model.pdopt'))
|
259 |
+
|
260 |
+
|
261 |
+
class VisualDL(Callback):
|
262 |
+
def __init__(self, log_dir="./log", freq=1):
|
263 |
+
super(VisualDL, self).__init__()
|
264 |
+
self.log_dir = log_dir
|
265 |
+
self.freq = freq
|
266 |
+
|
267 |
+
def on_train_begin(self, logs=None):
|
268 |
+
self.writer = LogWriter(self.log_dir)
|
269 |
+
|
270 |
+
def on_iter_end(self, iter, logs=None):
|
271 |
+
logs = logs or {}
|
272 |
+
if iter % self.freq == 0 and ParallelEnv().local_rank == 0:
|
273 |
+
for k, v in logs.items():
|
274 |
+
self.writer.add_scalar("Train/{}".format(k), v, iter)
|
275 |
+
|
276 |
+
self.writer.flush()
|
277 |
+
|
278 |
+
def on_train_end(self, logs=None):
|
279 |
+
self.writer.close()
|
paddleseg/cvlibs/config.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import codecs
|
16 |
+
import os
|
17 |
+
from typing import Any, Dict, Generic
|
18 |
+
import warnings
|
19 |
+
|
20 |
+
import paddle
|
21 |
+
import yaml
|
22 |
+
|
23 |
+
from paddleseg.cvlibs import manager
|
24 |
+
from paddleseg.utils import logger
|
25 |
+
|
26 |
+
|
27 |
+
class Config(object):
|
28 |
+
'''
|
29 |
+
Training configuration parsing. The only yaml/yml file is supported.
|
30 |
+
|
31 |
+
The following hyper-parameters are available in the config file:
|
32 |
+
batch_size: The number of samples per gpu.
|
33 |
+
iters: The total training steps.
|
34 |
+
train_dataset: A training data config including type/data_root/transforms/mode.
|
35 |
+
For data type, please refer to paddleseg.datasets.
|
36 |
+
For specific transforms, please refer to paddleseg.transforms.transforms.
|
37 |
+
val_dataset: A validation data config including type/data_root/transforms/mode.
|
38 |
+
optimizer: A optimizer config, but currently PaddleSeg only supports sgd with momentum in config file.
|
39 |
+
In addition, weight_decay could be set as a regularization.
|
40 |
+
learning_rate: A learning rate config. If decay is configured, learning _rate value is the starting learning rate,
|
41 |
+
where only poly decay is supported using the config file. In addition, decay power and end_lr are tuned experimentally.
|
42 |
+
loss: A loss config. Multi-loss config is available. The loss type order is consistent with the seg model outputs,
|
43 |
+
where the coef term indicates the weight of corresponding loss. Note that the number of coef must be the same as the number of
|
44 |
+
model outputs, and there could be only one loss type if using the same loss type among the outputs, otherwise the number of
|
45 |
+
loss type must be consistent with coef.
|
46 |
+
model: A model config including type/backbone and model-dependent arguments.
|
47 |
+
For model type, please refer to paddleseg.models.
|
48 |
+
For backbone, please refer to paddleseg.models.backbones.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
path (str) : The path of config file, supports yaml format only.
|
52 |
+
|
53 |
+
Examples:
|
54 |
+
|
55 |
+
from paddleseg.cvlibs.config import Config
|
56 |
+
|
57 |
+
# Create a cfg object with yaml file path.
|
58 |
+
cfg = Config(yaml_cfg_path)
|
59 |
+
|
60 |
+
# Parsing the argument when its property is used.
|
61 |
+
train_dataset = cfg.train_dataset
|
62 |
+
|
63 |
+
# the argument of model should be parsed after dataset,
|
64 |
+
# since the model builder uses some properties in dataset.
|
65 |
+
model = cfg.model
|
66 |
+
...
|
67 |
+
'''
|
68 |
+
|
69 |
+
def __init__(self,
|
70 |
+
path: str,
|
71 |
+
learning_rate: float=None,
|
72 |
+
batch_size: int=None,
|
73 |
+
iters: int=None):
|
74 |
+
if not path:
|
75 |
+
raise ValueError('Please specify the configuration file path.')
|
76 |
+
|
77 |
+
if not os.path.exists(path):
|
78 |
+
raise FileNotFoundError('File {} does not exist'.format(path))
|
79 |
+
|
80 |
+
self._model = None
|
81 |
+
self._losses = None
|
82 |
+
if path.endswith('yml') or path.endswith('yaml'):
|
83 |
+
self.dic = self._parse_from_yaml(path)
|
84 |
+
else:
|
85 |
+
raise RuntimeError('Config file should in yaml format!')
|
86 |
+
|
87 |
+
self.update(
|
88 |
+
learning_rate=learning_rate, batch_size=batch_size, iters=iters)
|
89 |
+
|
90 |
+
def _update_dic(self, dic, base_dic):
|
91 |
+
"""
|
92 |
+
Update config from dic based base_dic
|
93 |
+
"""
|
94 |
+
base_dic = base_dic.copy()
|
95 |
+
dic = dic.copy()
|
96 |
+
|
97 |
+
if dic.get('_inherited_', True) == False:
|
98 |
+
dic.pop('_inherited_')
|
99 |
+
return dic
|
100 |
+
|
101 |
+
for key, val in dic.items():
|
102 |
+
if isinstance(val, dict) and key in base_dic:
|
103 |
+
base_dic[key] = self._update_dic(val, base_dic[key])
|
104 |
+
else:
|
105 |
+
base_dic[key] = val
|
106 |
+
dic = base_dic
|
107 |
+
return dic
|
108 |
+
|
109 |
+
def _parse_from_yaml(self, path: str):
|
110 |
+
'''Parse a yaml file and build config'''
|
111 |
+
with codecs.open(path, 'r', 'utf-8') as file:
|
112 |
+
dic = yaml.load(file, Loader=yaml.FullLoader)
|
113 |
+
|
114 |
+
if '_base_' in dic:
|
115 |
+
cfg_dir = os.path.dirname(path)
|
116 |
+
base_path = dic.pop('_base_')
|
117 |
+
base_path = os.path.join(cfg_dir, base_path)
|
118 |
+
base_dic = self._parse_from_yaml(base_path)
|
119 |
+
dic = self._update_dic(dic, base_dic)
|
120 |
+
return dic
|
121 |
+
|
122 |
+
def update(self,
|
123 |
+
learning_rate: float=None,
|
124 |
+
batch_size: int=None,
|
125 |
+
iters: int=None):
|
126 |
+
'''Update config'''
|
127 |
+
if learning_rate:
|
128 |
+
if 'lr_scheduler' in self.dic:
|
129 |
+
self.dic['lr_scheduler']['learning_rate'] = learning_rate
|
130 |
+
else:
|
131 |
+
self.dic['learning_rate']['value'] = learning_rate
|
132 |
+
|
133 |
+
if batch_size:
|
134 |
+
self.dic['batch_size'] = batch_size
|
135 |
+
|
136 |
+
if iters:
|
137 |
+
self.dic['iters'] = iters
|
138 |
+
|
139 |
+
@property
|
140 |
+
def batch_size(self) -> int:
|
141 |
+
return self.dic.get('batch_size', 1)
|
142 |
+
|
143 |
+
@property
|
144 |
+
def iters(self) -> int:
|
145 |
+
iters = self.dic.get('iters')
|
146 |
+
if not iters:
|
147 |
+
raise RuntimeError('No iters specified in the configuration file.')
|
148 |
+
return iters
|
149 |
+
|
150 |
+
@property
|
151 |
+
def lr_scheduler(self) -> paddle.optimizer.lr.LRScheduler:
|
152 |
+
if 'lr_scheduler' not in self.dic:
|
153 |
+
raise RuntimeError(
|
154 |
+
'No `lr_scheduler` specified in the configuration file.')
|
155 |
+
params = self.dic.get('lr_scheduler')
|
156 |
+
|
157 |
+
use_warmup = False
|
158 |
+
if 'warmup_iters' in params:
|
159 |
+
use_warmup = True
|
160 |
+
warmup_iters = params.pop('warmup_iters')
|
161 |
+
assert 'warmup_start_lr' in params, \
|
162 |
+
"When use warmup, please set warmup_start_lr and warmup_iters in lr_scheduler"
|
163 |
+
warmup_start_lr = params.pop('warmup_start_lr')
|
164 |
+
end_lr = params['learning_rate']
|
165 |
+
|
166 |
+
lr_type = params.pop('type')
|
167 |
+
if lr_type == 'PolynomialDecay':
|
168 |
+
iters = self.iters - warmup_iters if use_warmup else self.iters
|
169 |
+
iters = max(iters, 1)
|
170 |
+
params.setdefault('decay_steps', iters)
|
171 |
+
params.setdefault('end_lr', 0)
|
172 |
+
params.setdefault('power', 0.9)
|
173 |
+
lr_sche = getattr(paddle.optimizer.lr, lr_type)(**params)
|
174 |
+
|
175 |
+
if use_warmup:
|
176 |
+
lr_sche = paddle.optimizer.lr.LinearWarmup(
|
177 |
+
learning_rate=lr_sche,
|
178 |
+
warmup_steps=warmup_iters,
|
179 |
+
start_lr=warmup_start_lr,
|
180 |
+
end_lr=end_lr)
|
181 |
+
|
182 |
+
return lr_sche
|
183 |
+
|
184 |
+
@property
|
185 |
+
def learning_rate(self) -> paddle.optimizer.lr.LRScheduler:
|
186 |
+
logger.warning(
|
187 |
+
'''`learning_rate` in configuration file will be deprecated, please use `lr_scheduler` instead. E.g
|
188 |
+
lr_scheduler:
|
189 |
+
type: PolynomialDecay
|
190 |
+
learning_rate: 0.01''')
|
191 |
+
|
192 |
+
_learning_rate = self.dic.get('learning_rate', {})
|
193 |
+
if isinstance(_learning_rate, float):
|
194 |
+
return _learning_rate
|
195 |
+
|
196 |
+
_learning_rate = self.dic.get('learning_rate', {}).get('value')
|
197 |
+
if not _learning_rate:
|
198 |
+
raise RuntimeError(
|
199 |
+
'No learning rate specified in the configuration file.')
|
200 |
+
|
201 |
+
args = self.decay_args
|
202 |
+
decay_type = args.pop('type')
|
203 |
+
|
204 |
+
if decay_type == 'poly':
|
205 |
+
lr = _learning_rate
|
206 |
+
return paddle.optimizer.lr.PolynomialDecay(lr, **args)
|
207 |
+
elif decay_type == 'piecewise':
|
208 |
+
values = _learning_rate
|
209 |
+
return paddle.optimizer.lr.PiecewiseDecay(values=values, **args)
|
210 |
+
elif decay_type == 'stepdecay':
|
211 |
+
lr = _learning_rate
|
212 |
+
return paddle.optimizer.lr.StepDecay(lr, **args)
|
213 |
+
else:
|
214 |
+
raise RuntimeError('Only poly and piecewise decay support.')
|
215 |
+
|
216 |
+
@property
|
217 |
+
def optimizer(self) -> paddle.optimizer.Optimizer:
|
218 |
+
if 'lr_scheduler' in self.dic:
|
219 |
+
lr = self.lr_scheduler
|
220 |
+
else:
|
221 |
+
lr = self.learning_rate
|
222 |
+
args = self.optimizer_args
|
223 |
+
optimizer_type = args.pop('type')
|
224 |
+
|
225 |
+
if optimizer_type == 'sgd':
|
226 |
+
return paddle.optimizer.Momentum(
|
227 |
+
lr, parameters=self.model.parameters(), **args)
|
228 |
+
elif optimizer_type == 'adam':
|
229 |
+
return paddle.optimizer.Adam(
|
230 |
+
lr, parameters=self.model.parameters(), **args)
|
231 |
+
elif optimizer_type in paddle.optimizer.__all__:
|
232 |
+
return getattr(paddle.optimizer,
|
233 |
+
optimizer_type)(lr,
|
234 |
+
parameters=self.model.parameters(),
|
235 |
+
**args)
|
236 |
+
|
237 |
+
raise RuntimeError('Unknown optimizer type {}.'.format(optimizer_type))
|
238 |
+
|
239 |
+
@property
|
240 |
+
def optimizer_args(self) -> dict:
|
241 |
+
args = self.dic.get('optimizer', {}).copy()
|
242 |
+
if args['type'] == 'sgd':
|
243 |
+
args.setdefault('momentum', 0.9)
|
244 |
+
|
245 |
+
return args
|
246 |
+
|
247 |
+
@property
|
248 |
+
def decay_args(self) -> dict:
|
249 |
+
args = self.dic.get('learning_rate', {}).get(
|
250 |
+
'decay', {'type': 'poly',
|
251 |
+
'power': 0.9}).copy()
|
252 |
+
|
253 |
+
if args['type'] == 'poly':
|
254 |
+
args.setdefault('decay_steps', self.iters)
|
255 |
+
args.setdefault('end_lr', 0)
|
256 |
+
|
257 |
+
return args
|
258 |
+
|
259 |
+
@property
|
260 |
+
def loss(self) -> dict:
|
261 |
+
if self._losses is None:
|
262 |
+
self._losses = self._prepare_loss('loss')
|
263 |
+
return self._losses
|
264 |
+
|
265 |
+
@property
|
266 |
+
def distill_loss(self) -> dict:
|
267 |
+
if not hasattr(self, '_distill_losses'):
|
268 |
+
self._distill_losses = self._prepare_loss('distill_loss')
|
269 |
+
return self._distill_losses
|
270 |
+
|
271 |
+
def _prepare_loss(self, loss_name):
|
272 |
+
"""
|
273 |
+
Parse the loss parameters and load the loss layers.
|
274 |
+
|
275 |
+
Args:
|
276 |
+
loss_name (str): The root name of loss in the yaml file.
|
277 |
+
Returns:
|
278 |
+
dict: A dict including the loss parameters and layers.
|
279 |
+
"""
|
280 |
+
args = self.dic.get(loss_name, {}).copy()
|
281 |
+
if 'types' in args and 'coef' in args:
|
282 |
+
len_types = len(args['types'])
|
283 |
+
len_coef = len(args['coef'])
|
284 |
+
if len_types != len_coef:
|
285 |
+
if len_types == 1:
|
286 |
+
args['types'] = args['types'] * len_coef
|
287 |
+
else:
|
288 |
+
raise ValueError(
|
289 |
+
'The length of types should equal to coef or equal to 1 in loss config, but they are {} and {}.'
|
290 |
+
.format(len_types, len_coef))
|
291 |
+
else:
|
292 |
+
raise ValueError(
|
293 |
+
'Loss config should contain keys of "types" and "coef"')
|
294 |
+
|
295 |
+
losses = dict()
|
296 |
+
for key, val in args.items():
|
297 |
+
if key == 'types':
|
298 |
+
losses['types'] = []
|
299 |
+
for item in args['types']:
|
300 |
+
if item['type'] != 'MixedLoss':
|
301 |
+
if 'ignore_index' in item:
|
302 |
+
assert item['ignore_index'] == self.train_dataset.ignore_index, 'If ignore_index of loss is set, '\
|
303 |
+
'the ignore_index of loss and train_dataset must be the same. \nCurrently, loss ignore_index = {}, '\
|
304 |
+
'train_dataset ignore_index = {}. \nIt is recommended not to set loss ignore_index, so it is consistent with '\
|
305 |
+
'train_dataset by default.'.format(item['ignore_index'], self.train_dataset.ignore_index)
|
306 |
+
item['ignore_index'] = \
|
307 |
+
self.train_dataset.ignore_index
|
308 |
+
losses['types'].append(self._load_object(item))
|
309 |
+
else:
|
310 |
+
losses[key] = val
|
311 |
+
if len(losses['coef']) != len(losses['types']):
|
312 |
+
raise RuntimeError(
|
313 |
+
'The length of coef should equal to types in loss config: {} != {}.'
|
314 |
+
.format(len(losses['coef']), len(losses['types'])))
|
315 |
+
return losses
|
316 |
+
|
317 |
+
@property
|
318 |
+
def model(self) -> paddle.nn.Layer:
|
319 |
+
model_cfg = self.dic.get('model').copy()
|
320 |
+
if not model_cfg:
|
321 |
+
raise RuntimeError('No model specified in the configuration file.')
|
322 |
+
|
323 |
+
if not 'num_classes' in model_cfg:
|
324 |
+
num_classes = None
|
325 |
+
try:
|
326 |
+
if self.train_dataset_config:
|
327 |
+
if hasattr(self.train_dataset_class, 'NUM_CLASSES'):
|
328 |
+
num_classes = self.train_dataset_class.NUM_CLASSES
|
329 |
+
elif 'num_classes' in self.train_dataset_config:
|
330 |
+
num_classes = self.train_dataset_config['num_classes']
|
331 |
+
elif hasattr(self.train_dataset, 'num_classes'):
|
332 |
+
num_classes = self.train_dataset.num_classes
|
333 |
+
elif self.val_dataset_config:
|
334 |
+
if hasattr(self.val_dataset_class, 'NUM_CLASSES'):
|
335 |
+
num_classes = self.val_dataset_class.NUM_CLASSES
|
336 |
+
elif 'num_classes' in self.val_dataset_config:
|
337 |
+
num_classes = self.val_dataset_config['num_classes']
|
338 |
+
elif hasattr(self.val_dataset, 'num_classes'):
|
339 |
+
num_classes = self.val_dataset.num_classes
|
340 |
+
except FileNotFoundError:
|
341 |
+
warnings.warn("`dataset_root` is not found. Is it correct?")
|
342 |
+
|
343 |
+
if num_classes is not None:
|
344 |
+
model_cfg['num_classes'] = num_classes
|
345 |
+
|
346 |
+
if not self._model:
|
347 |
+
self._model = self._load_object(model_cfg)
|
348 |
+
return self._model
|
349 |
+
|
350 |
+
@property
|
351 |
+
def train_dataset_config(self) -> Dict:
|
352 |
+
return self.dic.get('train_dataset', {}).copy()
|
353 |
+
|
354 |
+
@property
|
355 |
+
def val_dataset_config(self) -> Dict:
|
356 |
+
return self.dic.get('val_dataset', {}).copy()
|
357 |
+
|
358 |
+
@property
|
359 |
+
def train_dataset_class(self) -> Generic:
|
360 |
+
dataset_type = self.train_dataset_config['type']
|
361 |
+
return self._load_component(dataset_type)
|
362 |
+
|
363 |
+
@property
|
364 |
+
def val_dataset_class(self) -> Generic:
|
365 |
+
dataset_type = self.val_dataset_config['type']
|
366 |
+
return self._load_component(dataset_type)
|
367 |
+
|
368 |
+
@property
|
369 |
+
def train_dataset(self) -> paddle.io.Dataset:
|
370 |
+
_train_dataset = self.train_dataset_config
|
371 |
+
if not _train_dataset:
|
372 |
+
return None
|
373 |
+
return self._load_object(_train_dataset)
|
374 |
+
|
375 |
+
@property
|
376 |
+
def val_dataset(self) -> paddle.io.Dataset:
|
377 |
+
_val_dataset = self.val_dataset_config
|
378 |
+
if not _val_dataset:
|
379 |
+
return None
|
380 |
+
return self._load_object(_val_dataset)
|
381 |
+
|
382 |
+
def _load_component(self, com_name: str) -> Any:
|
383 |
+
com_list = [
|
384 |
+
manager.MODELS, manager.BACKBONES, manager.DATASETS,
|
385 |
+
manager.TRANSFORMS, manager.LOSSES
|
386 |
+
]
|
387 |
+
|
388 |
+
for com in com_list:
|
389 |
+
if com_name in com.components_dict:
|
390 |
+
return com[com_name]
|
391 |
+
else:
|
392 |
+
raise RuntimeError(
|
393 |
+
'The specified component was not found {}.'.format(com_name))
|
394 |
+
|
395 |
+
def _load_object(self, cfg: dict) -> Any:
|
396 |
+
cfg = cfg.copy()
|
397 |
+
if 'type' not in cfg:
|
398 |
+
raise RuntimeError('No object information in {}.'.format(cfg))
|
399 |
+
|
400 |
+
component = self._load_component(cfg.pop('type'))
|
401 |
+
|
402 |
+
params = {}
|
403 |
+
for key, val in cfg.items():
|
404 |
+
if self._is_meta_type(val):
|
405 |
+
params[key] = self._load_object(val)
|
406 |
+
elif isinstance(val, list):
|
407 |
+
params[key] = [
|
408 |
+
self._load_object(item)
|
409 |
+
if self._is_meta_type(item) else item for item in val
|
410 |
+
]
|
411 |
+
else:
|
412 |
+
params[key] = val
|
413 |
+
|
414 |
+
return component(**params)
|
415 |
+
|
416 |
+
@property
|
417 |
+
def test_config(self) -> Dict:
|
418 |
+
return self.dic.get('test_config', {})
|
419 |
+
|
420 |
+
@property
|
421 |
+
def export_config(self) -> Dict:
|
422 |
+
return self.dic.get('export', {})
|
423 |
+
|
424 |
+
@property
|
425 |
+
def to_static_training(self) -> bool:
|
426 |
+
'''Whether to use @to_static for training'''
|
427 |
+
return self.dic.get('to_static_training', False)
|
428 |
+
|
429 |
+
def _is_meta_type(self, item: Any) -> bool:
|
430 |
+
return isinstance(item, dict) and 'type' in item
|
431 |
+
|
432 |
+
def __str__(self) -> str:
|
433 |
+
return yaml.dump(self.dic)
|
434 |
+
|
435 |
+
@property
|
436 |
+
def val_transforms(self) -> list:
|
437 |
+
"""Get val_transform from val_dataset"""
|
438 |
+
_val_dataset = self.val_dataset_config
|
439 |
+
if not _val_dataset:
|
440 |
+
return []
|
441 |
+
_transforms = _val_dataset.get('transforms', [])
|
442 |
+
transforms = []
|
443 |
+
for i in _transforms:
|
444 |
+
transforms.append(self._load_object(i))
|
445 |
+
return transforms
|
paddleseg/cvlibs/manager.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from collections.abc import Sequence
|
17 |
+
|
18 |
+
import warnings
|
19 |
+
|
20 |
+
|
21 |
+
class ComponentManager:
|
22 |
+
"""
|
23 |
+
Implement a manager class to add the new component properly.
|
24 |
+
The component can be added as either class or function type.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
name (str): The name of component.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
A callable object of ComponentManager.
|
31 |
+
|
32 |
+
Examples 1:
|
33 |
+
|
34 |
+
from paddleseg.cvlibs.manager import ComponentManager
|
35 |
+
|
36 |
+
model_manager = ComponentManager()
|
37 |
+
|
38 |
+
class AlexNet: ...
|
39 |
+
class ResNet: ...
|
40 |
+
|
41 |
+
model_manager.add_component(AlexNet)
|
42 |
+
model_manager.add_component(ResNet)
|
43 |
+
|
44 |
+
# Or pass a sequence alliteratively:
|
45 |
+
model_manager.add_component([AlexNet, ResNet])
|
46 |
+
print(model_manager.components_dict)
|
47 |
+
# {'AlexNet': <class '__main__.AlexNet'>, 'ResNet': <class '__main__.ResNet'>}
|
48 |
+
|
49 |
+
Examples 2:
|
50 |
+
|
51 |
+
# Or an easier way, using it as a Python decorator, while just add it above the class declaration.
|
52 |
+
from paddleseg.cvlibs.manager import ComponentManager
|
53 |
+
|
54 |
+
model_manager = ComponentManager()
|
55 |
+
|
56 |
+
@model_manager.add_component
|
57 |
+
class AlexNet: ...
|
58 |
+
|
59 |
+
@model_manager.add_component
|
60 |
+
class ResNet: ...
|
61 |
+
|
62 |
+
print(model_manager.components_dict)
|
63 |
+
# {'AlexNet': <class '__main__.AlexNet'>, 'ResNet': <class '__main__.ResNet'>}
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self, name=None):
|
67 |
+
self._components_dict = dict()
|
68 |
+
self._name = name
|
69 |
+
|
70 |
+
def __len__(self):
|
71 |
+
return len(self._components_dict)
|
72 |
+
|
73 |
+
def __repr__(self):
|
74 |
+
name_str = self._name if self._name else self.__class__.__name__
|
75 |
+
return "{}:{}".format(name_str, list(self._components_dict.keys()))
|
76 |
+
|
77 |
+
def __getitem__(self, item):
|
78 |
+
if item not in self._components_dict.keys():
|
79 |
+
raise KeyError("{} does not exist in availabel {}".format(item,
|
80 |
+
self))
|
81 |
+
return self._components_dict[item]
|
82 |
+
|
83 |
+
@property
|
84 |
+
def components_dict(self):
|
85 |
+
return self._components_dict
|
86 |
+
|
87 |
+
@property
|
88 |
+
def name(self):
|
89 |
+
return self._name
|
90 |
+
|
91 |
+
def _add_single_component(self, component):
|
92 |
+
"""
|
93 |
+
Add a single component into the corresponding manager.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
component (function|class): A new component.
|
97 |
+
|
98 |
+
Raises:
|
99 |
+
TypeError: When `component` is neither class nor function.
|
100 |
+
KeyError: When `component` was added already.
|
101 |
+
"""
|
102 |
+
|
103 |
+
# Currently only support class or function type
|
104 |
+
if not (inspect.isclass(component) or inspect.isfunction(component)):
|
105 |
+
raise TypeError("Expect class/function type, but received {}".
|
106 |
+
format(type(component)))
|
107 |
+
|
108 |
+
# Obtain the internal name of the component
|
109 |
+
component_name = component.__name__
|
110 |
+
|
111 |
+
# Check whether the component was added already
|
112 |
+
if component_name in self._components_dict.keys():
|
113 |
+
warnings.warn("{} exists already! It is now updated to {} !!!".
|
114 |
+
format(component_name, component))
|
115 |
+
self._components_dict[component_name] = component
|
116 |
+
|
117 |
+
else:
|
118 |
+
# Take the internal name of the component as its key
|
119 |
+
self._components_dict[component_name] = component
|
120 |
+
|
121 |
+
def add_component(self, components):
|
122 |
+
"""
|
123 |
+
Add component(s) into the corresponding manager.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
components (function|class|list|tuple): Support four types of components.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
components (function|class|list|tuple): Same with input components.
|
130 |
+
"""
|
131 |
+
|
132 |
+
# Check whether the type is a sequence
|
133 |
+
if isinstance(components, Sequence):
|
134 |
+
for component in components:
|
135 |
+
self._add_single_component(component)
|
136 |
+
else:
|
137 |
+
component = components
|
138 |
+
self._add_single_component(component)
|
139 |
+
|
140 |
+
return components
|
141 |
+
|
142 |
+
|
143 |
+
MODELS = ComponentManager("models")
|
144 |
+
BACKBONES = ComponentManager("backbones")
|
145 |
+
DATASETS = ComponentManager("datasets")
|
146 |
+
TRANSFORMS = ComponentManager("transforms")
|
147 |
+
LOSSES = ComponentManager("losses")
|
paddleseg/cvlibs/param_init.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import paddle.nn as nn
|
16 |
+
|
17 |
+
|
18 |
+
def constant_init(param, **kwargs):
|
19 |
+
"""
|
20 |
+
Initialize the `param` with constants.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
param (Tensor): Tensor that needs to be initialized.
|
24 |
+
|
25 |
+
Examples:
|
26 |
+
|
27 |
+
from paddleseg.cvlibs import param_init
|
28 |
+
import paddle.nn as nn
|
29 |
+
|
30 |
+
linear = nn.Linear(2, 4)
|
31 |
+
param_init.constant_init(linear.weight, value=2.0)
|
32 |
+
print(linear.weight.numpy())
|
33 |
+
# result is [[2. 2. 2. 2.], [2. 2. 2. 2.]]
|
34 |
+
|
35 |
+
"""
|
36 |
+
initializer = nn.initializer.Constant(**kwargs)
|
37 |
+
initializer(param, param.block)
|
38 |
+
|
39 |
+
|
40 |
+
def normal_init(param, **kwargs):
|
41 |
+
"""
|
42 |
+
Initialize the `param` with a Normal distribution.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
param (Tensor): Tensor that needs to be initialized.
|
46 |
+
|
47 |
+
Examples:
|
48 |
+
|
49 |
+
from paddleseg.cvlibs import param_init
|
50 |
+
import paddle.nn as nn
|
51 |
+
|
52 |
+
linear = nn.Linear(2, 4)
|
53 |
+
param_init.normal_init(linear.weight, loc=0.0, scale=1.0)
|
54 |
+
|
55 |
+
"""
|
56 |
+
initializer = nn.initializer.Normal(**kwargs)
|
57 |
+
initializer(param, param.block)
|
58 |
+
|
59 |
+
|
60 |
+
def kaiming_normal_init(param, **kwargs):
|
61 |
+
r"""
|
62 |
+
Initialize the input tensor with Kaiming Normal initialization.
|
63 |
+
|
64 |
+
This function implements the `param` initialization from the paper
|
65 |
+
`Delving Deep into Rectifiers: Surpassing Human-Level Performance on
|
66 |
+
ImageNet Classification <https://arxiv.org/abs/1502.01852>`
|
67 |
+
by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. This is a
|
68 |
+
robust initialization method that particularly considers the rectifier
|
69 |
+
nonlinearities. In case of Uniform distribution, the range is [-x, x], where
|
70 |
+
.. math::
|
71 |
+
x = \sqrt{\\frac{6.0}{fan\_in}}
|
72 |
+
In case of Normal distribution, the mean is 0 and the standard deviation
|
73 |
+
is
|
74 |
+
.. math::
|
75 |
+
\sqrt{\\frac{2.0}{fan\_in}}
|
76 |
+
|
77 |
+
Args:
|
78 |
+
param (Tensor): Tensor that needs to be initialized.
|
79 |
+
|
80 |
+
Examples:
|
81 |
+
|
82 |
+
from paddleseg.cvlibs import param_init
|
83 |
+
import paddle.nn as nn
|
84 |
+
|
85 |
+
linear = nn.Linear(2, 4)
|
86 |
+
# uniform is used to decide whether to use uniform or normal distribution
|
87 |
+
param_init.kaiming_normal_init(linear.weight)
|
88 |
+
|
89 |
+
"""
|
90 |
+
initializer = nn.initializer.KaimingNormal(**kwargs)
|
91 |
+
initializer(param, param.block)
|
92 |
+
|
93 |
+
|
94 |
+
def kaiming_uniform(param, **kwargs):
|
95 |
+
r"""Implements the Kaiming Uniform initializer
|
96 |
+
This class implements the weight initialization from the paper
|
97 |
+
`Delving Deep into Rectifiers: Surpassing Human-Level Performance on
|
98 |
+
ImageNet Classification <https://arxiv.org/abs/1502.01852>`_
|
99 |
+
by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. This is a
|
100 |
+
robust initialization method that particularly considers the rectifier
|
101 |
+
nonlinearities.
|
102 |
+
|
103 |
+
In case of Uniform distribution, the range is [-x, x], where
|
104 |
+
.. math::
|
105 |
+
x = \sqrt{\\frac{6.0}{fan\_in}}
|
106 |
+
|
107 |
+
Args:
|
108 |
+
param (Tensor): Tensor that needs to be initialized.
|
109 |
+
|
110 |
+
Examples:
|
111 |
+
|
112 |
+
from paddleseg.cvlibs import param_init
|
113 |
+
import paddle.nn as nn
|
114 |
+
|
115 |
+
linear = nn.Linear(2, 4)
|
116 |
+
param_init.kaiming_uniform(linear.weight)
|
117 |
+
"""
|
118 |
+
|
119 |
+
initializer = nn.initializer.KaimingUniform(**kwargs)
|
120 |
+
initializer(param, param.block)
|
121 |
+
|
122 |
+
|
123 |
+
def xavier_uniform(param, **kwargs):
|
124 |
+
r"""
|
125 |
+
This implements the Xavier weight initializer from the paper
|
126 |
+
`Understanding the difficulty of training deep feedforward neural
|
127 |
+
networks <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_
|
128 |
+
by Xavier Glorot and Yoshua Bengio.
|
129 |
+
This initializer is designed to keep the scale of the gradients
|
130 |
+
approximately same in all the layers. In case of Uniform distribution,
|
131 |
+
the range is [-x, x], where
|
132 |
+
.. math::
|
133 |
+
x = \sqrt{\frac{6.0}{fan\_in + fan\_out}}
|
134 |
+
Args:
|
135 |
+
param (Tensor): Tensor that needs to be initialized.
|
136 |
+
|
137 |
+
Examples:
|
138 |
+
|
139 |
+
from paddleseg.cvlibs import param_init
|
140 |
+
import paddle.nn as nn
|
141 |
+
|
142 |
+
linear = nn.Linear(2, 4)
|
143 |
+
param_init.xavier_uniform(linear.weight)
|
144 |
+
"""
|
145 |
+
initializer = nn.initializer.XavierUniform(**kwargs)
|
146 |
+
initializer(param, param.block)
|
paddleseg/datasets/__init__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .dataset import Dataset
|
16 |
+
from .cityscapes import Cityscapes
|
17 |
+
from .voc import PascalVOC
|
18 |
+
from .ade import ADE20K
|
19 |
+
from .optic_disc_seg import OpticDiscSeg
|
20 |
+
from .pascal_context import PascalContext
|
21 |
+
from .mini_deep_globe_road_extraction import MiniDeepGlobeRoadExtraction
|
22 |
+
from .eg1800 import EG1800
|
23 |
+
from .supervisely import SUPERVISELY
|
24 |
+
from .cocostuff import CocoStuff
|
25 |
+
from .stare import STARE
|
26 |
+
from .drive import DRIVE
|
27 |
+
from .hrf import HRF
|
28 |
+
from .chase_db1 import CHASEDB1
|
29 |
+
from .pp_humanseg14k import PPHumanSeg14K
|
30 |
+
from .pssl import PSSLDataset
|
paddleseg/datasets/ade.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
from PIL import Image
|
19 |
+
|
20 |
+
from paddleseg.datasets import Dataset
|
21 |
+
from paddleseg.utils.download import download_file_and_uncompress
|
22 |
+
from paddleseg.utils import seg_env
|
23 |
+
from paddleseg.cvlibs import manager
|
24 |
+
from paddleseg.transforms import Compose
|
25 |
+
import paddleseg.transforms.functional as F
|
26 |
+
|
27 |
+
URL = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
|
28 |
+
|
29 |
+
|
30 |
+
@manager.DATASETS.add_component
|
31 |
+
class ADE20K(Dataset):
|
32 |
+
"""
|
33 |
+
ADE20K dataset `http://sceneparsing.csail.mit.edu/`.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
transforms (list): A list of image transformations.
|
37 |
+
dataset_root (str, optional): The ADK20K dataset directory. Default: None.
|
38 |
+
mode (str, optional): A subset of the entire dataset. It should be one of ('train', 'val'). Default: 'train'.
|
39 |
+
edge (bool, optional): Whether to compute edge while training. Default: False
|
40 |
+
"""
|
41 |
+
NUM_CLASSES = 150
|
42 |
+
|
43 |
+
def __init__(self, transforms, dataset_root=None, mode='train', edge=False):
|
44 |
+
self.dataset_root = dataset_root
|
45 |
+
self.transforms = Compose(transforms)
|
46 |
+
mode = mode.lower()
|
47 |
+
self.mode = mode
|
48 |
+
self.file_list = list()
|
49 |
+
self.num_classes = self.NUM_CLASSES
|
50 |
+
self.ignore_index = 255
|
51 |
+
self.edge = edge
|
52 |
+
|
53 |
+
if mode not in ['train', 'val']:
|
54 |
+
raise ValueError(
|
55 |
+
"`mode` should be one of ('train', 'val') in ADE20K dataset, but got {}."
|
56 |
+
.format(mode))
|
57 |
+
|
58 |
+
if self.transforms is None:
|
59 |
+
raise ValueError("`transforms` is necessary, but it is None.")
|
60 |
+
|
61 |
+
if self.dataset_root is None:
|
62 |
+
self.dataset_root = download_file_and_uncompress(
|
63 |
+
url=URL,
|
64 |
+
savepath=seg_env.DATA_HOME,
|
65 |
+
extrapath=seg_env.DATA_HOME,
|
66 |
+
extraname='ADEChallengeData2016')
|
67 |
+
elif not os.path.exists(self.dataset_root):
|
68 |
+
self.dataset_root = os.path.normpath(self.dataset_root)
|
69 |
+
savepath, extraname = self.dataset_root.rsplit(
|
70 |
+
sep=os.path.sep, maxsplit=1)
|
71 |
+
self.dataset_root = download_file_and_uncompress(
|
72 |
+
url=URL,
|
73 |
+
savepath=savepath,
|
74 |
+
extrapath=savepath,
|
75 |
+
extraname=extraname)
|
76 |
+
|
77 |
+
if mode == 'train':
|
78 |
+
img_dir = os.path.join(self.dataset_root, 'images/training')
|
79 |
+
label_dir = os.path.join(self.dataset_root, 'annotations/training')
|
80 |
+
elif mode == 'val':
|
81 |
+
img_dir = os.path.join(self.dataset_root, 'images/validation')
|
82 |
+
label_dir = os.path.join(self.dataset_root,
|
83 |
+
'annotations/validation')
|
84 |
+
img_files = os.listdir(img_dir)
|
85 |
+
label_files = [i.replace('.jpg', '.png') for i in img_files]
|
86 |
+
for i in range(len(img_files)):
|
87 |
+
img_path = os.path.join(img_dir, img_files[i])
|
88 |
+
label_path = os.path.join(label_dir, label_files[i])
|
89 |
+
self.file_list.append([img_path, label_path])
|
90 |
+
|
91 |
+
def __getitem__(self, idx):
|
92 |
+
data = {}
|
93 |
+
data['trans_info'] = []
|
94 |
+
image_path, label_path = self.file_list[idx]
|
95 |
+
data['img'] = image_path
|
96 |
+
data['gt_fields'] = [
|
97 |
+
] # If key in gt_fields, the data[key] have transforms synchronous.
|
98 |
+
|
99 |
+
if self.mode == 'val':
|
100 |
+
data = self.transforms(data)
|
101 |
+
label = np.asarray(Image.open(label_path))
|
102 |
+
# The class 0 is ignored. And it will equal to 255 after
|
103 |
+
# subtracted 1, because the dtype of label is uint8.
|
104 |
+
label = label - 1
|
105 |
+
label = label[np.newaxis, :, :]
|
106 |
+
data['label'] = label
|
107 |
+
return data
|
108 |
+
else:
|
109 |
+
data['label'] = label_path
|
110 |
+
data['gt_fields'].append('label')
|
111 |
+
data = self.transforms(data)
|
112 |
+
data['label'] = data['label'] - 1
|
113 |
+
# Recover the ignore pixels adding by transform
|
114 |
+
data['label'][data['label'] == 254] = 255
|
115 |
+
if self.edge:
|
116 |
+
edge_mask = F.mask_to_binary_edge(
|
117 |
+
label, radius=2, num_classes=self.num_classes)
|
118 |
+
data['edge'] = edge_mask
|
119 |
+
return data
|
paddleseg/datasets/chase_db1.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
from paddleseg.utils.download import download_file_and_uncompress
|
18 |
+
from paddleseg.utils import seg_env
|
19 |
+
from paddleseg.cvlibs import manager
|
20 |
+
from paddleseg.transforms import Compose
|
21 |
+
from paddleseg.datasets import Dataset
|
22 |
+
|
23 |
+
URL = 'https://bj.bcebos.com/paddleseg/dataset/chase_db1/chase_db1.zip'
|
24 |
+
|
25 |
+
|
26 |
+
@manager.DATASETS.add_component
|
27 |
+
class CHASEDB1(Dataset):
|
28 |
+
"""
|
29 |
+
CHASE_DB1 dataset is a dataset for retinal vessel segmentation
|
30 |
+
which contains 28 color retina images with the size of 999×960 pixels.
|
31 |
+
It is collected from both left and right eyes of 14 school children.
|
32 |
+
Each image is annotated by two independent human experts, and we choose the labels from 1st expert.
|
33 |
+
(https://blogs.kingston.ac.uk/retinal/chasedb1/)
|
34 |
+
|
35 |
+
Args:
|
36 |
+
transforms (list): Transforms for image.
|
37 |
+
dataset_root (str): The dataset directory. Default: None
|
38 |
+
edge (bool): whether extract edge infor in the output
|
39 |
+
mode (str, optional): Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
|
40 |
+
"""
|
41 |
+
NUM_CLASSES = 2
|
42 |
+
|
43 |
+
def __init__(self,
|
44 |
+
dataset_root=None,
|
45 |
+
transforms=None,
|
46 |
+
edge=False,
|
47 |
+
mode='train'):
|
48 |
+
self.dataset_root = dataset_root
|
49 |
+
self.transforms = Compose(transforms)
|
50 |
+
mode = mode.lower()
|
51 |
+
self.mode = mode
|
52 |
+
self.edge = edge
|
53 |
+
self.file_list = list()
|
54 |
+
self.num_classes = self.NUM_CLASSES
|
55 |
+
self.ignore_index = 255 # labels only have 1/0, thus ignore_index is not necessary
|
56 |
+
|
57 |
+
if mode not in ['train', 'val', 'test']:
|
58 |
+
raise ValueError(
|
59 |
+
"`mode` should be 'train', 'val' or 'test', but got {}.".format(
|
60 |
+
mode))
|
61 |
+
|
62 |
+
if self.transforms is None:
|
63 |
+
raise ValueError("`transforms` is necessary, but it is None.")
|
64 |
+
|
65 |
+
if self.dataset_root is None:
|
66 |
+
self.dataset_root = download_file_and_uncompress(
|
67 |
+
url=URL,
|
68 |
+
savepath=seg_env.DATA_HOME,
|
69 |
+
extrapath=seg_env.DATA_HOME)
|
70 |
+
elif not os.path.exists(self.dataset_root):
|
71 |
+
self.dataset_root = os.path.normpath(self.dataset_root)
|
72 |
+
savepath, extraname = self.dataset_root.rsplit(
|
73 |
+
sep=os.path.sep, maxsplit=1)
|
74 |
+
self.dataset_root = download_file_and_uncompress(
|
75 |
+
url=URL,
|
76 |
+
savepath=savepath,
|
77 |
+
extrapath=savepath,
|
78 |
+
extraname=extraname)
|
79 |
+
|
80 |
+
if mode == 'train':
|
81 |
+
file_path = os.path.join(self.dataset_root, 'train_list.txt')
|
82 |
+
elif mode == 'val':
|
83 |
+
file_path = os.path.join(self.dataset_root, 'val_list.txt')
|
84 |
+
|
85 |
+
with open(file_path, 'r') as f:
|
86 |
+
for line in f:
|
87 |
+
items = line.strip().split()
|
88 |
+
if len(items) != 2:
|
89 |
+
if mode == 'train' or mode == 'val':
|
90 |
+
raise Exception(
|
91 |
+
"File list format incorrect! It should be"
|
92 |
+
" image_name label_name\\n")
|
93 |
+
image_path = os.path.join(self.dataset_root, items[0])
|
94 |
+
grt_path = None
|
95 |
+
else:
|
96 |
+
image_path = os.path.join(self.dataset_root, items[0])
|
97 |
+
grt_path = os.path.join(self.dataset_root, items[1])
|
98 |
+
self.file_list.append([image_path, grt_path])
|
paddleseg/datasets/cityscapes.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import glob
|
17 |
+
|
18 |
+
from paddleseg.datasets import Dataset
|
19 |
+
from paddleseg.cvlibs import manager
|
20 |
+
from paddleseg.transforms import Compose
|
21 |
+
|
22 |
+
|
23 |
+
@manager.DATASETS.add_component
|
24 |
+
class Cityscapes(Dataset):
|
25 |
+
"""
|
26 |
+
Cityscapes dataset `https://www.cityscapes-dataset.com/`.
|
27 |
+
The folder structure is as follow:
|
28 |
+
|
29 |
+
cityscapes
|
30 |
+
|
|
31 |
+
|--leftImg8bit
|
32 |
+
| |--train
|
33 |
+
| |--val
|
34 |
+
| |--test
|
35 |
+
|
|
36 |
+
|--gtFine
|
37 |
+
| |--train
|
38 |
+
| |--val
|
39 |
+
| |--test
|
40 |
+
|
41 |
+
Make sure there are **labelTrainIds.png in gtFine directory. If not, please run the conver_cityscapes.py in tools.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
transforms (list): Transforms for image.
|
45 |
+
dataset_root (str): Cityscapes dataset directory.
|
46 |
+
mode (str, optional): Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
|
47 |
+
edge (bool, optional): Whether to compute edge while training. Default: False
|
48 |
+
"""
|
49 |
+
NUM_CLASSES = 19
|
50 |
+
|
51 |
+
def __init__(self, transforms, dataset_root, mode='train', edge=False):
|
52 |
+
self.dataset_root = dataset_root
|
53 |
+
self.transforms = Compose(transforms)
|
54 |
+
self.file_list = list()
|
55 |
+
mode = mode.lower()
|
56 |
+
self.mode = mode
|
57 |
+
self.num_classes = self.NUM_CLASSES
|
58 |
+
self.ignore_index = 255
|
59 |
+
self.edge = edge
|
60 |
+
|
61 |
+
if mode not in ['train', 'val', 'test']:
|
62 |
+
raise ValueError(
|
63 |
+
"mode should be 'train', 'val' or 'test', but got {}.".format(
|
64 |
+
mode))
|
65 |
+
|
66 |
+
if self.transforms is None:
|
67 |
+
raise ValueError("`transforms` is necessary, but it is None.")
|
68 |
+
|
69 |
+
img_dir = os.path.join(self.dataset_root, 'leftImg8bit')
|
70 |
+
label_dir = os.path.join(self.dataset_root, 'gtFine')
|
71 |
+
if self.dataset_root is None or not os.path.isdir(
|
72 |
+
self.dataset_root) or not os.path.isdir(
|
73 |
+
img_dir) or not os.path.isdir(label_dir):
|
74 |
+
raise ValueError(
|
75 |
+
"The dataset is not Found or the folder structure is nonconfoumance."
|
76 |
+
)
|
77 |
+
|
78 |
+
label_files = sorted(
|
79 |
+
glob.glob(
|
80 |
+
os.path.join(label_dir, mode, '*',
|
81 |
+
'*_gtFine_labelTrainIds.png')))
|
82 |
+
img_files = sorted(
|
83 |
+
glob.glob(os.path.join(img_dir, mode, '*', '*_leftImg8bit.png')))
|
84 |
+
|
85 |
+
self.file_list = [
|
86 |
+
[img_path, label_path]
|
87 |
+
for img_path, label_path in zip(img_files, label_files)
|
88 |
+
]
|
paddleseg/datasets/cocostuff.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import glob
|
17 |
+
|
18 |
+
from paddleseg.datasets import Dataset
|
19 |
+
from paddleseg.cvlibs import manager
|
20 |
+
from paddleseg.transforms import Compose
|
21 |
+
|
22 |
+
|
23 |
+
@manager.DATASETS.add_component
|
24 |
+
class CocoStuff(Dataset):
|
25 |
+
"""
|
26 |
+
COCO-Stuff dataset `https://github.com/nightrome/cocostuff`.
|
27 |
+
The folder structure is as follow:
|
28 |
+
|
29 |
+
cocostuff
|
30 |
+
|
|
31 |
+
|--images
|
32 |
+
| |--train2017
|
33 |
+
| |--val2017
|
34 |
+
|
|
35 |
+
|--annotations
|
36 |
+
| |--train2017
|
37 |
+
| |--val2017
|
38 |
+
|
39 |
+
|
40 |
+
Args:
|
41 |
+
transforms (list): Transforms for image.
|
42 |
+
dataset_root (str): Cityscapes dataset directory.
|
43 |
+
mode (str): Which part of dataset to use. it is one of ('train', 'val'). Default: 'train'.
|
44 |
+
edge (bool, optional): Whether to compute edge while training. Default: False
|
45 |
+
"""
|
46 |
+
NUM_CLASSES = 171
|
47 |
+
|
48 |
+
def __init__(self, transforms, dataset_root, mode='train', edge=False):
|
49 |
+
self.dataset_root = dataset_root
|
50 |
+
self.transforms = Compose(transforms)
|
51 |
+
self.file_list = list()
|
52 |
+
mode = mode.lower()
|
53 |
+
self.mode = mode
|
54 |
+
self.num_classes = self.NUM_CLASSES
|
55 |
+
self.ignore_index = 255
|
56 |
+
self.edge = edge
|
57 |
+
|
58 |
+
if mode not in ['train', 'val']:
|
59 |
+
raise ValueError(
|
60 |
+
"mode should be 'train', 'val', but got {}.".format(mode))
|
61 |
+
|
62 |
+
if self.transforms is None:
|
63 |
+
raise ValueError("`transforms` is necessary, but it is None.")
|
64 |
+
|
65 |
+
img_dir = os.path.join(self.dataset_root, 'images')
|
66 |
+
label_dir = os.path.join(self.dataset_root, 'annotations')
|
67 |
+
if self.dataset_root is None or not os.path.isdir(
|
68 |
+
self.dataset_root) or not os.path.isdir(
|
69 |
+
img_dir) or not os.path.isdir(label_dir):
|
70 |
+
raise ValueError(
|
71 |
+
"The dataset is not Found or the folder structure is nonconfoumance."
|
72 |
+
)
|
73 |
+
|
74 |
+
label_files = sorted(
|
75 |
+
glob.glob(os.path.join(label_dir, mode + '2017', '*.png')))
|
76 |
+
|
77 |
+
img_files = sorted(
|
78 |
+
glob.glob(os.path.join(img_dir, mode + '2017', '*.jpg')))
|
79 |
+
|
80 |
+
self.file_list = [
|
81 |
+
[img_path, label_path]
|
82 |
+
for img_path, label_path in zip(img_files, label_files)
|
83 |
+
]
|
paddleseg/datasets/dataset.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
import paddle
|
18 |
+
import numpy as np
|
19 |
+
from PIL import Image
|
20 |
+
|
21 |
+
from paddleseg.cvlibs import manager
|
22 |
+
from paddleseg.transforms import Compose
|
23 |
+
import paddleseg.transforms.functional as F
|
24 |
+
|
25 |
+
|
26 |
+
@manager.DATASETS.add_component
|
27 |
+
class Dataset(paddle.io.Dataset):
|
28 |
+
"""
|
29 |
+
Pass in a custom dataset that conforms to the format.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
transforms (list): Transforms for image.
|
33 |
+
dataset_root (str): The dataset directory.
|
34 |
+
num_classes (int): Number of classes.
|
35 |
+
mode (str, optional): which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
|
36 |
+
train_path (str, optional): The train dataset file. When mode is 'train', train_path is necessary.
|
37 |
+
The contents of train_path file are as follow:
|
38 |
+
image1.jpg ground_truth1.png
|
39 |
+
image2.jpg ground_truth2.png
|
40 |
+
val_path (str. optional): The evaluation dataset file. When mode is 'val', val_path is necessary.
|
41 |
+
The contents is the same as train_path
|
42 |
+
test_path (str, optional): The test dataset file. When mode is 'test', test_path is necessary.
|
43 |
+
The annotation file is not necessary in test_path file.
|
44 |
+
separator (str, optional): The separator of dataset list. Default: ' '.
|
45 |
+
edge (bool, optional): Whether to compute edge while training. Default: False
|
46 |
+
|
47 |
+
Examples:
|
48 |
+
|
49 |
+
import paddleseg.transforms as T
|
50 |
+
from paddleseg.datasets import Dataset
|
51 |
+
|
52 |
+
transforms = [T.RandomPaddingCrop(crop_size=(512,512)), T.Normalize()]
|
53 |
+
dataset_root = 'dataset_root_path'
|
54 |
+
train_path = 'train_path'
|
55 |
+
num_classes = 2
|
56 |
+
dataset = Dataset(transforms = transforms,
|
57 |
+
dataset_root = dataset_root,
|
58 |
+
num_classes = 2,
|
59 |
+
train_path = train_path,
|
60 |
+
mode = 'train')
|
61 |
+
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(self,
|
65 |
+
transforms,
|
66 |
+
dataset_root,
|
67 |
+
num_classes,
|
68 |
+
mode='train',
|
69 |
+
train_path=None,
|
70 |
+
val_path=None,
|
71 |
+
test_path=None,
|
72 |
+
separator=' ',
|
73 |
+
ignore_index=255,
|
74 |
+
edge=False):
|
75 |
+
self.dataset_root = dataset_root
|
76 |
+
self.transforms = Compose(transforms)
|
77 |
+
self.file_list = list()
|
78 |
+
self.mode = mode.lower()
|
79 |
+
self.num_classes = num_classes
|
80 |
+
self.ignore_index = ignore_index
|
81 |
+
self.edge = edge
|
82 |
+
|
83 |
+
if self.mode not in ['train', 'val', 'test']:
|
84 |
+
raise ValueError(
|
85 |
+
"mode should be 'train', 'val' or 'test', but got {}.".format(
|
86 |
+
self.mode))
|
87 |
+
|
88 |
+
if self.transforms is None:
|
89 |
+
raise ValueError("`transforms` is necessary, but it is None.")
|
90 |
+
|
91 |
+
if not os.path.exists(self.dataset_root):
|
92 |
+
raise FileNotFoundError('there is not `dataset_root`: {}.'.format(
|
93 |
+
self.dataset_root))
|
94 |
+
|
95 |
+
if self.mode == 'train':
|
96 |
+
if train_path is None:
|
97 |
+
raise ValueError(
|
98 |
+
'When `mode` is "train", `train_path` is necessary, but it is None.'
|
99 |
+
)
|
100 |
+
elif not os.path.exists(train_path):
|
101 |
+
raise FileNotFoundError('`train_path` is not found: {}'.format(
|
102 |
+
train_path))
|
103 |
+
else:
|
104 |
+
file_path = train_path
|
105 |
+
elif self.mode == 'val':
|
106 |
+
if val_path is None:
|
107 |
+
raise ValueError(
|
108 |
+
'When `mode` is "val", `val_path` is necessary, but it is None.'
|
109 |
+
)
|
110 |
+
elif not os.path.exists(val_path):
|
111 |
+
raise FileNotFoundError('`val_path` is not found: {}'.format(
|
112 |
+
val_path))
|
113 |
+
else:
|
114 |
+
file_path = val_path
|
115 |
+
else:
|
116 |
+
if test_path is None:
|
117 |
+
raise ValueError(
|
118 |
+
'When `mode` is "test", `test_path` is necessary, but it is None.'
|
119 |
+
)
|
120 |
+
elif not os.path.exists(test_path):
|
121 |
+
raise FileNotFoundError('`test_path` is not found: {}'.format(
|
122 |
+
test_path))
|
123 |
+
else:
|
124 |
+
file_path = test_path
|
125 |
+
|
126 |
+
with open(file_path, 'r') as f:
|
127 |
+
for line in f:
|
128 |
+
items = line.strip().split(separator)
|
129 |
+
if len(items) != 2:
|
130 |
+
if self.mode == 'train' or self.mode == 'val':
|
131 |
+
raise ValueError(
|
132 |
+
"File list format incorrect! In training or evaluation task it should be"
|
133 |
+
" image_name{}label_name\\n".format(separator))
|
134 |
+
image_path = os.path.join(self.dataset_root, items[0])
|
135 |
+
label_path = None
|
136 |
+
else:
|
137 |
+
image_path = os.path.join(self.dataset_root, items[0])
|
138 |
+
label_path = os.path.join(self.dataset_root, items[1])
|
139 |
+
self.file_list.append([image_path, label_path])
|
140 |
+
|
141 |
+
def __getitem__(self, idx):
|
142 |
+
data = {}
|
143 |
+
data['trans_info'] = []
|
144 |
+
image_path, label_path = self.file_list[idx]
|
145 |
+
data['img'] = image_path
|
146 |
+
data['label'] = label_path
|
147 |
+
# If key in gt_fields, the data[key] have transforms synchronous.
|
148 |
+
data['gt_fields'] = []
|
149 |
+
if self.mode == 'val':
|
150 |
+
data = self.transforms(data)
|
151 |
+
data['label'] = data['label'][np.newaxis, :, :]
|
152 |
+
|
153 |
+
else:
|
154 |
+
data['gt_fields'].append('label')
|
155 |
+
data = self.transforms(data)
|
156 |
+
if self.edge:
|
157 |
+
edge_mask = F.mask_to_binary_edge(
|
158 |
+
data['label'], radius=2, num_classes=self.num_classes)
|
159 |
+
data['edge'] = edge_mask
|
160 |
+
return data
|
161 |
+
|
162 |
+
def __len__(self):
|
163 |
+
return len(self.file_list)
|
paddleseg/datasets/drive.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
from paddleseg.utils.download import download_file_and_uncompress
|
18 |
+
from paddleseg.utils import seg_env
|
19 |
+
from paddleseg.cvlibs import manager
|
20 |
+
from paddleseg.transforms import Compose
|
21 |
+
from paddleseg.datasets import Dataset
|
22 |
+
|
23 |
+
URL = 'https://bj.bcebos.com/paddleseg/dataset/drive/drive.zip'
|
24 |
+
|
25 |
+
|
26 |
+
@manager.DATASETS.add_component
|
27 |
+
class DRIVE(Dataset):
|
28 |
+
"""
|
29 |
+
The Digital Retinal Images for Vessel Extraction (DRIVE) dataset is a dataset for retinal vessel segmentation.
|
30 |
+
It consists of a total of JPEG 40 color fundus images which is of size (584, 565); including 7 abnormal pathology cases.
|
31 |
+
(http://www.isi.uu.nl/Research/Databases/DRIVE/)
|
32 |
+
|
33 |
+
Args:
|
34 |
+
transforms (list): Transforms for image.
|
35 |
+
dataset_root (str): The dataset directory. Default: None
|
36 |
+
edge (bool): whether extract edge infor in the output
|
37 |
+
mode (str, optional): Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
|
38 |
+
"""
|
39 |
+
NUM_CLASSES = 2
|
40 |
+
|
41 |
+
def __init__(self,
|
42 |
+
dataset_root=None,
|
43 |
+
transforms=None,
|
44 |
+
edge=False,
|
45 |
+
mode='train'):
|
46 |
+
self.dataset_root = dataset_root
|
47 |
+
self.transforms = Compose(transforms)
|
48 |
+
mode = mode.lower()
|
49 |
+
self.mode = mode
|
50 |
+
self.edge = edge
|
51 |
+
self.file_list = list()
|
52 |
+
self.num_classes = self.NUM_CLASSES
|
53 |
+
self.ignore_index = 255 # labels only have 1/0, thus ignore_index is not necessary
|
54 |
+
|
55 |
+
if mode not in ['train', 'val', 'test']:
|
56 |
+
raise ValueError(
|
57 |
+
"`mode` should be 'train', 'val' or 'test', but got {}.".format(
|
58 |
+
mode))
|
59 |
+
|
60 |
+
if self.transforms is None:
|
61 |
+
raise ValueError("`transforms` is necessary, but it is None.")
|
62 |
+
|
63 |
+
if self.dataset_root is None:
|
64 |
+
self.dataset_root = download_file_and_uncompress(
|
65 |
+
url=URL,
|
66 |
+
savepath=seg_env.DATA_HOME,
|
67 |
+
extrapath=seg_env.DATA_HOME)
|
68 |
+
elif not os.path.exists(self.dataset_root):
|
69 |
+
self.dataset_root = os.path.normpath(self.dataset_root)
|
70 |
+
savepath, extraname = self.dataset_root.rsplit(
|
71 |
+
sep=os.path.sep, maxsplit=1)
|
72 |
+
self.dataset_root = download_file_and_uncompress(
|
73 |
+
url=URL,
|
74 |
+
savepath=savepath,
|
75 |
+
extrapath=savepath,
|
76 |
+
extraname=extraname)
|
77 |
+
|
78 |
+
if mode == 'train':
|
79 |
+
file_path = os.path.join(self.dataset_root, 'train_list.txt')
|
80 |
+
elif mode == 'val':
|
81 |
+
file_path = os.path.join(self.dataset_root, 'val_list.txt')
|
82 |
+
|
83 |
+
with open(file_path, 'r') as f:
|
84 |
+
for line in f:
|
85 |
+
items = line.strip().split()
|
86 |
+
if len(items) != 2:
|
87 |
+
if mode == 'train' or mode == 'val':
|
88 |
+
raise Exception(
|
89 |
+
"File list format incorrect! It should be"
|
90 |
+
" image_name label_name\\n")
|
91 |
+
image_path = os.path.join(self.dataset_root, items[0])
|
92 |
+
grt_path = None
|
93 |
+
else:
|
94 |
+
image_path = os.path.join(self.dataset_root, items[0])
|
95 |
+
grt_path = os.path.join(self.dataset_root, items[1])
|
96 |
+
self.file_list.append([image_path, grt_path])
|
paddleseg/datasets/eg1800.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import copy
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
from paddleseg.datasets import Dataset
|
22 |
+
from paddleseg.cvlibs import manager
|
23 |
+
from paddleseg.transforms import Compose
|
24 |
+
from paddleseg.utils.download import download_file_and_uncompress
|
25 |
+
from paddleseg.utils import seg_env
|
26 |
+
import paddleseg.transforms.functional as F
|
27 |
+
|
28 |
+
URL = "https://paddleseg.bj.bcebos.com/dataset/EG1800.zip"
|
29 |
+
|
30 |
+
|
31 |
+
@manager.DATASETS.add_component
|
32 |
+
class EG1800(Dataset):
|
33 |
+
"""
|
34 |
+
EG1800 dataset `http://xiaoyongshen.me/webpage_portrait/index.html`.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
common_transforms (list): A list of common image transformations for two inputs of portrait net.
|
38 |
+
transforms1 (list): A list of image transformations for the first input of portrait net.
|
39 |
+
transforms2 (list): A list of image transformations for the second input of portrait net.
|
40 |
+
dataset_root (str, optional): The EG1800 dataset directory. Default: None.
|
41 |
+
mode (str, optional): A subset of the entire dataset. It should be one of ('train', 'val'). Default: 'train'.
|
42 |
+
edge (bool, optional): Whether to compute edge while training. Default: False
|
43 |
+
"""
|
44 |
+
NUM_CLASSES = 2
|
45 |
+
|
46 |
+
def __init__(self,
|
47 |
+
common_transforms,
|
48 |
+
transforms1,
|
49 |
+
transforms2,
|
50 |
+
dataset_root=None,
|
51 |
+
mode='train',
|
52 |
+
edge=False):
|
53 |
+
self.dataset_root = dataset_root
|
54 |
+
self.common_transforms = Compose(common_transforms)
|
55 |
+
self.transforms = self.common_transforms
|
56 |
+
if transforms1 is not None:
|
57 |
+
self.transforms1 = Compose(transforms1, to_rgb=False)
|
58 |
+
if transforms2 is not None:
|
59 |
+
self.transforms2 = Compose(transforms2, to_rgb=False)
|
60 |
+
mode = mode.lower()
|
61 |
+
self.ignore_index = 255
|
62 |
+
self.mode = mode
|
63 |
+
self.num_classes = self.NUM_CLASSES
|
64 |
+
self.input_width = 224
|
65 |
+
self.input_height = 224
|
66 |
+
|
67 |
+
if self.dataset_root is None:
|
68 |
+
self.dataset_root = download_file_and_uncompress(
|
69 |
+
url=URL,
|
70 |
+
savepath=seg_env.DATA_HOME,
|
71 |
+
extrapath=seg_env.DATA_HOME)
|
72 |
+
elif not os.path.exists(self.dataset_root):
|
73 |
+
self.dataset_root = os.path.normpath(self.dataset_root)
|
74 |
+
savepath, extraname = self.dataset_root.rsplit(
|
75 |
+
sep=os.path.sep, maxsplit=1)
|
76 |
+
self.dataset_root = download_file_and_uncompress(
|
77 |
+
url=URL,
|
78 |
+
savepath=savepath,
|
79 |
+
extrapath=savepath,
|
80 |
+
extraname=extraname)
|
81 |
+
|
82 |
+
if mode == 'train':
|
83 |
+
path = os.path.join(dataset_root, 'eg1800_train.txt')
|
84 |
+
else:
|
85 |
+
path = os.path.join(dataset_root, 'eg1800_test.txt')
|
86 |
+
with open(path, 'r') as f:
|
87 |
+
files = f.readlines()
|
88 |
+
img_files = [
|
89 |
+
os.path.join(dataset_root, 'Images', file).strip() for file in files
|
90 |
+
]
|
91 |
+
label_files = [
|
92 |
+
os.path.join(dataset_root, 'Labels', file).strip() for file in files
|
93 |
+
]
|
94 |
+
|
95 |
+
self.file_list = [
|
96 |
+
[img_path, label_path]
|
97 |
+
for img_path, label_path in zip(img_files, label_files)
|
98 |
+
]
|
99 |
+
pass
|
100 |
+
|
101 |
+
def __getitem__(self, item):
|
102 |
+
image_path, label_path = self.file_list[item]
|
103 |
+
im = cv2.imread(image_path)
|
104 |
+
label = cv2.imread(label_path, 0)
|
105 |
+
label[label > 1] = 0
|
106 |
+
|
107 |
+
if self.mode == "val":
|
108 |
+
common_im, label = self.common_transforms(im=im, label=label)
|
109 |
+
im = np.float32(common_im[::-1, :, :]) # RGB => BGR
|
110 |
+
im_aug = copy.deepcopy(im)
|
111 |
+
else:
|
112 |
+
common_im, label = self.common_transforms(im=im, label=label)
|
113 |
+
common_im = np.transpose(common_im, [1, 2, 0])
|
114 |
+
# add augmentation
|
115 |
+
im, _ = self.transforms1(common_im)
|
116 |
+
im_aug, _ = self.transforms2(common_im)
|
117 |
+
|
118 |
+
im = np.float32(im[::-1, :, :]) # RGB => BGR
|
119 |
+
im_aug = np.float32(im_aug[::-1, :, :]) # RGB => BGR
|
120 |
+
|
121 |
+
label = cv2.resize(
|
122 |
+
np.uint8(label), (self.input_width, self.input_height),
|
123 |
+
interpolation=cv2.INTER_NEAREST)
|
124 |
+
|
125 |
+
# add mask blur
|
126 |
+
label = np.uint8(cv2.blur(label, (5, 5)))
|
127 |
+
label[label >= 0.5] = 1
|
128 |
+
label[label < 0.5] = 0
|
129 |
+
|
130 |
+
edge_mask = F.mask_to_binary_edge(
|
131 |
+
label, radius=4, num_classes=self.num_classes)
|
132 |
+
edge_mask = np.transpose(edge_mask, [1, 2, 0]).squeeze(axis=-1)
|
133 |
+
im = np.concatenate([im_aug, im])
|
134 |
+
if self.mode == "train":
|
135 |
+
return im, label, edge_mask
|
136 |
+
else:
|
137 |
+
return im, label
|
paddleseg/datasets/hrf.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
from paddleseg.utils.download import download_file_and_uncompress
|
18 |
+
from paddleseg.utils import seg_env
|
19 |
+
from paddleseg.cvlibs import manager
|
20 |
+
from paddleseg.transforms import Compose
|
21 |
+
from paddleseg.datasets import Dataset
|
22 |
+
|
23 |
+
URL = 'https://bj.bcebos.com/paddleseg/dataset/hrf/hrf.zip'
|
24 |
+
|
25 |
+
|
26 |
+
@manager.DATASETS.add_component
|
27 |
+
class HRF(Dataset):
|
28 |
+
"""
|
29 |
+
The HRF dataset is a dataset for retinal vessel segmentation which comprises 45 images and is organized as 15 subsets. Each subset contains one healthy fundus image, one image of patient with diabetic retinopathy and one glaucoma image. The image sizes are 3,304 x 2,336, with a training/testing image split of 21/24.
|
30 |
+
(https://doi.org/10.1155/2013/154860)
|
31 |
+
|
32 |
+
Args:
|
33 |
+
transforms (list): Transforms for image.
|
34 |
+
dataset_root (str): The dataset directory. Default: None
|
35 |
+
edge (bool): whether extract edge infor in the output
|
36 |
+
mode (str, optional): Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
|
37 |
+
"""
|
38 |
+
NUM_CLASSES = 2
|
39 |
+
|
40 |
+
def __init__(self,
|
41 |
+
dataset_root=None,
|
42 |
+
transforms=None,
|
43 |
+
edge=False,
|
44 |
+
mode='train'):
|
45 |
+
self.dataset_root = dataset_root
|
46 |
+
self.transforms = Compose(transforms)
|
47 |
+
mode = mode.lower()
|
48 |
+
self.mode = mode
|
49 |
+
self.edge = edge
|
50 |
+
self.file_list = list()
|
51 |
+
self.num_classes = self.NUM_CLASSES
|
52 |
+
self.ignore_index = 255
|
53 |
+
|
54 |
+
if mode not in ['train', 'val', 'test']:
|
55 |
+
raise ValueError(
|
56 |
+
"`mode` should be 'train', 'val' or 'test', but got {}.".format(
|
57 |
+
mode))
|
58 |
+
|
59 |
+
if self.transforms is None:
|
60 |
+
raise ValueError("`transforms` is necessary, but it is None.")
|
61 |
+
|
62 |
+
if self.dataset_root is None:
|
63 |
+
self.dataset_root = download_file_and_uncompress(
|
64 |
+
url=URL,
|
65 |
+
savepath=seg_env.DATA_HOME,
|
66 |
+
extrapath=seg_env.DATA_HOME)
|
67 |
+
elif not os.path.exists(self.dataset_root):
|
68 |
+
self.dataset_root = os.path.normpath(self.dataset_root)
|
69 |
+
savepath, extraname = self.dataset_root.rsplit(
|
70 |
+
sep=os.path.sep, maxsplit=1)
|
71 |
+
self.dataset_root = download_file_and_uncompress(
|
72 |
+
url=URL,
|
73 |
+
savepath=savepath,
|
74 |
+
extrapath=savepath,
|
75 |
+
extraname=extraname)
|
76 |
+
|
77 |
+
if mode == 'train':
|
78 |
+
file_path = os.path.join(self.dataset_root, 'train_list.txt')
|
79 |
+
elif mode == 'val':
|
80 |
+
file_path = os.path.join(self.dataset_root, 'val_list.txt')
|
81 |
+
|
82 |
+
with open(file_path, 'r') as f:
|
83 |
+
for line in f:
|
84 |
+
items = line.strip().split()
|
85 |
+
if len(items) != 2:
|
86 |
+
if mode == 'train' or mode == 'val':
|
87 |
+
raise Exception(
|
88 |
+
"File list format incorrect! It should be"
|
89 |
+
" image_name label_name\\n")
|
90 |
+
image_path = os.path.join(self.dataset_root, items[0])
|
91 |
+
grt_path = None
|
92 |
+
else:
|
93 |
+
image_path = os.path.join(self.dataset_root, items[0])
|
94 |
+
grt_path = os.path.join(self.dataset_root, items[1])
|
95 |
+
self.file_list.append([image_path, grt_path])
|
paddleseg/datasets/mini_deep_globe_road_extraction.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
from .dataset import Dataset
|
18 |
+
from paddleseg.utils.download import download_file_and_uncompress
|
19 |
+
from paddleseg.utils import seg_env
|
20 |
+
from paddleseg.cvlibs import manager
|
21 |
+
from paddleseg.transforms import Compose
|
22 |
+
|
23 |
+
URL = "https://paddleseg.bj.bcebos.com/dataset/MiniDeepGlobeRoadExtraction.zip"
|
24 |
+
|
25 |
+
|
26 |
+
@manager.DATASETS.add_component
|
27 |
+
class MiniDeepGlobeRoadExtraction(Dataset):
|
28 |
+
"""
|
29 |
+
MiniDeepGlobeRoadExtraction dataset is extraced from DeepGlobe CVPR2018 challenge (http://deepglobe.org/)
|
30 |
+
|
31 |
+
There are 800 images in the training set and 200 images in the validation set.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
dataset_root (str, optional): The dataset directory. Default: None.
|
35 |
+
transforms (list, optional): Transforms for image. Default: None.
|
36 |
+
mode (str, optional): Which part of dataset to use. It is one of ('train', 'val'). Default: 'train'.
|
37 |
+
edge (bool, optional): Whether to compute edge while training. Default: False.
|
38 |
+
"""
|
39 |
+
NUM_CLASSES = 2
|
40 |
+
|
41 |
+
def __init__(self,
|
42 |
+
dataset_root=None,
|
43 |
+
transforms=None,
|
44 |
+
mode='train',
|
45 |
+
edge=False):
|
46 |
+
self.dataset_root = dataset_root
|
47 |
+
self.transforms = Compose(transforms)
|
48 |
+
mode = mode.lower()
|
49 |
+
self.mode = mode
|
50 |
+
self.file_list = list()
|
51 |
+
self.num_classes = self.NUM_CLASSES
|
52 |
+
self.ignore_index = 255
|
53 |
+
self.edge = edge
|
54 |
+
|
55 |
+
if mode not in ['train', 'val']:
|
56 |
+
raise ValueError(
|
57 |
+
"`mode` should be 'train' or 'val', but got {}.".format(mode))
|
58 |
+
|
59 |
+
if self.transforms is None:
|
60 |
+
raise ValueError("`transforms` is necessary, but it is None.")
|
61 |
+
|
62 |
+
if self.dataset_root is None:
|
63 |
+
self.dataset_root = download_file_and_uncompress(
|
64 |
+
url=URL,
|
65 |
+
savepath=seg_env.DATA_HOME,
|
66 |
+
extrapath=seg_env.DATA_HOME)
|
67 |
+
elif not os.path.exists(self.dataset_root):
|
68 |
+
self.dataset_root = os.path.normpath(self.dataset_root)
|
69 |
+
savepath, extraname = self.dataset_root.rsplit(
|
70 |
+
sep=os.path.sep, maxsplit=1)
|
71 |
+
self.dataset_root = download_file_and_uncompress(
|
72 |
+
url=URL,
|
73 |
+
savepath=savepath,
|
74 |
+
extrapath=savepath,
|
75 |
+
extraname=extraname)
|
76 |
+
|
77 |
+
if mode == 'train':
|
78 |
+
file_path = os.path.join(self.dataset_root, 'train.txt')
|
79 |
+
else:
|
80 |
+
file_path = os.path.join(self.dataset_root, 'val.txt')
|
81 |
+
|
82 |
+
with open(file_path, 'r') as f:
|
83 |
+
for line in f:
|
84 |
+
items = line.strip().split('|')
|
85 |
+
if len(items) != 2:
|
86 |
+
if mode == 'train' or mode == 'val':
|
87 |
+
raise Exception(
|
88 |
+
"File list format incorrect! It should be"
|
89 |
+
" image_name|label_name\\n")
|
90 |
+
image_path = os.path.join(self.dataset_root, items[0])
|
91 |
+
grt_path = None
|
92 |
+
else:
|
93 |
+
image_path = os.path.join(self.dataset_root, items[0])
|
94 |
+
grt_path = os.path.join(self.dataset_root, items[1])
|
95 |
+
self.file_list.append([image_path, grt_path])
|
paddleseg/datasets/optic_disc_seg.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
from .dataset import Dataset
|
18 |
+
from paddleseg.utils.download import download_file_and_uncompress
|
19 |
+
from paddleseg.utils import seg_env
|
20 |
+
from paddleseg.cvlibs import manager
|
21 |
+
from paddleseg.transforms import Compose
|
22 |
+
|
23 |
+
URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
|
24 |
+
|
25 |
+
|
26 |
+
@manager.DATASETS.add_component
|
27 |
+
class OpticDiscSeg(Dataset):
|
28 |
+
"""
|
29 |
+
OpticDiscSeg dataset is extraced from iChallenge-AMD
|
30 |
+
(https://ai.baidu.com/broad/subordinate?dataset=amd).
|
31 |
+
|
32 |
+
Args:
|
33 |
+
transforms (list): Transforms for image.
|
34 |
+
dataset_root (str): The dataset directory. Default: None
|
35 |
+
mode (str, optional): Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
|
36 |
+
edge (bool, optional): Whether to compute edge while training. Default: False
|
37 |
+
"""
|
38 |
+
NUM_CLASSES = 2
|
39 |
+
|
40 |
+
def __init__(self,
|
41 |
+
dataset_root=None,
|
42 |
+
transforms=None,
|
43 |
+
mode='train',
|
44 |
+
edge=False):
|
45 |
+
self.dataset_root = dataset_root
|
46 |
+
self.transforms = Compose(transforms)
|
47 |
+
mode = mode.lower()
|
48 |
+
self.mode = mode
|
49 |
+
self.file_list = list()
|
50 |
+
self.num_classes = self.NUM_CLASSES
|
51 |
+
self.ignore_index = 255
|
52 |
+
self.edge = edge
|
53 |
+
|
54 |
+
if mode not in ['train', 'val', 'test']:
|
55 |
+
raise ValueError(
|
56 |
+
"`mode` should be 'train', 'val' or 'test', but got {}.".format(
|
57 |
+
mode))
|
58 |
+
|
59 |
+
if self.transforms is None:
|
60 |
+
raise ValueError("`transforms` is necessary, but it is None.")
|
61 |
+
|
62 |
+
if self.dataset_root is None:
|
63 |
+
self.dataset_root = download_file_and_uncompress(
|
64 |
+
url=URL,
|
65 |
+
savepath=seg_env.DATA_HOME,
|
66 |
+
extrapath=seg_env.DATA_HOME)
|
67 |
+
elif not os.path.exists(self.dataset_root):
|
68 |
+
self.dataset_root = os.path.normpath(self.dataset_root)
|
69 |
+
savepath, extraname = self.dataset_root.rsplit(
|
70 |
+
sep=os.path.sep, maxsplit=1)
|
71 |
+
self.dataset_root = download_file_and_uncompress(
|
72 |
+
url=URL,
|
73 |
+
savepath=savepath,
|
74 |
+
extrapath=savepath,
|
75 |
+
extraname=extraname)
|
76 |
+
|
77 |
+
if mode == 'train':
|
78 |
+
file_path = os.path.join(self.dataset_root, 'train_list.txt')
|
79 |
+
elif mode == 'val':
|
80 |
+
file_path = os.path.join(self.dataset_root, 'val_list.txt')
|
81 |
+
else:
|
82 |
+
file_path = os.path.join(self.dataset_root, 'test_list.txt')
|
83 |
+
|
84 |
+
with open(file_path, 'r') as f:
|
85 |
+
for line in f:
|
86 |
+
items = line.strip().split()
|
87 |
+
if len(items) != 2:
|
88 |
+
if mode == 'train' or mode == 'val':
|
89 |
+
raise Exception(
|
90 |
+
"File list format incorrect! It should be"
|
91 |
+
" image_name label_name\\n")
|
92 |
+
image_path = os.path.join(self.dataset_root, items[0])
|
93 |
+
grt_path = None
|
94 |
+
else:
|
95 |
+
image_path = os.path.join(self.dataset_root, items[0])
|
96 |
+
grt_path = os.path.join(self.dataset_root, items[1])
|
97 |
+
self.file_list.append([image_path, grt_path])
|
paddleseg/datasets/pascal_context.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
from PIL import Image
|
18 |
+
from paddleseg.datasets import Dataset
|
19 |
+
from paddleseg.cvlibs import manager
|
20 |
+
from paddleseg.transforms import Compose
|
21 |
+
|
22 |
+
|
23 |
+
@manager.DATASETS.add_component
|
24 |
+
class PascalContext(Dataset):
|
25 |
+
"""
|
26 |
+
PascalVOC2010 dataset `http://host.robots.ox.ac.uk/pascal/VOC/`.
|
27 |
+
If you want to use pascal context dataset, please run the convert_voc2010.py in tools firstly.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
transforms (list): Transforms for image.
|
31 |
+
dataset_root (str): The dataset directory. Default: None
|
32 |
+
mode (str): Which part of dataset to use. it is one of ('train', 'trainval', 'context', 'val').
|
33 |
+
If you want to set mode to 'context', please make sure the dataset have been augmented. Default: 'train'.
|
34 |
+
edge (bool, optional): Whether to compute edge while training. Default: False
|
35 |
+
"""
|
36 |
+
NUM_CLASSES = 60
|
37 |
+
|
38 |
+
def __init__(self,
|
39 |
+
transforms=None,
|
40 |
+
dataset_root=None,
|
41 |
+
mode='train',
|
42 |
+
edge=False):
|
43 |
+
self.dataset_root = dataset_root
|
44 |
+
self.transforms = Compose(transforms)
|
45 |
+
mode = mode.lower()
|
46 |
+
self.mode = mode
|
47 |
+
self.file_list = list()
|
48 |
+
self.num_classes = self.NUM_CLASSES
|
49 |
+
self.ignore_index = 255
|
50 |
+
self.edge = edge
|
51 |
+
|
52 |
+
if mode not in ['train', 'trainval', 'val']:
|
53 |
+
raise ValueError(
|
54 |
+
"`mode` should be one of ('train', 'trainval', 'val') in PascalContext dataset, but got {}."
|
55 |
+
.format(mode))
|
56 |
+
|
57 |
+
if self.transforms is None:
|
58 |
+
raise ValueError("`transforms` is necessary, but it is None.")
|
59 |
+
if self.dataset_root is None:
|
60 |
+
raise ValueError(
|
61 |
+
"The dataset is not Found or the folder structure is nonconfoumance."
|
62 |
+
)
|
63 |
+
|
64 |
+
image_set_dir = os.path.join(self.dataset_root, 'ImageSets',
|
65 |
+
'Segmentation')
|
66 |
+
|
67 |
+
if mode == 'train':
|
68 |
+
file_path = os.path.join(image_set_dir, 'train_context.txt')
|
69 |
+
elif mode == 'val':
|
70 |
+
file_path = os.path.join(image_set_dir, 'val_context.txt')
|
71 |
+
elif mode == 'trainval':
|
72 |
+
file_path = os.path.join(image_set_dir, 'trainval_context.txt')
|
73 |
+
if not os.path.exists(file_path):
|
74 |
+
raise RuntimeError(
|
75 |
+
"PASCAL-Context annotations are not ready, "
|
76 |
+
"Please make sure voc_context.py has been properly run.")
|
77 |
+
|
78 |
+
img_dir = os.path.join(self.dataset_root, 'JPEGImages')
|
79 |
+
label_dir = os.path.join(self.dataset_root, 'Context')
|
80 |
+
|
81 |
+
with open(file_path, 'r') as f:
|
82 |
+
for line in f:
|
83 |
+
line = line.strip()
|
84 |
+
image_path = os.path.join(img_dir, ''.join([line, '.jpg']))
|
85 |
+
label_path = os.path.join(label_dir, ''.join([line, '.png']))
|
86 |
+
self.file_list.append([image_path, label_path])
|
paddleseg/datasets/pp_humanseg14k.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
from .dataset import Dataset
|
18 |
+
from paddleseg.cvlibs import manager
|
19 |
+
from paddleseg.transforms import Compose
|
20 |
+
|
21 |
+
|
22 |
+
@manager.DATASETS.add_component
|
23 |
+
class PPHumanSeg14K(Dataset):
|
24 |
+
"""
|
25 |
+
This is the PP-HumanSeg14K Dataset.
|
26 |
+
|
27 |
+
This dataset was introduced in the work:
|
28 |
+
Chu, Lutao, et al. "PP-HumanSeg: Connectivity-Aware Portrait Segmentation with a Large-Scale Teleconferencing Video Dataset." Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision. 2022.
|
29 |
+
|
30 |
+
This dataset is divided into training set, validation set and test set. The training set includes 8770 pictures, the validation set includes 2431 pictures, and the test set includes 2482 pictures.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
dataset_root (str, optional): The dataset directory. Default: None.
|
34 |
+
transforms (list, optional): Transforms for image. Default: None.
|
35 |
+
mode (str, optional): Which part of dataset to use. It is one of ('train', 'val'). Default: 'train'.
|
36 |
+
edge (bool, optional): Whether to compute edge while training. Default: False.
|
37 |
+
"""
|
38 |
+
NUM_CLASSES = 2
|
39 |
+
|
40 |
+
def __init__(self,
|
41 |
+
dataset_root=None,
|
42 |
+
transforms=None,
|
43 |
+
mode='train',
|
44 |
+
edge=False):
|
45 |
+
self.dataset_root = dataset_root
|
46 |
+
self.transforms = Compose(transforms)
|
47 |
+
mode = mode.lower()
|
48 |
+
self.mode = mode
|
49 |
+
self.file_list = list()
|
50 |
+
self.num_classes = self.NUM_CLASSES
|
51 |
+
self.ignore_index = 255
|
52 |
+
self.edge = edge
|
53 |
+
|
54 |
+
if mode not in ['train', 'val', 'test']:
|
55 |
+
raise ValueError(
|
56 |
+
"`mode` should be 'train', 'val' or 'test', but got {}.".format(
|
57 |
+
mode))
|
58 |
+
|
59 |
+
if self.transforms is None:
|
60 |
+
raise ValueError("`transforms` is necessary, but it is None.")
|
61 |
+
|
62 |
+
if mode == 'train':
|
63 |
+
file_path = os.path.join(self.dataset_root, 'train.txt')
|
64 |
+
elif mode == 'val':
|
65 |
+
file_path = os.path.join(self.dataset_root, 'val.txt')
|
66 |
+
else:
|
67 |
+
file_path = os.path.join(self.dataset_root, 'test.txt')
|
68 |
+
|
69 |
+
with open(file_path, 'r') as f:
|
70 |
+
for line in f:
|
71 |
+
items = line.strip().split(' ')
|
72 |
+
if len(items) != 2:
|
73 |
+
if mode == 'train' or mode == 'val':
|
74 |
+
raise Exception(
|
75 |
+
"File list format incorrect! It should be"
|
76 |
+
" image_name label_name\\n")
|
77 |
+
image_path = os.path.join(self.dataset_root, items[0])
|
78 |
+
grt_path = None
|
79 |
+
else:
|
80 |
+
image_path = os.path.join(self.dataset_root, items[0])
|
81 |
+
grt_path = os.path.join(self.dataset_root, items[1])
|
82 |
+
self.file_list.append([image_path, grt_path])
|
paddleseg/datasets/pssl.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
from paddleseg.datasets import Dataset
|
19 |
+
from paddleseg.cvlibs import manager
|
20 |
+
from paddleseg.transforms import Compose
|
21 |
+
|
22 |
+
|
23 |
+
@manager.DATASETS.add_component
|
24 |
+
class PSSLDataset(Dataset):
|
25 |
+
"""
|
26 |
+
The PSSL dataset for segmentation. PSSL is short for Pseudo Semantic Segmentation Labels, where the pseudo label
|
27 |
+
is computed by the Consensus explanation algorithm.
|
28 |
+
|
29 |
+
The PSSL refers to "Distilling Ensemble of Explanations for Weakly-Supervised Pre-Training of Image Segmentation
|
30 |
+
Models" (https://arxiv.org/abs/2207.03335).
|
31 |
+
|
32 |
+
The Consensus explanation refers to "Cross-Model Consensus of Explanations and Beyond for Image Classification
|
33 |
+
Models: An Empirical Study" (https://arxiv.org/abs/2109.00707).
|
34 |
+
|
35 |
+
To use this dataset, we need to additionally prepare the orignal ImageNet dataset, which has the folder structure
|
36 |
+
as follows:
|
37 |
+
|
38 |
+
imagenet_root
|
39 |
+
|
|
40 |
+
|--train
|
41 |
+
| |--n01440764
|
42 |
+
| | |--n01440764_10026.JPEG
|
43 |
+
| | |--...
|
44 |
+
| |--nxxxxxxxx
|
45 |
+
| |--...
|
46 |
+
|
47 |
+
where only the "train" set is needed.
|
48 |
+
|
49 |
+
The PSSL dataset has the folder structure as follows:
|
50 |
+
|
51 |
+
pssl_root
|
52 |
+
|
|
53 |
+
|--train
|
54 |
+
| |--n01440764
|
55 |
+
| | |--n01440764_10026.JPEG_eiseg.npz
|
56 |
+
| | |--...
|
57 |
+
| |--nxxxxxxxx
|
58 |
+
| |--...
|
59 |
+
|
|
60 |
+
|--imagenet_lsvrc_2015_synsets.txt
|
61 |
+
|--train.txt
|
62 |
+
|
63 |
+
where "train.txt" and "imagenet_lsvrc_2015_synsets.txt" are included in the PSSL dataset.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
transforms (list): Transforms for image.
|
67 |
+
imagenet_root (str): The path to the original ImageNet dataset.
|
68 |
+
pssl_root (str): The path to the PSSL dataset.
|
69 |
+
mode (str, optional): Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
|
70 |
+
edge (bool, optional): Whether to compute edge while training. Default: False.
|
71 |
+
"""
|
72 |
+
ignore_index = 1001 # 0~999 is target class, 1000 is bg
|
73 |
+
NUM_CLASSES = 1001 # consider target class and bg
|
74 |
+
|
75 |
+
def __init__(self,
|
76 |
+
transforms,
|
77 |
+
imagenet_root,
|
78 |
+
pssl_root,
|
79 |
+
mode='train',
|
80 |
+
edge=False):
|
81 |
+
mode = mode.lower()
|
82 |
+
if mode not in ['train']:
|
83 |
+
raise ValueError("mode should be 'train', but got {}.".format(mode))
|
84 |
+
if transforms is None:
|
85 |
+
raise ValueError("`transforms` is necessary, but it is None.")
|
86 |
+
|
87 |
+
self.transforms = Compose(transforms)
|
88 |
+
self.mode = mode
|
89 |
+
self.edge = edge
|
90 |
+
|
91 |
+
self.num_classes = self.NUM_CLASSES
|
92 |
+
self.ignore_index = self.num_classes # 1001
|
93 |
+
self.file_list = []
|
94 |
+
self.class_id_dict = {}
|
95 |
+
|
96 |
+
if imagenet_root is None or not os.path.isdir(pssl_root):
|
97 |
+
raise ValueError(
|
98 |
+
"The dataset is not Found or the folder structure is nonconfoumance."
|
99 |
+
)
|
100 |
+
|
101 |
+
train_list_file = os.path.join(pssl_root, "train.txt")
|
102 |
+
if not os.path.exists(train_list_file):
|
103 |
+
raise ValueError("Train list file isn't exists.")
|
104 |
+
for idx, line in enumerate(open(train_list_file)):
|
105 |
+
# line: train/n04118776/n04118776_45912.JPEG_eiseg.npz
|
106 |
+
label_path = line.strip()
|
107 |
+
img_path = label_path.split('.JPEG')[0] + '.JPEG'
|
108 |
+
label_path = os.path.join(pssl_root, label_path)
|
109 |
+
img_path = os.path.join(imagenet_root, img_path)
|
110 |
+
self.file_list.append([img_path, label_path])
|
111 |
+
|
112 |
+
# mapping class name to class id.
|
113 |
+
class_id_file = os.path.join(pssl_root,
|
114 |
+
"imagenet_lsvrc_2015_synsets.txt")
|
115 |
+
if not os.path.exists(class_id_file):
|
116 |
+
raise ValueError("Class id file isn't exists.")
|
117 |
+
for idx, line in enumerate(open(class_id_file)):
|
118 |
+
class_name = line.strip()
|
119 |
+
self.class_id_dict[class_name] = idx
|
120 |
+
|
121 |
+
def __getitem__(self, idx):
|
122 |
+
image_path, label_path = self.file_list[idx]
|
123 |
+
|
124 |
+
# transform label
|
125 |
+
class_name = (image_path.split('/')[-1]).split('_')[0]
|
126 |
+
class_id = self.class_id_dict[class_name]
|
127 |
+
|
128 |
+
pssl_seg = np.load(label_path)['arr_0']
|
129 |
+
gt_semantic_seg = np.zeros_like(pssl_seg, dtype=np.int64) + 1000
|
130 |
+
# [0, 999] for imagenet classes, 1000 for background, others(-1) will be ignored during training.
|
131 |
+
gt_semantic_seg[pssl_seg == 1] = class_id
|
132 |
+
|
133 |
+
im, label = self.transforms(im=image_path, label=gt_semantic_seg)
|
134 |
+
|
135 |
+
return im, label
|
paddleseg/datasets/stare.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
from paddleseg.utils.download import download_file_and_uncompress
|
18 |
+
from paddleseg.utils import seg_env
|
19 |
+
from paddleseg.cvlibs import manager
|
20 |
+
from paddleseg.transforms import Compose
|
21 |
+
from paddleseg.datasets import Dataset
|
22 |
+
|
23 |
+
URL = 'https://bj.bcebos.com/paddleseg/dataset/stare/stare.zip'
|
24 |
+
|
25 |
+
|
26 |
+
@manager.DATASETS.add_component
|
27 |
+
class STARE(Dataset):
|
28 |
+
"""
|
29 |
+
STARE dataset is processed from the STARE(STructured Analysis of the Retina) project.
|
30 |
+
(https://cecas.clemson.edu/~ahoover/stare/)
|
31 |
+
|
32 |
+
Args:
|
33 |
+
transforms (list): Transforms for image.
|
34 |
+
dataset_root (str): The dataset directory. Default: None
|
35 |
+
edge (bool): whether extract edge infor in the output
|
36 |
+
mode (str, optional): Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
|
37 |
+
"""
|
38 |
+
NUM_CLASSES = 2
|
39 |
+
|
40 |
+
def __init__(self,
|
41 |
+
dataset_root=None,
|
42 |
+
transforms=None,
|
43 |
+
edge=False,
|
44 |
+
mode='train'):
|
45 |
+
self.dataset_root = dataset_root
|
46 |
+
self.transforms = Compose(transforms)
|
47 |
+
mode = mode.lower()
|
48 |
+
self.mode = mode
|
49 |
+
self.edge = edge
|
50 |
+
self.file_list = list()
|
51 |
+
self.num_classes = self.NUM_CLASSES
|
52 |
+
self.ignore_index = 255
|
53 |
+
|
54 |
+
if mode not in ['train', 'val', 'test']:
|
55 |
+
raise ValueError(
|
56 |
+
"`mode` should be 'train', 'val' or 'test', but got {}.".format(
|
57 |
+
mode))
|
58 |
+
|
59 |
+
if self.transforms is None:
|
60 |
+
raise ValueError("`transforms` is necessary, but it is None.")
|
61 |
+
|
62 |
+
if self.dataset_root is None:
|
63 |
+
self.dataset_root = download_file_and_uncompress(
|
64 |
+
url=URL,
|
65 |
+
savepath=seg_env.DATA_HOME,
|
66 |
+
extrapath=seg_env.DATA_HOME)
|
67 |
+
elif not os.path.exists(self.dataset_root):
|
68 |
+
self.dataset_root = os.path.normpath(self.dataset_root)
|
69 |
+
savepath, extraname = self.dataset_root.rsplit(
|
70 |
+
sep=os.path.sep, maxsplit=1) # data STARE
|
71 |
+
self.dataset_root = download_file_and_uncompress(
|
72 |
+
url=URL,
|
73 |
+
savepath=savepath,
|
74 |
+
extrapath=savepath,
|
75 |
+
extraname=extraname)
|
76 |
+
|
77 |
+
if mode == 'train':
|
78 |
+
file_path = os.path.join(self.dataset_root, 'train_list.txt')
|
79 |
+
elif mode == 'val':
|
80 |
+
file_path = os.path.join(self.dataset_root, 'val_list.txt')
|
81 |
+
|
82 |
+
with open(file_path, 'r') as f:
|
83 |
+
for line in f:
|
84 |
+
items = line.strip().split()
|
85 |
+
if len(items) != 2:
|
86 |
+
if mode == 'train' or mode == 'val':
|
87 |
+
raise Exception(
|
88 |
+
"File list format incorrect! It should be"
|
89 |
+
" image_name label_name\\n")
|
90 |
+
image_path = os.path.join(self.dataset_root, items[0])
|
91 |
+
grt_path = None
|
92 |
+
else:
|
93 |
+
image_path = os.path.join(self.dataset_root, items[0])
|
94 |
+
grt_path = os.path.join(self.dataset_root, items[1])
|
95 |
+
self.file_list.append([image_path, grt_path])
|
paddleseg/datasets/supervisely.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import copy
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
from paddleseg.cvlibs import manager
|
22 |
+
from paddleseg.transforms import Compose
|
23 |
+
from paddleseg.datasets import Dataset
|
24 |
+
from paddleseg.utils.download import download_file_and_uncompress
|
25 |
+
from paddleseg.utils import seg_env
|
26 |
+
import paddleseg.transforms.functional as F
|
27 |
+
|
28 |
+
URL = "https://paddleseg.bj.bcebos.com/dataset/Supervisely_face.zip"
|
29 |
+
|
30 |
+
|
31 |
+
@manager.DATASETS.add_component
|
32 |
+
class SUPERVISELY(Dataset):
|
33 |
+
"""
|
34 |
+
Supervise.ly dataset `https://supervise.ly/`.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
common_transforms (list): A list of common image transformations for two inputs of portrait net.
|
38 |
+
transforms1 (list): A list of image transformations for the first input of portrait net.
|
39 |
+
transforms2 (list): A list of image transformations for the second input of portrait net.
|
40 |
+
dataset_root (str, optional): The Supervise.ly dataset directory. Default: None.
|
41 |
+
mode (str, optional): A subset of the entire dataset. It should be one of ('train', 'val'). Default: 'train'.
|
42 |
+
edge (bool, optional): Whether to compute edge while training. Default: False
|
43 |
+
"""
|
44 |
+
NUM_CLASSES = 2
|
45 |
+
|
46 |
+
def __init__(self,
|
47 |
+
common_transforms,
|
48 |
+
transforms1,
|
49 |
+
transforms2,
|
50 |
+
dataset_root=None,
|
51 |
+
mode='train',
|
52 |
+
edge=False):
|
53 |
+
self.dataset_root = dataset_root
|
54 |
+
self.common_transforms = Compose(common_transforms)
|
55 |
+
self.transforms = self.common_transforms
|
56 |
+
if transforms1 is not None:
|
57 |
+
self.transforms1 = Compose(transforms1, to_rgb=False)
|
58 |
+
if transforms2 is not None:
|
59 |
+
self.transforms2 = Compose(transforms2, to_rgb=False)
|
60 |
+
mode = mode.lower()
|
61 |
+
self.ignore_index = 255
|
62 |
+
self.mode = mode
|
63 |
+
self.num_classes = self.NUM_CLASSES
|
64 |
+
self.input_width = 224
|
65 |
+
self.input_height = 224
|
66 |
+
|
67 |
+
if self.dataset_root is None:
|
68 |
+
self.dataset_root = download_file_and_uncompress(
|
69 |
+
url=URL,
|
70 |
+
savepath=seg_env.DATA_HOME,
|
71 |
+
extrapath=seg_env.DATA_HOME)
|
72 |
+
elif not os.path.exists(self.dataset_root):
|
73 |
+
self.dataset_root = os.path.normpath(self.dataset_root)
|
74 |
+
savepath, extraname = self.dataset_root.rsplit(
|
75 |
+
sep=os.path.sep, maxsplit=1)
|
76 |
+
self.dataset_root = download_file_and_uncompress(
|
77 |
+
url=URL,
|
78 |
+
savepath=savepath,
|
79 |
+
extrapath=savepath,
|
80 |
+
extraname=extraname)
|
81 |
+
|
82 |
+
if mode == 'train':
|
83 |
+
path = os.path.join(dataset_root, 'supervisely_face_train_easy.txt')
|
84 |
+
else:
|
85 |
+
path = os.path.join(dataset_root, 'supervisely_face_test_easy.txt')
|
86 |
+
with open(path, 'r') as f:
|
87 |
+
files = f.readlines()
|
88 |
+
files = ["/".join(file.split('/')[1:]) for file in files]
|
89 |
+
img_files = [os.path.join(dataset_root, file).strip() for file in files]
|
90 |
+
label_files = [
|
91 |
+
os.path.join(dataset_root, file.replace('/img/', '/ann/')).strip()
|
92 |
+
for file in files
|
93 |
+
]
|
94 |
+
|
95 |
+
self.file_list = [
|
96 |
+
[img_path, label_path]
|
97 |
+
for img_path, label_path in zip(img_files, label_files)
|
98 |
+
]
|
99 |
+
|
100 |
+
def __getitem__(self, item):
|
101 |
+
image_path, label_path = self.file_list[item]
|
102 |
+
im = cv2.imread(image_path)
|
103 |
+
label = cv2.imread(label_path, 0)
|
104 |
+
label[label > 0] = 1
|
105 |
+
|
106 |
+
if self.mode == "val":
|
107 |
+
common_im, label = self.common_transforms(im=im, label=label)
|
108 |
+
im = np.float32(common_im[::-1, :, :]) # RGB => BGR
|
109 |
+
im_aug = copy.deepcopy(im)
|
110 |
+
else:
|
111 |
+
common_im, label = self.common_transforms(im=im, label=label)
|
112 |
+
common_im = np.transpose(common_im, [1, 2, 0])
|
113 |
+
# add augmentation
|
114 |
+
im, _ = self.transforms1(common_im)
|
115 |
+
im_aug, _ = self.transforms2(common_im)
|
116 |
+
|
117 |
+
im = np.float32(im[::-1, :, :]) # RGB => BGR
|
118 |
+
im_aug = np.float32(im_aug[::-1, :, :]) # RGB => BGR
|
119 |
+
|
120 |
+
label = cv2.resize(
|
121 |
+
np.uint8(label), (self.input_width, self.input_height),
|
122 |
+
interpolation=cv2.INTER_NEAREST)
|
123 |
+
|
124 |
+
# add mask blur
|
125 |
+
label = np.uint8(cv2.blur(label, (5, 5)))
|
126 |
+
label[label >= 0.5] = 1
|
127 |
+
label[label < 0.5] = 0
|
128 |
+
|
129 |
+
edge_mask = F.mask_to_binary_edge(
|
130 |
+
label, radius=4, num_classes=self.num_classes)
|
131 |
+
edge_mask = np.transpose(edge_mask, [1, 2, 0]).squeeze(axis=-1)
|
132 |
+
im = np.concatenate([im_aug, im])
|
133 |
+
if self.mode == "train":
|
134 |
+
return im, label, edge_mask
|
135 |
+
else:
|
136 |
+
return im, label
|
paddleseg/datasets/voc.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
from paddleseg.datasets import Dataset
|
18 |
+
from paddleseg.utils.download import download_file_and_uncompress
|
19 |
+
from paddleseg.utils import seg_env
|
20 |
+
from paddleseg.cvlibs import manager
|
21 |
+
from paddleseg.transforms import Compose
|
22 |
+
|
23 |
+
URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
|
24 |
+
|
25 |
+
|
26 |
+
@manager.DATASETS.add_component
|
27 |
+
class PascalVOC(Dataset):
|
28 |
+
"""
|
29 |
+
PascalVOC2012 dataset `http://host.robots.ox.ac.uk/pascal/VOC/`.
|
30 |
+
If you want to augment the dataset, please run the voc_augment.py in tools.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
transforms (list): Transforms for image.
|
34 |
+
dataset_root (str): The dataset directory. Default: None
|
35 |
+
mode (str, optional): Which part of dataset to use. it is one of ('train', 'trainval', 'trainaug', 'val').
|
36 |
+
If you want to set mode to 'trainaug', please make sure the dataset have been augmented. Default: 'train'.
|
37 |
+
edge (bool, optional): Whether to compute edge while training. Default: False
|
38 |
+
"""
|
39 |
+
NUM_CLASSES = 21
|
40 |
+
|
41 |
+
def __init__(self, transforms, dataset_root=None, mode='train', edge=False):
|
42 |
+
self.dataset_root = dataset_root
|
43 |
+
self.transforms = Compose(transforms)
|
44 |
+
mode = mode.lower()
|
45 |
+
self.mode = mode
|
46 |
+
self.file_list = list()
|
47 |
+
self.num_classes = self.NUM_CLASSES
|
48 |
+
self.ignore_index = 255
|
49 |
+
self.edge = edge
|
50 |
+
|
51 |
+
if mode not in ['train', 'trainval', 'trainaug', 'val']:
|
52 |
+
raise ValueError(
|
53 |
+
"`mode` should be one of ('train', 'trainval', 'trainaug', 'val') in PascalVOC dataset, but got {}."
|
54 |
+
.format(mode))
|
55 |
+
|
56 |
+
if self.transforms is None:
|
57 |
+
raise ValueError("`transforms` is necessary, but it is None.")
|
58 |
+
|
59 |
+
if self.dataset_root is None:
|
60 |
+
self.dataset_root = download_file_and_uncompress(
|
61 |
+
url=URL,
|
62 |
+
savepath=seg_env.DATA_HOME,
|
63 |
+
extrapath=seg_env.DATA_HOME,
|
64 |
+
extraname='VOCdevkit')
|
65 |
+
elif not os.path.exists(self.dataset_root):
|
66 |
+
self.dataset_root = os.path.normpath(self.dataset_root)
|
67 |
+
savepath, extraname = self.dataset_root.rsplit(
|
68 |
+
sep=os.path.sep, maxsplit=1)
|
69 |
+
self.dataset_root = download_file_and_uncompress(
|
70 |
+
url=URL,
|
71 |
+
savepath=savepath,
|
72 |
+
extrapath=savepath,
|
73 |
+
extraname=extraname)
|
74 |
+
|
75 |
+
image_set_dir = os.path.join(self.dataset_root, 'VOC2012', 'ImageSets',
|
76 |
+
'Segmentation')
|
77 |
+
if mode == 'train':
|
78 |
+
file_path = os.path.join(image_set_dir, 'train.txt')
|
79 |
+
elif mode == 'val':
|
80 |
+
file_path = os.path.join(image_set_dir, 'val.txt')
|
81 |
+
elif mode == 'trainval':
|
82 |
+
file_path = os.path.join(image_set_dir, 'trainval.txt')
|
83 |
+
elif mode == 'trainaug':
|
84 |
+
file_path = os.path.join(image_set_dir, 'train.txt')
|
85 |
+
file_path_aug = os.path.join(image_set_dir, 'aug.txt')
|
86 |
+
|
87 |
+
if not os.path.exists(file_path_aug):
|
88 |
+
raise RuntimeError(
|
89 |
+
"When `mode` is 'trainaug', Pascal Voc dataset should be augmented, "
|
90 |
+
"Please make sure voc_augment.py has been properly run when using this mode."
|
91 |
+
)
|
92 |
+
|
93 |
+
img_dir = os.path.join(self.dataset_root, 'VOC2012', 'JPEGImages')
|
94 |
+
label_dir = os.path.join(self.dataset_root, 'VOC2012',
|
95 |
+
'SegmentationClass')
|
96 |
+
label_dir_aug = os.path.join(self.dataset_root, 'VOC2012',
|
97 |
+
'SegmentationClassAug')
|
98 |
+
|
99 |
+
with open(file_path, 'r') as f:
|
100 |
+
for line in f:
|
101 |
+
line = line.strip()
|
102 |
+
image_path = os.path.join(img_dir, ''.join([line, '.jpg']))
|
103 |
+
label_path = os.path.join(label_dir, ''.join([line, '.png']))
|
104 |
+
self.file_list.append([image_path, label_path])
|
105 |
+
if mode == 'trainaug':
|
106 |
+
with open(file_path_aug, 'r') as f:
|
107 |
+
for line in f:
|
108 |
+
line = line.strip()
|
109 |
+
image_path = os.path.join(img_dir, ''.join([line, '.jpg']))
|
110 |
+
label_path = os.path.join(label_dir_aug,
|
111 |
+
''.join([line, '.png']))
|
112 |
+
self.file_list.append([image_path, label_path])
|
paddleseg/models/ann.py
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import paddle
|
16 |
+
import paddle.nn as nn
|
17 |
+
import paddle.nn.functional as F
|
18 |
+
|
19 |
+
from paddleseg.cvlibs import manager
|
20 |
+
from paddleseg.models import layers
|
21 |
+
from paddleseg.utils import utils
|
22 |
+
|
23 |
+
|
24 |
+
@manager.MODELS.add_component
|
25 |
+
class ANN(nn.Layer):
|
26 |
+
"""
|
27 |
+
The ANN implementation based on PaddlePaddle.
|
28 |
+
|
29 |
+
The original article refers to
|
30 |
+
Zhen, Zhu, et al. "Asymmetric Non-local Neural Networks for Semantic Segmentation"
|
31 |
+
(https://arxiv.org/pdf/1908.07678.pdf).
|
32 |
+
|
33 |
+
Args:
|
34 |
+
num_classes (int): The unique number of target classes.
|
35 |
+
backbone (Paddle.nn.Layer): Backbone network, currently support Resnet50/101.
|
36 |
+
backbone_indices (tuple, optional): Two values in the tuple indicate the indices of output of backbone.
|
37 |
+
key_value_channels (int, optional): The key and value channels of self-attention map in both AFNB and APNB modules.
|
38 |
+
Default: 256.
|
39 |
+
inter_channels (int, optional): Both input and output channels of APNB modules. Default: 512.
|
40 |
+
psp_size (tuple, optional): The out size of pooled feature maps. Default: (1, 3, 6, 8).
|
41 |
+
enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
|
42 |
+
align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
|
43 |
+
e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
|
44 |
+
pretrained (str, optional): The path or url of pretrained model. Default: None.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self,
|
48 |
+
num_classes,
|
49 |
+
backbone,
|
50 |
+
backbone_indices=(2, 3),
|
51 |
+
key_value_channels=256,
|
52 |
+
inter_channels=512,
|
53 |
+
psp_size=(1, 3, 6, 8),
|
54 |
+
enable_auxiliary_loss=True,
|
55 |
+
align_corners=False,
|
56 |
+
pretrained=None):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
self.backbone = backbone
|
60 |
+
backbone_channels = [
|
61 |
+
backbone.feat_channels[i] for i in backbone_indices
|
62 |
+
]
|
63 |
+
|
64 |
+
self.head = ANNHead(num_classes, backbone_indices, backbone_channels,
|
65 |
+
key_value_channels, inter_channels, psp_size,
|
66 |
+
enable_auxiliary_loss)
|
67 |
+
self.align_corners = align_corners
|
68 |
+
self.pretrained = pretrained
|
69 |
+
self.init_weight()
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
feat_list = self.backbone(x)
|
73 |
+
logit_list = self.head(feat_list)
|
74 |
+
return [
|
75 |
+
F.interpolate(
|
76 |
+
logit,
|
77 |
+
paddle.shape(x)[2:],
|
78 |
+
mode='bilinear',
|
79 |
+
align_corners=self.align_corners) for logit in logit_list
|
80 |
+
]
|
81 |
+
|
82 |
+
def init_weight(self):
|
83 |
+
if self.pretrained is not None:
|
84 |
+
utils.load_entire_model(self, self.pretrained)
|
85 |
+
|
86 |
+
|
87 |
+
class ANNHead(nn.Layer):
|
88 |
+
"""
|
89 |
+
The ANNHead implementation.
|
90 |
+
|
91 |
+
It mainly consists of AFNB and APNB modules.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
num_classes (int): The unique number of target classes.
|
95 |
+
backbone_indices (tuple): Two values in the tuple indicate the indices of output of backbone.
|
96 |
+
The first index will be taken as low-level features; the second one will be
|
97 |
+
taken as high-level features in AFNB module. Usually backbone consists of four
|
98 |
+
downsampling stage, such as ResNet, and return an output of each stage. If it is (2, 3),
|
99 |
+
it means taking feature map of the third stage and the fourth stage in backbone.
|
100 |
+
backbone_channels (tuple): The same length with "backbone_indices". It indicates the channels of corresponding index.
|
101 |
+
key_value_channels (int): The key and value channels of self-attention map in both AFNB and APNB modules.
|
102 |
+
inter_channels (int): Both input and output channels of APNB modules.
|
103 |
+
psp_size (tuple): The out size of pooled feature maps.
|
104 |
+
enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
|
105 |
+
"""
|
106 |
+
|
107 |
+
def __init__(self,
|
108 |
+
num_classes,
|
109 |
+
backbone_indices,
|
110 |
+
backbone_channels,
|
111 |
+
key_value_channels,
|
112 |
+
inter_channels,
|
113 |
+
psp_size,
|
114 |
+
enable_auxiliary_loss=True):
|
115 |
+
super().__init__()
|
116 |
+
|
117 |
+
low_in_channels = backbone_channels[0]
|
118 |
+
high_in_channels = backbone_channels[1]
|
119 |
+
|
120 |
+
self.fusion = AFNB(
|
121 |
+
low_in_channels=low_in_channels,
|
122 |
+
high_in_channels=high_in_channels,
|
123 |
+
out_channels=high_in_channels,
|
124 |
+
key_channels=key_value_channels,
|
125 |
+
value_channels=key_value_channels,
|
126 |
+
dropout_prob=0.05,
|
127 |
+
repeat_sizes=([1]),
|
128 |
+
psp_size=psp_size)
|
129 |
+
|
130 |
+
self.context = nn.Sequential(
|
131 |
+
layers.ConvBNReLU(
|
132 |
+
in_channels=high_in_channels,
|
133 |
+
out_channels=inter_channels,
|
134 |
+
kernel_size=3,
|
135 |
+
padding=1),
|
136 |
+
APNB(
|
137 |
+
in_channels=inter_channels,
|
138 |
+
out_channels=inter_channels,
|
139 |
+
key_channels=key_value_channels,
|
140 |
+
value_channels=key_value_channels,
|
141 |
+
dropout_prob=0.05,
|
142 |
+
repeat_sizes=([1]),
|
143 |
+
psp_size=psp_size))
|
144 |
+
|
145 |
+
self.cls = nn.Conv2D(
|
146 |
+
in_channels=inter_channels, out_channels=num_classes, kernel_size=1)
|
147 |
+
self.auxlayer = layers.AuxLayer(
|
148 |
+
in_channels=low_in_channels,
|
149 |
+
inter_channels=low_in_channels // 2,
|
150 |
+
out_channels=num_classes,
|
151 |
+
dropout_prob=0.05)
|
152 |
+
|
153 |
+
self.backbone_indices = backbone_indices
|
154 |
+
self.enable_auxiliary_loss = enable_auxiliary_loss
|
155 |
+
|
156 |
+
def forward(self, feat_list):
|
157 |
+
logit_list = []
|
158 |
+
low_level_x = feat_list[self.backbone_indices[0]]
|
159 |
+
high_level_x = feat_list[self.backbone_indices[1]]
|
160 |
+
x = self.fusion(low_level_x, high_level_x)
|
161 |
+
x = self.context(x)
|
162 |
+
logit = self.cls(x)
|
163 |
+
logit_list.append(logit)
|
164 |
+
|
165 |
+
if self.enable_auxiliary_loss:
|
166 |
+
auxiliary_logit = self.auxlayer(low_level_x)
|
167 |
+
logit_list.append(auxiliary_logit)
|
168 |
+
|
169 |
+
return logit_list
|
170 |
+
|
171 |
+
|
172 |
+
class AFNB(nn.Layer):
|
173 |
+
"""
|
174 |
+
Asymmetric Fusion Non-local Block.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
low_in_channels (int): Low-level-feature channels.
|
178 |
+
high_in_channels (int): High-level-feature channels.
|
179 |
+
out_channels (int): Out channels of AFNB module.
|
180 |
+
key_channels (int): The key channels in self-attention block.
|
181 |
+
value_channels (int): The value channels in self-attention block.
|
182 |
+
dropout_prob (float): The dropout rate of output.
|
183 |
+
repeat_sizes (tuple, optional): The number of AFNB modules. Default: ([1]).
|
184 |
+
psp_size (tuple. optional): The out size of pooled feature maps. Default: (1, 3, 6, 8).
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(self,
|
188 |
+
low_in_channels,
|
189 |
+
high_in_channels,
|
190 |
+
out_channels,
|
191 |
+
key_channels,
|
192 |
+
value_channels,
|
193 |
+
dropout_prob,
|
194 |
+
repeat_sizes=([1]),
|
195 |
+
psp_size=(1, 3, 6, 8)):
|
196 |
+
super().__init__()
|
197 |
+
|
198 |
+
self.psp_size = psp_size
|
199 |
+
self.stages = nn.LayerList([
|
200 |
+
SelfAttentionBlock_AFNB(low_in_channels, high_in_channels,
|
201 |
+
key_channels, value_channels, out_channels,
|
202 |
+
size) for size in repeat_sizes
|
203 |
+
])
|
204 |
+
self.conv_bn = layers.ConvBN(
|
205 |
+
in_channels=out_channels + high_in_channels,
|
206 |
+
out_channels=out_channels,
|
207 |
+
kernel_size=1)
|
208 |
+
self.dropout = nn.Dropout(p=dropout_prob)
|
209 |
+
|
210 |
+
def forward(self, low_feats, high_feats):
|
211 |
+
priors = [stage(low_feats, high_feats) for stage in self.stages]
|
212 |
+
context = priors[0]
|
213 |
+
for i in range(1, len(priors)):
|
214 |
+
context += priors[i]
|
215 |
+
|
216 |
+
output = self.conv_bn(paddle.concat([context, high_feats], axis=1))
|
217 |
+
output = self.dropout(output)
|
218 |
+
|
219 |
+
return output
|
220 |
+
|
221 |
+
|
222 |
+
class APNB(nn.Layer):
|
223 |
+
"""
|
224 |
+
Asymmetric Pyramid Non-local Block.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
in_channels (int): The input channels of APNB module.
|
228 |
+
out_channels (int): Out channels of APNB module.
|
229 |
+
key_channels (int): The key channels in self-attention block.
|
230 |
+
value_channels (int): The value channels in self-attention block.
|
231 |
+
dropout_prob (float): The dropout rate of output.
|
232 |
+
repeat_sizes (tuple, optional): The number of AFNB modules. Default: ([1]).
|
233 |
+
psp_size (tuple, optional): The out size of pooled feature maps. Default: (1, 3, 6, 8).
|
234 |
+
"""
|
235 |
+
|
236 |
+
def __init__(self,
|
237 |
+
in_channels,
|
238 |
+
out_channels,
|
239 |
+
key_channels,
|
240 |
+
value_channels,
|
241 |
+
dropout_prob,
|
242 |
+
repeat_sizes=([1]),
|
243 |
+
psp_size=(1, 3, 6, 8)):
|
244 |
+
super().__init__()
|
245 |
+
|
246 |
+
self.psp_size = psp_size
|
247 |
+
self.stages = nn.LayerList([
|
248 |
+
SelfAttentionBlock_APNB(in_channels, out_channels,
|
249 |
+
key_channels, value_channels, size)
|
250 |
+
for size in repeat_sizes
|
251 |
+
])
|
252 |
+
self.conv_bn = layers.ConvBNReLU(
|
253 |
+
in_channels=in_channels * 2,
|
254 |
+
out_channels=out_channels,
|
255 |
+
kernel_size=1)
|
256 |
+
self.dropout = nn.Dropout(p=dropout_prob)
|
257 |
+
|
258 |
+
def forward(self, x):
|
259 |
+
priors = [stage(x) for stage in self.stages]
|
260 |
+
context = priors[0]
|
261 |
+
for i in range(1, len(priors)):
|
262 |
+
context += priors[i]
|
263 |
+
|
264 |
+
output = self.conv_bn(paddle.concat([context, x], axis=1))
|
265 |
+
output = self.dropout(output)
|
266 |
+
|
267 |
+
return output
|
268 |
+
|
269 |
+
|
270 |
+
def _pp_module(x, psp_size):
|
271 |
+
n, c, h, w = x.shape
|
272 |
+
priors = []
|
273 |
+
for size in psp_size:
|
274 |
+
feat = F.adaptive_avg_pool2d(x, size)
|
275 |
+
feat = paddle.reshape(feat, shape=(0, c, -1))
|
276 |
+
priors.append(feat)
|
277 |
+
center = paddle.concat(priors, axis=-1)
|
278 |
+
return center
|
279 |
+
|
280 |
+
|
281 |
+
class SelfAttentionBlock_AFNB(nn.Layer):
|
282 |
+
"""
|
283 |
+
Self-Attention Block for AFNB module.
|
284 |
+
|
285 |
+
Args:
|
286 |
+
low_in_channels (int): Low-level-feature channels.
|
287 |
+
high_in_channels (int): High-level-feature channels.
|
288 |
+
key_channels (int): The key channels in self-attention block.
|
289 |
+
value_channels (int): The value channels in self-attention block.
|
290 |
+
out_channels (int, optional): Out channels of AFNB module. Default: None.
|
291 |
+
scale (int, optional): Pooling size. Default: 1.
|
292 |
+
psp_size (tuple, optional): The out size of pooled feature maps. Default: (1, 3, 6, 8).
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(self,
|
296 |
+
low_in_channels,
|
297 |
+
high_in_channels,
|
298 |
+
key_channels,
|
299 |
+
value_channels,
|
300 |
+
out_channels=None,
|
301 |
+
scale=1,
|
302 |
+
psp_size=(1, 3, 6, 8)):
|
303 |
+
super().__init__()
|
304 |
+
|
305 |
+
self.scale = scale
|
306 |
+
self.in_channels = low_in_channels
|
307 |
+
self.out_channels = out_channels
|
308 |
+
self.key_channels = key_channels
|
309 |
+
self.value_channels = value_channels
|
310 |
+
if out_channels == None:
|
311 |
+
self.out_channels = high_in_channels
|
312 |
+
self.pool = nn.MaxPool2D(scale)
|
313 |
+
self.f_key = layers.ConvBNReLU(
|
314 |
+
in_channels=low_in_channels,
|
315 |
+
out_channels=key_channels,
|
316 |
+
kernel_size=1)
|
317 |
+
self.f_query = layers.ConvBNReLU(
|
318 |
+
in_channels=high_in_channels,
|
319 |
+
out_channels=key_channels,
|
320 |
+
kernel_size=1)
|
321 |
+
self.f_value = nn.Conv2D(
|
322 |
+
in_channels=low_in_channels,
|
323 |
+
out_channels=value_channels,
|
324 |
+
kernel_size=1)
|
325 |
+
|
326 |
+
self.W = nn.Conv2D(
|
327 |
+
in_channels=value_channels,
|
328 |
+
out_channels=out_channels,
|
329 |
+
kernel_size=1)
|
330 |
+
|
331 |
+
self.psp_size = psp_size
|
332 |
+
|
333 |
+
def forward(self, low_feats, high_feats):
|
334 |
+
batch_size, _, h, w = high_feats.shape
|
335 |
+
|
336 |
+
value = self.f_value(low_feats)
|
337 |
+
value = _pp_module(value, self.psp_size)
|
338 |
+
value = paddle.transpose(value, (0, 2, 1))
|
339 |
+
|
340 |
+
query = self.f_query(high_feats)
|
341 |
+
query = paddle.reshape(query, shape=(0, self.key_channels, -1))
|
342 |
+
query = paddle.transpose(query, perm=(0, 2, 1))
|
343 |
+
|
344 |
+
key = self.f_key(low_feats)
|
345 |
+
key = _pp_module(key, self.psp_size)
|
346 |
+
|
347 |
+
sim_map = paddle.matmul(query, key)
|
348 |
+
sim_map = (self.key_channels**-.5) * sim_map
|
349 |
+
sim_map = F.softmax(sim_map, axis=-1)
|
350 |
+
|
351 |
+
context = paddle.matmul(sim_map, value)
|
352 |
+
context = paddle.transpose(context, perm=(0, 2, 1))
|
353 |
+
hf_shape = paddle.shape(high_feats)
|
354 |
+
context = paddle.reshape(
|
355 |
+
context, shape=[0, self.value_channels, hf_shape[2], hf_shape[3]])
|
356 |
+
|
357 |
+
context = self.W(context)
|
358 |
+
|
359 |
+
return context
|
360 |
+
|
361 |
+
|
362 |
+
class SelfAttentionBlock_APNB(nn.Layer):
|
363 |
+
"""
|
364 |
+
Self-Attention Block for APNB module.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
in_channels (int): The input channels of APNB module.
|
368 |
+
out_channels (int): The out channels of APNB module.
|
369 |
+
key_channels (int): The key channels in self-attention block.
|
370 |
+
value_channels (int): The value channels in self-attention block.
|
371 |
+
scale (int, optional): Pooling size. Default: 1.
|
372 |
+
psp_size (tuple, optional): The out size of pooled feature maps. Default: (1, 3, 6, 8).
|
373 |
+
"""
|
374 |
+
|
375 |
+
def __init__(self,
|
376 |
+
in_channels,
|
377 |
+
out_channels,
|
378 |
+
key_channels,
|
379 |
+
value_channels,
|
380 |
+
scale=1,
|
381 |
+
psp_size=(1, 3, 6, 8)):
|
382 |
+
super().__init__()
|
383 |
+
|
384 |
+
self.scale = scale
|
385 |
+
self.in_channels = in_channels
|
386 |
+
self.out_channels = out_channels
|
387 |
+
self.key_channels = key_channels
|
388 |
+
self.value_channels = value_channels
|
389 |
+
self.pool = nn.MaxPool2D(scale)
|
390 |
+
self.f_key = layers.ConvBNReLU(
|
391 |
+
in_channels=self.in_channels,
|
392 |
+
out_channels=self.key_channels,
|
393 |
+
kernel_size=1)
|
394 |
+
self.f_query = self.f_key
|
395 |
+
self.f_value = nn.Conv2D(
|
396 |
+
in_channels=self.in_channels,
|
397 |
+
out_channels=self.value_channels,
|
398 |
+
kernel_size=1)
|
399 |
+
self.W = nn.Conv2D(
|
400 |
+
in_channels=self.value_channels,
|
401 |
+
out_channels=self.out_channels,
|
402 |
+
kernel_size=1)
|
403 |
+
|
404 |
+
self.psp_size = psp_size
|
405 |
+
|
406 |
+
def forward(self, x):
|
407 |
+
batch_size, _, h, w = x.shape
|
408 |
+
if self.scale > 1:
|
409 |
+
x = self.pool(x)
|
410 |
+
|
411 |
+
value = self.f_value(x)
|
412 |
+
value = _pp_module(value, self.psp_size)
|
413 |
+
value = paddle.transpose(value, perm=(0, 2, 1))
|
414 |
+
|
415 |
+
query = self.f_query(x)
|
416 |
+
query = paddle.reshape(query, shape=(0, self.key_channels, -1))
|
417 |
+
query = paddle.transpose(query, perm=(0, 2, 1))
|
418 |
+
|
419 |
+
key = self.f_key(x)
|
420 |
+
key = _pp_module(key, self.psp_size)
|
421 |
+
|
422 |
+
sim_map = paddle.matmul(query, key)
|
423 |
+
sim_map = (self.key_channels**-.5) * sim_map
|
424 |
+
sim_map = F.softmax(sim_map, axis=-1)
|
425 |
+
|
426 |
+
context = paddle.matmul(sim_map, value)
|
427 |
+
context = paddle.transpose(context, perm=(0, 2, 1))
|
428 |
+
|
429 |
+
x_shape = paddle.shape(x)
|
430 |
+
context = paddle.reshape(
|
431 |
+
context, shape=[0, self.value_channels, x_shape[2], x_shape[3]])
|
432 |
+
context = self.W(context)
|
433 |
+
|
434 |
+
return context
|
paddleseg/models/attention_unet.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import paddle
|
16 |
+
import paddle.nn as nn
|
17 |
+
from paddleseg.cvlibs import manager
|
18 |
+
from paddleseg.models import layers
|
19 |
+
from paddleseg import utils
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
|
23 |
+
@manager.MODELS.add_component
|
24 |
+
class AttentionUNet(nn.Layer):
|
25 |
+
"""
|
26 |
+
The Attention-UNet implementation based on PaddlePaddle.
|
27 |
+
As mentioned in the original paper, author proposes a novel attention gate (AG)
|
28 |
+
that automatically learns to focus on target structures of varying shapes and sizes.
|
29 |
+
Models trained with AGs implicitly learn to suppress irrelevant regions in an input image while
|
30 |
+
highlighting salient features useful for a specific task.
|
31 |
+
|
32 |
+
The original article refers to
|
33 |
+
Oktay, O, et, al. "Attention u-net: Learning where to look for the pancreas."
|
34 |
+
(https://arxiv.org/pdf/1804.03999.pdf).
|
35 |
+
|
36 |
+
Args:
|
37 |
+
num_classes (int): The unique number of target classes.
|
38 |
+
pretrained (str, optional): The path or url of pretrained model. Default: None.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, num_classes, pretrained=None):
|
42 |
+
super().__init__()
|
43 |
+
n_channels = 3
|
44 |
+
self.encoder = Encoder(n_channels, [64, 128, 256, 512])
|
45 |
+
filters = np.array([64, 128, 256, 512, 1024])
|
46 |
+
self.up5 = UpConv(ch_in=filters[4], ch_out=filters[3])
|
47 |
+
self.att5 = AttentionBlock(
|
48 |
+
F_g=filters[3], F_l=filters[3], F_out=filters[2])
|
49 |
+
self.up_conv5 = ConvBlock(ch_in=filters[4], ch_out=filters[3])
|
50 |
+
|
51 |
+
self.up4 = UpConv(ch_in=filters[3], ch_out=filters[2])
|
52 |
+
self.att4 = AttentionBlock(
|
53 |
+
F_g=filters[2], F_l=filters[2], F_out=filters[1])
|
54 |
+
self.up_conv4 = ConvBlock(ch_in=filters[3], ch_out=filters[2])
|
55 |
+
|
56 |
+
self.up3 = UpConv(ch_in=filters[2], ch_out=filters[1])
|
57 |
+
self.att3 = AttentionBlock(
|
58 |
+
F_g=filters[1], F_l=filters[1], F_out=filters[0])
|
59 |
+
self.up_conv3 = ConvBlock(ch_in=filters[2], ch_out=filters[1])
|
60 |
+
|
61 |
+
self.up2 = UpConv(ch_in=filters[1], ch_out=filters[0])
|
62 |
+
self.att2 = AttentionBlock(
|
63 |
+
F_g=filters[0], F_l=filters[0], F_out=filters[0] // 2)
|
64 |
+
self.up_conv2 = ConvBlock(ch_in=filters[1], ch_out=filters[0])
|
65 |
+
|
66 |
+
self.conv_1x1 = nn.Conv2D(
|
67 |
+
filters[0], num_classes, kernel_size=1, stride=1, padding=0)
|
68 |
+
self.pretrained = pretrained
|
69 |
+
self.init_weight()
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x5, (x1, x2, x3, x4) = self.encoder(x)
|
73 |
+
d5 = self.up5(x5)
|
74 |
+
x4 = self.att5(g=d5, x=x4)
|
75 |
+
d5 = paddle.concat([x4, d5], axis=1)
|
76 |
+
d5 = self.up_conv5(d5)
|
77 |
+
|
78 |
+
d4 = self.up4(d5)
|
79 |
+
x3 = self.att4(g=d4, x=x3)
|
80 |
+
d4 = paddle.concat((x3, d4), axis=1)
|
81 |
+
d4 = self.up_conv4(d4)
|
82 |
+
|
83 |
+
d3 = self.up3(d4)
|
84 |
+
x2 = self.att3(g=d3, x=x2)
|
85 |
+
d3 = paddle.concat((x2, d3), axis=1)
|
86 |
+
d3 = self.up_conv3(d3)
|
87 |
+
|
88 |
+
d2 = self.up2(d3)
|
89 |
+
x1 = self.att2(g=d2, x=x1)
|
90 |
+
d2 = paddle.concat((x1, d2), axis=1)
|
91 |
+
d2 = self.up_conv2(d2)
|
92 |
+
|
93 |
+
logit = self.conv_1x1(d2)
|
94 |
+
logit_list = [logit]
|
95 |
+
return logit_list
|
96 |
+
|
97 |
+
def init_weight(self):
|
98 |
+
if self.pretrained is not None:
|
99 |
+
utils.load_entire_model(self, self.pretrained)
|
100 |
+
|
101 |
+
|
102 |
+
class AttentionBlock(nn.Layer):
|
103 |
+
def __init__(self, F_g, F_l, F_out):
|
104 |
+
super().__init__()
|
105 |
+
self.W_g = nn.Sequential(
|
106 |
+
nn.Conv2D(
|
107 |
+
F_g, F_out, kernel_size=1, stride=1, padding=0),
|
108 |
+
nn.BatchNorm2D(F_out))
|
109 |
+
|
110 |
+
self.W_x = nn.Sequential(
|
111 |
+
nn.Conv2D(
|
112 |
+
F_l, F_out, kernel_size=1, stride=1, padding=0),
|
113 |
+
nn.BatchNorm2D(F_out))
|
114 |
+
|
115 |
+
self.psi = nn.Sequential(
|
116 |
+
nn.Conv2D(
|
117 |
+
F_out, 1, kernel_size=1, stride=1, padding=0),
|
118 |
+
nn.BatchNorm2D(1),
|
119 |
+
nn.Sigmoid())
|
120 |
+
|
121 |
+
self.relu = nn.ReLU()
|
122 |
+
|
123 |
+
def forward(self, g, x):
|
124 |
+
g1 = self.W_g(g)
|
125 |
+
x1 = self.W_x(x)
|
126 |
+
psi = self.relu(g1 + x1)
|
127 |
+
psi = self.psi(psi)
|
128 |
+
res = x * psi
|
129 |
+
return res
|
130 |
+
|
131 |
+
|
132 |
+
class UpConv(nn.Layer):
|
133 |
+
def __init__(self, ch_in, ch_out):
|
134 |
+
super().__init__()
|
135 |
+
self.up = nn.Sequential(
|
136 |
+
nn.Upsample(
|
137 |
+
scale_factor=2, mode="bilinear"),
|
138 |
+
nn.Conv2D(
|
139 |
+
ch_in, ch_out, kernel_size=3, stride=1, padding=1),
|
140 |
+
nn.BatchNorm2D(ch_out),
|
141 |
+
nn.ReLU())
|
142 |
+
|
143 |
+
def forward(self, x):
|
144 |
+
return self.up(x)
|
145 |
+
|
146 |
+
|
147 |
+
class Encoder(nn.Layer):
|
148 |
+
def __init__(self, input_channels, filters):
|
149 |
+
super().__init__()
|
150 |
+
self.double_conv = nn.Sequential(
|
151 |
+
layers.ConvBNReLU(input_channels, 64, 3),
|
152 |
+
layers.ConvBNReLU(64, 64, 3))
|
153 |
+
down_channels = filters
|
154 |
+
self.down_sample_list = nn.LayerList([
|
155 |
+
self.down_sampling(channel, channel * 2)
|
156 |
+
for channel in down_channels
|
157 |
+
])
|
158 |
+
|
159 |
+
def down_sampling(self, in_channels, out_channels):
|
160 |
+
modules = []
|
161 |
+
modules.append(nn.MaxPool2D(kernel_size=2, stride=2))
|
162 |
+
modules.append(layers.ConvBNReLU(in_channels, out_channels, 3))
|
163 |
+
modules.append(layers.ConvBNReLU(out_channels, out_channels, 3))
|
164 |
+
return nn.Sequential(*modules)
|
165 |
+
|
166 |
+
def forward(self, x):
|
167 |
+
short_cuts = []
|
168 |
+
x = self.double_conv(x)
|
169 |
+
for down_sample in self.down_sample_list:
|
170 |
+
short_cuts.append(x)
|
171 |
+
x = down_sample(x)
|
172 |
+
return x, short_cuts
|
173 |
+
|
174 |
+
|
175 |
+
class ConvBlock(nn.Layer):
|
176 |
+
def __init__(self, ch_in, ch_out):
|
177 |
+
super(ConvBlock, self).__init__()
|
178 |
+
self.conv = nn.Sequential(
|
179 |
+
nn.Conv2D(
|
180 |
+
ch_in, ch_out, kernel_size=3, stride=1, padding=1),
|
181 |
+
nn.BatchNorm2D(ch_out),
|
182 |
+
nn.ReLU(),
|
183 |
+
nn.Conv2D(
|
184 |
+
ch_out, ch_out, kernel_size=3, stride=1, padding=1),
|
185 |
+
nn.BatchNorm2D(ch_out),
|
186 |
+
nn.ReLU())
|
187 |
+
|
188 |
+
def forward(self, x):
|
189 |
+
return self.conv(x)
|
paddleseg/models/backbones/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .hrnet import *
|
16 |
+
from .resnet_vd import *
|
17 |
+
from .xception_deeplab import *
|
18 |
+
from .mobilenetv3 import *
|
19 |
+
from .vision_transformer import *
|
20 |
+
from .swin_transformer import *
|
21 |
+
from .mobilenetv2 import *
|
22 |
+
from .mix_transformer import *
|
23 |
+
from .stdcnet import *
|
24 |
+
from .lite_hrnet import *
|
25 |
+
from .shufflenetv2 import *
|
26 |
+
from .ghostnet import *
|
paddleseg/models/backbones/ghostnet.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# Code was based on https://github.com/huawei-noah/CV-Backbones/tree/master/ghostnet_pytorch
|
16 |
+
|
17 |
+
import math
|
18 |
+
import paddle
|
19 |
+
from paddle import ParamAttr
|
20 |
+
import paddle.nn as nn
|
21 |
+
import paddle.nn.functional as F
|
22 |
+
from paddle.nn import Conv2D, BatchNorm, AdaptiveAvgPool2D, Linear
|
23 |
+
from paddle.regularizer import L2Decay
|
24 |
+
from paddle.nn.initializer import Uniform, KaimingNormal
|
25 |
+
|
26 |
+
from paddleseg.cvlibs import manager
|
27 |
+
from paddleseg.utils import utils, logger
|
28 |
+
|
29 |
+
__all__ = ["GhostNet_x0_5", "GhostNet_x1_0", "GhostNet_x1_3"]
|
30 |
+
|
31 |
+
|
32 |
+
class ConvBNLayer(nn.Layer):
|
33 |
+
def __init__(self,
|
34 |
+
in_channels,
|
35 |
+
out_channels,
|
36 |
+
kernel_size,
|
37 |
+
stride=1,
|
38 |
+
groups=1,
|
39 |
+
act="relu",
|
40 |
+
name=None):
|
41 |
+
super(ConvBNLayer, self).__init__()
|
42 |
+
self._conv = Conv2D(
|
43 |
+
in_channels=in_channels,
|
44 |
+
out_channels=out_channels,
|
45 |
+
kernel_size=kernel_size,
|
46 |
+
stride=stride,
|
47 |
+
padding=(kernel_size - 1) // 2,
|
48 |
+
groups=groups,
|
49 |
+
weight_attr=ParamAttr(
|
50 |
+
initializer=KaimingNormal(), name=name + "_weights"),
|
51 |
+
bias_attr=False)
|
52 |
+
bn_name = name + "_bn"
|
53 |
+
|
54 |
+
self._batch_norm = BatchNorm(
|
55 |
+
num_channels=out_channels,
|
56 |
+
act=act,
|
57 |
+
param_attr=ParamAttr(
|
58 |
+
name=bn_name + "_scale", regularizer=L2Decay(0.0)),
|
59 |
+
bias_attr=ParamAttr(
|
60 |
+
name=bn_name + "_offset", regularizer=L2Decay(0.0)),
|
61 |
+
moving_mean_name=bn_name + "_mean",
|
62 |
+
moving_variance_name=bn_name + "_variance")
|
63 |
+
|
64 |
+
def forward(self, inputs):
|
65 |
+
y = self._conv(inputs)
|
66 |
+
y = self._batch_norm(y)
|
67 |
+
return y
|
68 |
+
|
69 |
+
|
70 |
+
class SEBlock(nn.Layer):
|
71 |
+
def __init__(self, num_channels, reduction_ratio=4, name=None):
|
72 |
+
super(SEBlock, self).__init__()
|
73 |
+
self.pool2d_gap = AdaptiveAvgPool2D(1)
|
74 |
+
self._num_channels = num_channels
|
75 |
+
stdv = 1.0 / math.sqrt(num_channels * 1.0)
|
76 |
+
med_ch = num_channels // reduction_ratio
|
77 |
+
self.squeeze = Linear(
|
78 |
+
num_channels,
|
79 |
+
med_ch,
|
80 |
+
weight_attr=ParamAttr(
|
81 |
+
initializer=Uniform(-stdv, stdv), name=name + "_1_weights"),
|
82 |
+
bias_attr=ParamAttr(name=name + "_1_offset"))
|
83 |
+
stdv = 1.0 / math.sqrt(med_ch * 1.0)
|
84 |
+
self.excitation = Linear(
|
85 |
+
med_ch,
|
86 |
+
num_channels,
|
87 |
+
weight_attr=ParamAttr(
|
88 |
+
initializer=Uniform(-stdv, stdv), name=name + "_2_weights"),
|
89 |
+
bias_attr=ParamAttr(name=name + "_2_offset"))
|
90 |
+
|
91 |
+
def forward(self, inputs):
|
92 |
+
pool = self.pool2d_gap(inputs)
|
93 |
+
pool = paddle.squeeze(pool, axis=[2, 3])
|
94 |
+
squeeze = self.squeeze(pool)
|
95 |
+
squeeze = F.relu(squeeze)
|
96 |
+
excitation = self.excitation(squeeze)
|
97 |
+
excitation = paddle.clip(x=excitation, min=0, max=1)
|
98 |
+
excitation = paddle.unsqueeze(excitation, axis=[2, 3])
|
99 |
+
out = paddle.multiply(inputs, excitation)
|
100 |
+
return out
|
101 |
+
|
102 |
+
|
103 |
+
class GhostModule(nn.Layer):
|
104 |
+
def __init__(self,
|
105 |
+
in_channels,
|
106 |
+
output_channels,
|
107 |
+
kernel_size=1,
|
108 |
+
ratio=2,
|
109 |
+
dw_size=3,
|
110 |
+
stride=1,
|
111 |
+
relu=True,
|
112 |
+
name=None):
|
113 |
+
super(GhostModule, self).__init__()
|
114 |
+
init_channels = int(math.ceil(output_channels / ratio))
|
115 |
+
new_channels = int(init_channels * (ratio - 1))
|
116 |
+
self.primary_conv = ConvBNLayer(
|
117 |
+
in_channels=in_channels,
|
118 |
+
out_channels=init_channels,
|
119 |
+
kernel_size=kernel_size,
|
120 |
+
stride=stride,
|
121 |
+
groups=1,
|
122 |
+
act="relu" if relu else None,
|
123 |
+
name=name + "_primary_conv")
|
124 |
+
self.cheap_operation = ConvBNLayer(
|
125 |
+
in_channels=init_channels,
|
126 |
+
out_channels=new_channels,
|
127 |
+
kernel_size=dw_size,
|
128 |
+
stride=1,
|
129 |
+
groups=init_channels,
|
130 |
+
act="relu" if relu else None,
|
131 |
+
name=name + "_cheap_operation")
|
132 |
+
|
133 |
+
def forward(self, inputs):
|
134 |
+
x = self.primary_conv(inputs)
|
135 |
+
y = self.cheap_operation(x)
|
136 |
+
out = paddle.concat([x, y], axis=1)
|
137 |
+
return out
|
138 |
+
|
139 |
+
|
140 |
+
class GhostBottleneck(nn.Layer):
|
141 |
+
def __init__(self,
|
142 |
+
in_channels,
|
143 |
+
hidden_dim,
|
144 |
+
output_channels,
|
145 |
+
kernel_size,
|
146 |
+
stride,
|
147 |
+
use_se,
|
148 |
+
name=None):
|
149 |
+
super(GhostBottleneck, self).__init__()
|
150 |
+
self._stride = stride
|
151 |
+
self._use_se = use_se
|
152 |
+
self._num_channels = in_channels
|
153 |
+
self._output_channels = output_channels
|
154 |
+
self.ghost_module_1 = GhostModule(
|
155 |
+
in_channels=in_channels,
|
156 |
+
output_channels=hidden_dim,
|
157 |
+
kernel_size=1,
|
158 |
+
stride=1,
|
159 |
+
relu=True,
|
160 |
+
name=name + "_ghost_module_1")
|
161 |
+
if stride == 2:
|
162 |
+
self.depthwise_conv = ConvBNLayer(
|
163 |
+
in_channels=hidden_dim,
|
164 |
+
out_channels=hidden_dim,
|
165 |
+
kernel_size=kernel_size,
|
166 |
+
stride=stride,
|
167 |
+
groups=hidden_dim,
|
168 |
+
act=None,
|
169 |
+
name=name +
|
170 |
+
"_depthwise_depthwise" # looks strange due to an old typo, will be fixed later.
|
171 |
+
)
|
172 |
+
if use_se:
|
173 |
+
self.se_block = SEBlock(num_channels=hidden_dim, name=name + "_se")
|
174 |
+
self.ghost_module_2 = GhostModule(
|
175 |
+
in_channels=hidden_dim,
|
176 |
+
output_channels=output_channels,
|
177 |
+
kernel_size=1,
|
178 |
+
relu=False,
|
179 |
+
name=name + "_ghost_module_2")
|
180 |
+
if stride != 1 or in_channels != output_channels:
|
181 |
+
self.shortcut_depthwise = ConvBNLayer(
|
182 |
+
in_channels=in_channels,
|
183 |
+
out_channels=in_channels,
|
184 |
+
kernel_size=kernel_size,
|
185 |
+
stride=stride,
|
186 |
+
groups=in_channels,
|
187 |
+
act=None,
|
188 |
+
name=name +
|
189 |
+
"_shortcut_depthwise_depthwise" # looks strange due to an old typo, will be fixed later.
|
190 |
+
)
|
191 |
+
self.shortcut_conv = ConvBNLayer(
|
192 |
+
in_channels=in_channels,
|
193 |
+
out_channels=output_channels,
|
194 |
+
kernel_size=1,
|
195 |
+
stride=1,
|
196 |
+
groups=1,
|
197 |
+
act=None,
|
198 |
+
name=name + "_shortcut_conv")
|
199 |
+
|
200 |
+
def forward(self, inputs):
|
201 |
+
x = self.ghost_module_1(inputs)
|
202 |
+
if self._stride == 2:
|
203 |
+
x = self.depthwise_conv(x)
|
204 |
+
if self._use_se:
|
205 |
+
x = self.se_block(x)
|
206 |
+
x = self.ghost_module_2(x)
|
207 |
+
if self._stride == 1 and self._num_channels == self._output_channels:
|
208 |
+
shortcut = inputs
|
209 |
+
else:
|
210 |
+
shortcut = self.shortcut_depthwise(inputs)
|
211 |
+
shortcut = self.shortcut_conv(shortcut)
|
212 |
+
return paddle.add(x=x, y=shortcut)
|
213 |
+
|
214 |
+
|
215 |
+
class GhostNet(nn.Layer):
|
216 |
+
def __init__(self, scale, pretrained=None):
|
217 |
+
super(GhostNet, self).__init__()
|
218 |
+
self.cfgs = [
|
219 |
+
# k, t, c, SE, s
|
220 |
+
[3, 16, 16, 0, 1],
|
221 |
+
[3, 48, 24, 0, 2],
|
222 |
+
[3, 72, 24, 0, 1], # x4
|
223 |
+
[5, 72, 40, 1, 2],
|
224 |
+
[5, 120, 40, 1, 1], # x8
|
225 |
+
[3, 240, 80, 0, 2],
|
226 |
+
[3, 200, 80, 0, 1],
|
227 |
+
[3, 184, 80, 0, 1],
|
228 |
+
[3, 184, 80, 0, 1],
|
229 |
+
[3, 480, 112, 1, 1],
|
230 |
+
[3, 672, 112, 1, 1], # x16
|
231 |
+
[5, 672, 160, 1, 2],
|
232 |
+
[5, 960, 160, 0, 1],
|
233 |
+
[5, 960, 160, 1, 1],
|
234 |
+
[5, 960, 160, 0, 1],
|
235 |
+
[5, 960, 160, 1, 1] # x32
|
236 |
+
]
|
237 |
+
self.scale = scale
|
238 |
+
self.pretrained = pretrained
|
239 |
+
|
240 |
+
output_channels = int(self._make_divisible(16 * self.scale, 4))
|
241 |
+
self.conv1 = ConvBNLayer(
|
242 |
+
in_channels=3,
|
243 |
+
out_channels=output_channels,
|
244 |
+
kernel_size=3,
|
245 |
+
stride=2,
|
246 |
+
groups=1,
|
247 |
+
act="relu",
|
248 |
+
name="conv1")
|
249 |
+
|
250 |
+
# build inverted residual blocks
|
251 |
+
self.out_index = [2, 4, 10, 15]
|
252 |
+
self.feat_channels = []
|
253 |
+
self.ghost_bottleneck_list = []
|
254 |
+
for idx, (k, exp_size, c, use_se, s) in enumerate(self.cfgs):
|
255 |
+
in_channels = output_channels
|
256 |
+
output_channels = int(self._make_divisible(c * self.scale, 4))
|
257 |
+
hidden_dim = int(self._make_divisible(exp_size * self.scale, 4))
|
258 |
+
ghost_bottleneck = self.add_sublayer(
|
259 |
+
name="_ghostbottleneck_" + str(idx),
|
260 |
+
sublayer=GhostBottleneck(
|
261 |
+
in_channels=in_channels,
|
262 |
+
hidden_dim=hidden_dim,
|
263 |
+
output_channels=output_channels,
|
264 |
+
kernel_size=k,
|
265 |
+
stride=s,
|
266 |
+
use_se=use_se,
|
267 |
+
name="_ghostbottleneck_" + str(idx)))
|
268 |
+
self.ghost_bottleneck_list.append(ghost_bottleneck)
|
269 |
+
if idx in self.out_index:
|
270 |
+
self.feat_channels.append(output_channels)
|
271 |
+
|
272 |
+
self.init_weight()
|
273 |
+
|
274 |
+
def init_weight(self):
|
275 |
+
if self.pretrained is not None:
|
276 |
+
utils.load_entire_model(self, self.pretrained)
|
277 |
+
|
278 |
+
def forward(self, inputs):
|
279 |
+
feat_list = []
|
280 |
+
x = self.conv1(inputs)
|
281 |
+
for idx, ghost_bottleneck in enumerate(self.ghost_bottleneck_list):
|
282 |
+
x = ghost_bottleneck(x)
|
283 |
+
if idx in self.out_index:
|
284 |
+
feat_list.append(x)
|
285 |
+
return feat_list
|
286 |
+
|
287 |
+
def _make_divisible(self, v, divisor, min_value=None):
|
288 |
+
"""
|
289 |
+
This function is taken from the original tf repo.
|
290 |
+
It ensures that all layers have a channel number that is divisible by 8
|
291 |
+
It can be seen here:
|
292 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
293 |
+
"""
|
294 |
+
if min_value is None:
|
295 |
+
min_value = divisor
|
296 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
297 |
+
# Make sure that round down does not go down by more than 10%.
|
298 |
+
if new_v < 0.9 * v:
|
299 |
+
new_v += divisor
|
300 |
+
return new_v
|
301 |
+
|
302 |
+
|
303 |
+
@manager.BACKBONES.add_component
|
304 |
+
def GhostNet_x0_5(**kwargs):
|
305 |
+
model = GhostNet(scale=0.5, **kwargs)
|
306 |
+
return model
|
307 |
+
|
308 |
+
|
309 |
+
@manager.BACKBONES.add_component
|
310 |
+
def GhostNet_x1_0(**kwargs):
|
311 |
+
model = GhostNet(scale=1.0, **kwargs)
|
312 |
+
return model
|
313 |
+
|
314 |
+
|
315 |
+
@manager.BACKBONES.add_component
|
316 |
+
def GhostNet_x1_3(**kwargs):
|
317 |
+
model = GhostNet(scale=1.3, **kwargs)
|
318 |
+
return model
|
paddleseg/models/backbones/hrnet.py
ADDED
@@ -0,0 +1,837 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
|
17 |
+
import paddle
|
18 |
+
import paddle.nn as nn
|
19 |
+
import paddle.nn.functional as F
|
20 |
+
|
21 |
+
from paddleseg.cvlibs import manager, param_init
|
22 |
+
from paddleseg.models import layers
|
23 |
+
from paddleseg.utils import utils
|
24 |
+
|
25 |
+
__all__ = [
|
26 |
+
"HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30",
|
27 |
+
"HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", "HRNet_W60", "HRNet_W64"
|
28 |
+
]
|
29 |
+
|
30 |
+
|
31 |
+
class HRNet(nn.Layer):
|
32 |
+
"""
|
33 |
+
The HRNet implementation based on PaddlePaddle.
|
34 |
+
|
35 |
+
The original article refers to
|
36 |
+
Jingdong Wang, et, al. "HRNet:Deep High-Resolution Representation Learning for Visual Recognition"
|
37 |
+
(https://arxiv.org/pdf/1908.07919.pdf).
|
38 |
+
|
39 |
+
Args:
|
40 |
+
pretrained (str, optional): The path of pretrained model.
|
41 |
+
stage1_num_modules (int, optional): Number of modules for stage1. Default 1.
|
42 |
+
stage1_num_blocks (list, optional): Number of blocks per module for stage1. Default (4).
|
43 |
+
stage1_num_channels (list, optional): Number of channels per branch for stage1. Default (64).
|
44 |
+
stage2_num_modules (int, optional): Number of modules for stage2. Default 1.
|
45 |
+
stage2_num_blocks (list, optional): Number of blocks per module for stage2. Default (4, 4).
|
46 |
+
stage2_num_channels (list, optional): Number of channels per branch for stage2. Default (18, 36).
|
47 |
+
stage3_num_modules (int, optional): Number of modules for stage3. Default 4.
|
48 |
+
stage3_num_blocks (list, optional): Number of blocks per module for stage3. Default (4, 4, 4).
|
49 |
+
stage3_num_channels (list, optional): Number of channels per branch for stage3. Default [18, 36, 72).
|
50 |
+
stage4_num_modules (int, optional): Number of modules for stage4. Default 3.
|
51 |
+
stage4_num_blocks (list, optional): Number of blocks per module for stage4. Default (4, 4, 4, 4).
|
52 |
+
stage4_num_channels (list, optional): Number of channels per branch for stage4. Default (18, 36, 72. 144).
|
53 |
+
has_se (bool, optional): Whether to use Squeeze-and-Excitation module. Default False.
|
54 |
+
align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
|
55 |
+
e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self,
|
59 |
+
pretrained=None,
|
60 |
+
stage1_num_modules=1,
|
61 |
+
stage1_num_blocks=(4, ),
|
62 |
+
stage1_num_channels=(64, ),
|
63 |
+
stage2_num_modules=1,
|
64 |
+
stage2_num_blocks=(4, 4),
|
65 |
+
stage2_num_channels=(18, 36),
|
66 |
+
stage3_num_modules=4,
|
67 |
+
stage3_num_blocks=(4, 4, 4),
|
68 |
+
stage3_num_channels=(18, 36, 72),
|
69 |
+
stage4_num_modules=3,
|
70 |
+
stage4_num_blocks=(4, 4, 4, 4),
|
71 |
+
stage4_num_channels=(18, 36, 72, 144),
|
72 |
+
has_se=False,
|
73 |
+
align_corners=False,
|
74 |
+
padding_same=True):
|
75 |
+
super(HRNet, self).__init__()
|
76 |
+
self.pretrained = pretrained
|
77 |
+
self.stage1_num_modules = stage1_num_modules
|
78 |
+
self.stage1_num_blocks = stage1_num_blocks
|
79 |
+
self.stage1_num_channels = stage1_num_channels
|
80 |
+
self.stage2_num_modules = stage2_num_modules
|
81 |
+
self.stage2_num_blocks = stage2_num_blocks
|
82 |
+
self.stage2_num_channels = stage2_num_channels
|
83 |
+
self.stage3_num_modules = stage3_num_modules
|
84 |
+
self.stage3_num_blocks = stage3_num_blocks
|
85 |
+
self.stage3_num_channels = stage3_num_channels
|
86 |
+
self.stage4_num_modules = stage4_num_modules
|
87 |
+
self.stage4_num_blocks = stage4_num_blocks
|
88 |
+
self.stage4_num_channels = stage4_num_channels
|
89 |
+
self.has_se = has_se
|
90 |
+
self.align_corners = align_corners
|
91 |
+
self.feat_channels = [sum(stage4_num_channels)]
|
92 |
+
|
93 |
+
self.conv_layer1_1 = layers.ConvBNReLU(
|
94 |
+
in_channels=3,
|
95 |
+
out_channels=64,
|
96 |
+
kernel_size=3,
|
97 |
+
stride=2,
|
98 |
+
padding=1 if not padding_same else 'same',
|
99 |
+
bias_attr=False)
|
100 |
+
|
101 |
+
self.conv_layer1_2 = layers.ConvBNReLU(
|
102 |
+
in_channels=64,
|
103 |
+
out_channels=64,
|
104 |
+
kernel_size=3,
|
105 |
+
stride=2,
|
106 |
+
padding=1 if not padding_same else 'same',
|
107 |
+
bias_attr=False)
|
108 |
+
|
109 |
+
self.la1 = Layer1(
|
110 |
+
num_channels=64,
|
111 |
+
num_blocks=self.stage1_num_blocks[0],
|
112 |
+
num_filters=self.stage1_num_channels[0],
|
113 |
+
has_se=has_se,
|
114 |
+
name="layer2",
|
115 |
+
padding_same=padding_same)
|
116 |
+
|
117 |
+
self.tr1 = TransitionLayer(
|
118 |
+
in_channels=[self.stage1_num_channels[0] * 4],
|
119 |
+
out_channels=self.stage2_num_channels,
|
120 |
+
name="tr1",
|
121 |
+
padding_same=padding_same)
|
122 |
+
|
123 |
+
self.st2 = Stage(
|
124 |
+
num_channels=self.stage2_num_channels,
|
125 |
+
num_modules=self.stage2_num_modules,
|
126 |
+
num_blocks=self.stage2_num_blocks,
|
127 |
+
num_filters=self.stage2_num_channels,
|
128 |
+
has_se=self.has_se,
|
129 |
+
name="st2",
|
130 |
+
align_corners=align_corners,
|
131 |
+
padding_same=padding_same)
|
132 |
+
|
133 |
+
self.tr2 = TransitionLayer(
|
134 |
+
in_channels=self.stage2_num_channels,
|
135 |
+
out_channels=self.stage3_num_channels,
|
136 |
+
name="tr2",
|
137 |
+
padding_same=padding_same)
|
138 |
+
self.st3 = Stage(
|
139 |
+
num_channels=self.stage3_num_channels,
|
140 |
+
num_modules=self.stage3_num_modules,
|
141 |
+
num_blocks=self.stage3_num_blocks,
|
142 |
+
num_filters=self.stage3_num_channels,
|
143 |
+
has_se=self.has_se,
|
144 |
+
name="st3",
|
145 |
+
align_corners=align_corners,
|
146 |
+
padding_same=padding_same)
|
147 |
+
|
148 |
+
self.tr3 = TransitionLayer(
|
149 |
+
in_channels=self.stage3_num_channels,
|
150 |
+
out_channels=self.stage4_num_channels,
|
151 |
+
name="tr3",
|
152 |
+
padding_same=padding_same)
|
153 |
+
self.st4 = Stage(
|
154 |
+
num_channels=self.stage4_num_channels,
|
155 |
+
num_modules=self.stage4_num_modules,
|
156 |
+
num_blocks=self.stage4_num_blocks,
|
157 |
+
num_filters=self.stage4_num_channels,
|
158 |
+
has_se=self.has_se,
|
159 |
+
name="st4",
|
160 |
+
align_corners=align_corners,
|
161 |
+
padding_same=padding_same)
|
162 |
+
|
163 |
+
self.init_weight()
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
conv1 = self.conv_layer1_1(x)
|
167 |
+
conv2 = self.conv_layer1_2(conv1)
|
168 |
+
|
169 |
+
la1 = self.la1(conv2)
|
170 |
+
|
171 |
+
tr1 = self.tr1([la1])
|
172 |
+
st2 = self.st2(tr1)
|
173 |
+
|
174 |
+
tr2 = self.tr2(st2)
|
175 |
+
st3 = self.st3(tr2)
|
176 |
+
|
177 |
+
tr3 = self.tr3(st3)
|
178 |
+
st4 = self.st4(tr3)
|
179 |
+
|
180 |
+
size = paddle.shape(st4[0])[2:]
|
181 |
+
x1 = F.interpolate(
|
182 |
+
st4[1], size, mode='bilinear', align_corners=self.align_corners)
|
183 |
+
x2 = F.interpolate(
|
184 |
+
st4[2], size, mode='bilinear', align_corners=self.align_corners)
|
185 |
+
x3 = F.interpolate(
|
186 |
+
st4[3], size, mode='bilinear', align_corners=self.align_corners)
|
187 |
+
x = paddle.concat([st4[0], x1, x2, x3], axis=1)
|
188 |
+
|
189 |
+
return [x]
|
190 |
+
|
191 |
+
def init_weight(self):
|
192 |
+
for layer in self.sublayers():
|
193 |
+
if isinstance(layer, nn.Conv2D):
|
194 |
+
param_init.normal_init(layer.weight, std=0.001)
|
195 |
+
elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
|
196 |
+
param_init.constant_init(layer.weight, value=1.0)
|
197 |
+
param_init.constant_init(layer.bias, value=0.0)
|
198 |
+
if self.pretrained is not None:
|
199 |
+
utils.load_pretrained_model(self, self.pretrained)
|
200 |
+
|
201 |
+
|
202 |
+
class Layer1(nn.Layer):
|
203 |
+
def __init__(self,
|
204 |
+
num_channels,
|
205 |
+
num_filters,
|
206 |
+
num_blocks,
|
207 |
+
has_se=False,
|
208 |
+
name=None,
|
209 |
+
padding_same=True):
|
210 |
+
super(Layer1, self).__init__()
|
211 |
+
|
212 |
+
self.bottleneck_block_list = []
|
213 |
+
|
214 |
+
for i in range(num_blocks):
|
215 |
+
bottleneck_block = self.add_sublayer(
|
216 |
+
"bb_{}_{}".format(name, i + 1),
|
217 |
+
BottleneckBlock(
|
218 |
+
num_channels=num_channels if i == 0 else num_filters * 4,
|
219 |
+
num_filters=num_filters,
|
220 |
+
has_se=has_se,
|
221 |
+
stride=1,
|
222 |
+
downsample=True if i == 0 else False,
|
223 |
+
name=name + '_' + str(i + 1),
|
224 |
+
padding_same=padding_same))
|
225 |
+
self.bottleneck_block_list.append(bottleneck_block)
|
226 |
+
|
227 |
+
def forward(self, x):
|
228 |
+
conv = x
|
229 |
+
for block_func in self.bottleneck_block_list:
|
230 |
+
conv = block_func(conv)
|
231 |
+
return conv
|
232 |
+
|
233 |
+
|
234 |
+
class TransitionLayer(nn.Layer):
|
235 |
+
def __init__(self, in_channels, out_channels, name=None, padding_same=True):
|
236 |
+
super(TransitionLayer, self).__init__()
|
237 |
+
|
238 |
+
num_in = len(in_channels)
|
239 |
+
num_out = len(out_channels)
|
240 |
+
self.conv_bn_func_list = []
|
241 |
+
for i in range(num_out):
|
242 |
+
residual = None
|
243 |
+
if i < num_in:
|
244 |
+
if in_channels[i] != out_channels[i]:
|
245 |
+
residual = self.add_sublayer(
|
246 |
+
"transition_{}_layer_{}".format(name, i + 1),
|
247 |
+
layers.ConvBNReLU(
|
248 |
+
in_channels=in_channels[i],
|
249 |
+
out_channels=out_channels[i],
|
250 |
+
kernel_size=3,
|
251 |
+
padding=1 if not padding_same else 'same',
|
252 |
+
bias_attr=False))
|
253 |
+
else:
|
254 |
+
residual = self.add_sublayer(
|
255 |
+
"transition_{}_layer_{}".format(name, i + 1),
|
256 |
+
layers.ConvBNReLU(
|
257 |
+
in_channels=in_channels[-1],
|
258 |
+
out_channels=out_channels[i],
|
259 |
+
kernel_size=3,
|
260 |
+
stride=2,
|
261 |
+
padding=1 if not padding_same else 'same',
|
262 |
+
bias_attr=False))
|
263 |
+
self.conv_bn_func_list.append(residual)
|
264 |
+
|
265 |
+
def forward(self, x):
|
266 |
+
outs = []
|
267 |
+
for idx, conv_bn_func in enumerate(self.conv_bn_func_list):
|
268 |
+
if conv_bn_func is None:
|
269 |
+
outs.append(x[idx])
|
270 |
+
else:
|
271 |
+
if idx < len(x):
|
272 |
+
outs.append(conv_bn_func(x[idx]))
|
273 |
+
else:
|
274 |
+
outs.append(conv_bn_func(x[-1]))
|
275 |
+
return outs
|
276 |
+
|
277 |
+
|
278 |
+
class Branches(nn.Layer):
|
279 |
+
def __init__(self,
|
280 |
+
num_blocks,
|
281 |
+
in_channels,
|
282 |
+
out_channels,
|
283 |
+
has_se=False,
|
284 |
+
name=None,
|
285 |
+
padding_same=True):
|
286 |
+
super(Branches, self).__init__()
|
287 |
+
|
288 |
+
self.basic_block_list = []
|
289 |
+
|
290 |
+
for i in range(len(out_channels)):
|
291 |
+
self.basic_block_list.append([])
|
292 |
+
for j in range(num_blocks[i]):
|
293 |
+
in_ch = in_channels[i] if j == 0 else out_channels[i]
|
294 |
+
basic_block_func = self.add_sublayer(
|
295 |
+
"bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1),
|
296 |
+
BasicBlock(
|
297 |
+
num_channels=in_ch,
|
298 |
+
num_filters=out_channels[i],
|
299 |
+
has_se=has_se,
|
300 |
+
name=name + '_branch_layer_' + str(i + 1) + '_' +
|
301 |
+
str(j + 1),
|
302 |
+
padding_same=padding_same))
|
303 |
+
self.basic_block_list[i].append(basic_block_func)
|
304 |
+
|
305 |
+
def forward(self, x):
|
306 |
+
outs = []
|
307 |
+
for idx, input in enumerate(x):
|
308 |
+
conv = input
|
309 |
+
for basic_block_func in self.basic_block_list[idx]:
|
310 |
+
conv = basic_block_func(conv)
|
311 |
+
outs.append(conv)
|
312 |
+
return outs
|
313 |
+
|
314 |
+
|
315 |
+
class BottleneckBlock(nn.Layer):
|
316 |
+
def __init__(self,
|
317 |
+
num_channels,
|
318 |
+
num_filters,
|
319 |
+
has_se,
|
320 |
+
stride=1,
|
321 |
+
downsample=False,
|
322 |
+
name=None,
|
323 |
+
padding_same=True):
|
324 |
+
super(BottleneckBlock, self).__init__()
|
325 |
+
|
326 |
+
self.has_se = has_se
|
327 |
+
self.downsample = downsample
|
328 |
+
|
329 |
+
self.conv1 = layers.ConvBNReLU(
|
330 |
+
in_channels=num_channels,
|
331 |
+
out_channels=num_filters,
|
332 |
+
kernel_size=1,
|
333 |
+
bias_attr=False)
|
334 |
+
|
335 |
+
self.conv2 = layers.ConvBNReLU(
|
336 |
+
in_channels=num_filters,
|
337 |
+
out_channels=num_filters,
|
338 |
+
kernel_size=3,
|
339 |
+
stride=stride,
|
340 |
+
padding=1 if not padding_same else 'same',
|
341 |
+
bias_attr=False)
|
342 |
+
|
343 |
+
self.conv3 = layers.ConvBN(
|
344 |
+
in_channels=num_filters,
|
345 |
+
out_channels=num_filters * 4,
|
346 |
+
kernel_size=1,
|
347 |
+
bias_attr=False)
|
348 |
+
|
349 |
+
if self.downsample:
|
350 |
+
self.conv_down = layers.ConvBN(
|
351 |
+
in_channels=num_channels,
|
352 |
+
out_channels=num_filters * 4,
|
353 |
+
kernel_size=1,
|
354 |
+
bias_attr=False)
|
355 |
+
|
356 |
+
if self.has_se:
|
357 |
+
self.se = SELayer(
|
358 |
+
num_channels=num_filters * 4,
|
359 |
+
num_filters=num_filters * 4,
|
360 |
+
reduction_ratio=16,
|
361 |
+
name=name + '_fc')
|
362 |
+
|
363 |
+
self.add = layers.Add()
|
364 |
+
self.relu = layers.Activation("relu")
|
365 |
+
|
366 |
+
def forward(self, x):
|
367 |
+
residual = x
|
368 |
+
conv1 = self.conv1(x)
|
369 |
+
conv2 = self.conv2(conv1)
|
370 |
+
conv3 = self.conv3(conv2)
|
371 |
+
|
372 |
+
if self.downsample:
|
373 |
+
residual = self.conv_down(x)
|
374 |
+
|
375 |
+
if self.has_se:
|
376 |
+
conv3 = self.se(conv3)
|
377 |
+
|
378 |
+
y = self.add(conv3, residual)
|
379 |
+
y = self.relu(y)
|
380 |
+
return y
|
381 |
+
|
382 |
+
|
383 |
+
class BasicBlock(nn.Layer):
|
384 |
+
def __init__(self,
|
385 |
+
num_channels,
|
386 |
+
num_filters,
|
387 |
+
stride=1,
|
388 |
+
has_se=False,
|
389 |
+
downsample=False,
|
390 |
+
name=None,
|
391 |
+
padding_same=True):
|
392 |
+
super(BasicBlock, self).__init__()
|
393 |
+
|
394 |
+
self.has_se = has_se
|
395 |
+
self.downsample = downsample
|
396 |
+
|
397 |
+
self.conv1 = layers.ConvBNReLU(
|
398 |
+
in_channels=num_channels,
|
399 |
+
out_channels=num_filters,
|
400 |
+
kernel_size=3,
|
401 |
+
stride=stride,
|
402 |
+
padding=1 if not padding_same else 'same',
|
403 |
+
bias_attr=False)
|
404 |
+
self.conv2 = layers.ConvBN(
|
405 |
+
in_channels=num_filters,
|
406 |
+
out_channels=num_filters,
|
407 |
+
kernel_size=3,
|
408 |
+
padding=1 if not padding_same else 'same',
|
409 |
+
bias_attr=False)
|
410 |
+
|
411 |
+
if self.downsample:
|
412 |
+
self.conv_down = layers.ConvBNReLU(
|
413 |
+
in_channels=num_channels,
|
414 |
+
out_channels=num_filters,
|
415 |
+
kernel_size=1,
|
416 |
+
bias_attr=False)
|
417 |
+
|
418 |
+
if self.has_se:
|
419 |
+
self.se = SELayer(
|
420 |
+
num_channels=num_filters,
|
421 |
+
num_filters=num_filters,
|
422 |
+
reduction_ratio=16,
|
423 |
+
name=name + '_fc')
|
424 |
+
|
425 |
+
self.add = layers.Add()
|
426 |
+
self.relu = layers.Activation("relu")
|
427 |
+
|
428 |
+
def forward(self, x):
|
429 |
+
residual = x
|
430 |
+
conv1 = self.conv1(x)
|
431 |
+
conv2 = self.conv2(conv1)
|
432 |
+
|
433 |
+
if self.downsample:
|
434 |
+
residual = self.conv_down(x)
|
435 |
+
|
436 |
+
if self.has_se:
|
437 |
+
conv2 = self.se(conv2)
|
438 |
+
|
439 |
+
y = self.add(conv2, residual)
|
440 |
+
y = self.relu(y)
|
441 |
+
return y
|
442 |
+
|
443 |
+
|
444 |
+
class SELayer(nn.Layer):
|
445 |
+
def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
|
446 |
+
super(SELayer, self).__init__()
|
447 |
+
|
448 |
+
self.pool2d_gap = nn.AdaptiveAvgPool2D(1)
|
449 |
+
|
450 |
+
self._num_channels = num_channels
|
451 |
+
|
452 |
+
med_ch = int(num_channels / reduction_ratio)
|
453 |
+
stdv = 1.0 / math.sqrt(num_channels * 1.0)
|
454 |
+
self.squeeze = nn.Linear(
|
455 |
+
num_channels,
|
456 |
+
med_ch,
|
457 |
+
weight_attr=paddle.ParamAttr(
|
458 |
+
initializer=nn.initializer.Uniform(-stdv, stdv)))
|
459 |
+
|
460 |
+
stdv = 1.0 / math.sqrt(med_ch * 1.0)
|
461 |
+
self.excitation = nn.Linear(
|
462 |
+
med_ch,
|
463 |
+
num_filters,
|
464 |
+
weight_attr=paddle.ParamAttr(
|
465 |
+
initializer=nn.initializer.Uniform(-stdv, stdv)))
|
466 |
+
|
467 |
+
def forward(self, x):
|
468 |
+
pool = self.pool2d_gap(x)
|
469 |
+
pool = paddle.reshape(pool, shape=[-1, self._num_channels])
|
470 |
+
squeeze = self.squeeze(pool)
|
471 |
+
squeeze = F.relu(squeeze)
|
472 |
+
excitation = self.excitation(squeeze)
|
473 |
+
excitation = F.sigmoid(excitation)
|
474 |
+
excitation = paddle.reshape(
|
475 |
+
excitation, shape=[-1, self._num_channels, 1, 1])
|
476 |
+
out = x * excitation
|
477 |
+
return out
|
478 |
+
|
479 |
+
|
480 |
+
class Stage(nn.Layer):
|
481 |
+
def __init__(self,
|
482 |
+
num_channels,
|
483 |
+
num_modules,
|
484 |
+
num_blocks,
|
485 |
+
num_filters,
|
486 |
+
has_se=False,
|
487 |
+
multi_scale_output=True,
|
488 |
+
name=None,
|
489 |
+
align_corners=False,
|
490 |
+
padding_same=True):
|
491 |
+
super(Stage, self).__init__()
|
492 |
+
|
493 |
+
self._num_modules = num_modules
|
494 |
+
|
495 |
+
self.stage_func_list = []
|
496 |
+
for i in range(num_modules):
|
497 |
+
if i == num_modules - 1 and not multi_scale_output:
|
498 |
+
stage_func = self.add_sublayer(
|
499 |
+
"stage_{}_{}".format(name, i + 1),
|
500 |
+
HighResolutionModule(
|
501 |
+
num_channels=num_channels,
|
502 |
+
num_blocks=num_blocks,
|
503 |
+
num_filters=num_filters,
|
504 |
+
has_se=has_se,
|
505 |
+
multi_scale_output=False,
|
506 |
+
name=name + '_' + str(i + 1),
|
507 |
+
align_corners=align_corners,
|
508 |
+
padding_same=padding_same))
|
509 |
+
else:
|
510 |
+
stage_func = self.add_sublayer(
|
511 |
+
"stage_{}_{}".format(name, i + 1),
|
512 |
+
HighResolutionModule(
|
513 |
+
num_channels=num_channels,
|
514 |
+
num_blocks=num_blocks,
|
515 |
+
num_filters=num_filters,
|
516 |
+
has_se=has_se,
|
517 |
+
name=name + '_' + str(i + 1),
|
518 |
+
align_corners=align_corners,
|
519 |
+
padding_same=padding_same))
|
520 |
+
|
521 |
+
self.stage_func_list.append(stage_func)
|
522 |
+
|
523 |
+
def forward(self, x):
|
524 |
+
out = x
|
525 |
+
for idx in range(self._num_modules):
|
526 |
+
out = self.stage_func_list[idx](out)
|
527 |
+
return out
|
528 |
+
|
529 |
+
|
530 |
+
class HighResolutionModule(nn.Layer):
|
531 |
+
def __init__(self,
|
532 |
+
num_channels,
|
533 |
+
num_blocks,
|
534 |
+
num_filters,
|
535 |
+
has_se=False,
|
536 |
+
multi_scale_output=True,
|
537 |
+
name=None,
|
538 |
+
align_corners=False,
|
539 |
+
padding_same=True):
|
540 |
+
super(HighResolutionModule, self).__init__()
|
541 |
+
|
542 |
+
self.branches_func = Branches(
|
543 |
+
num_blocks=num_blocks,
|
544 |
+
in_channels=num_channels,
|
545 |
+
out_channels=num_filters,
|
546 |
+
has_se=has_se,
|
547 |
+
name=name,
|
548 |
+
padding_same=padding_same)
|
549 |
+
|
550 |
+
self.fuse_func = FuseLayers(
|
551 |
+
in_channels=num_filters,
|
552 |
+
out_channels=num_filters,
|
553 |
+
multi_scale_output=multi_scale_output,
|
554 |
+
name=name,
|
555 |
+
align_corners=align_corners,
|
556 |
+
padding_same=padding_same)
|
557 |
+
|
558 |
+
def forward(self, x):
|
559 |
+
out = self.branches_func(x)
|
560 |
+
out = self.fuse_func(out)
|
561 |
+
return out
|
562 |
+
|
563 |
+
|
564 |
+
class FuseLayers(nn.Layer):
|
565 |
+
def __init__(self,
|
566 |
+
in_channels,
|
567 |
+
out_channels,
|
568 |
+
multi_scale_output=True,
|
569 |
+
name=None,
|
570 |
+
align_corners=False,
|
571 |
+
padding_same=True):
|
572 |
+
super(FuseLayers, self).__init__()
|
573 |
+
|
574 |
+
self._actual_ch = len(in_channels) if multi_scale_output else 1
|
575 |
+
self._in_channels = in_channels
|
576 |
+
self.align_corners = align_corners
|
577 |
+
|
578 |
+
self.residual_func_list = []
|
579 |
+
for i in range(self._actual_ch):
|
580 |
+
for j in range(len(in_channels)):
|
581 |
+
if j > i:
|
582 |
+
residual_func = self.add_sublayer(
|
583 |
+
"residual_{}_layer_{}_{}".format(name, i + 1, j + 1),
|
584 |
+
layers.ConvBN(
|
585 |
+
in_channels=in_channels[j],
|
586 |
+
out_channels=out_channels[i],
|
587 |
+
kernel_size=1,
|
588 |
+
bias_attr=False))
|
589 |
+
self.residual_func_list.append(residual_func)
|
590 |
+
elif j < i:
|
591 |
+
pre_num_filters = in_channels[j]
|
592 |
+
for k in range(i - j):
|
593 |
+
if k == i - j - 1:
|
594 |
+
residual_func = self.add_sublayer(
|
595 |
+
"residual_{}_layer_{}_{}_{}".format(
|
596 |
+
name, i + 1, j + 1, k + 1),
|
597 |
+
layers.ConvBN(
|
598 |
+
in_channels=pre_num_filters,
|
599 |
+
out_channels=out_channels[i],
|
600 |
+
kernel_size=3,
|
601 |
+
stride=2,
|
602 |
+
padding=1 if not padding_same else 'same',
|
603 |
+
bias_attr=False))
|
604 |
+
pre_num_filters = out_channels[i]
|
605 |
+
else:
|
606 |
+
residual_func = self.add_sublayer(
|
607 |
+
"residual_{}_layer_{}_{}_{}".format(
|
608 |
+
name, i + 1, j + 1, k + 1),
|
609 |
+
layers.ConvBNReLU(
|
610 |
+
in_channels=pre_num_filters,
|
611 |
+
out_channels=out_channels[j],
|
612 |
+
kernel_size=3,
|
613 |
+
stride=2,
|
614 |
+
padding=1 if not padding_same else 'same',
|
615 |
+
bias_attr=False))
|
616 |
+
pre_num_filters = out_channels[j]
|
617 |
+
self.residual_func_list.append(residual_func)
|
618 |
+
|
619 |
+
def forward(self, x):
|
620 |
+
outs = []
|
621 |
+
residual_func_idx = 0
|
622 |
+
for i in range(self._actual_ch):
|
623 |
+
residual = x[i]
|
624 |
+
residual_shape = paddle.shape(residual)[-2:]
|
625 |
+
for j in range(len(self._in_channels)):
|
626 |
+
if j > i:
|
627 |
+
y = self.residual_func_list[residual_func_idx](x[j])
|
628 |
+
residual_func_idx += 1
|
629 |
+
|
630 |
+
y = F.interpolate(
|
631 |
+
y,
|
632 |
+
residual_shape,
|
633 |
+
mode='bilinear',
|
634 |
+
align_corners=self.align_corners)
|
635 |
+
residual = residual + y
|
636 |
+
elif j < i:
|
637 |
+
y = x[j]
|
638 |
+
for k in range(i - j):
|
639 |
+
y = self.residual_func_list[residual_func_idx](y)
|
640 |
+
residual_func_idx += 1
|
641 |
+
|
642 |
+
residual = residual + y
|
643 |
+
|
644 |
+
residual = F.relu(residual)
|
645 |
+
outs.append(residual)
|
646 |
+
|
647 |
+
return outs
|
648 |
+
|
649 |
+
|
650 |
+
@manager.BACKBONES.add_component
|
651 |
+
def HRNet_W18_Small_V1(**kwargs):
|
652 |
+
model = HRNet(
|
653 |
+
stage1_num_modules=1,
|
654 |
+
stage1_num_blocks=[1],
|
655 |
+
stage1_num_channels=[32],
|
656 |
+
stage2_num_modules=1,
|
657 |
+
stage2_num_blocks=[2, 2],
|
658 |
+
stage2_num_channels=[16, 32],
|
659 |
+
stage3_num_modules=1,
|
660 |
+
stage3_num_blocks=[2, 2, 2],
|
661 |
+
stage3_num_channels=[16, 32, 64],
|
662 |
+
stage4_num_modules=1,
|
663 |
+
stage4_num_blocks=[2, 2, 2, 2],
|
664 |
+
stage4_num_channels=[16, 32, 64, 128],
|
665 |
+
**kwargs)
|
666 |
+
return model
|
667 |
+
|
668 |
+
|
669 |
+
@manager.BACKBONES.add_component
|
670 |
+
def HRNet_W18_Small_V2(**kwargs):
|
671 |
+
model = HRNet(
|
672 |
+
stage1_num_modules=1,
|
673 |
+
stage1_num_blocks=[2],
|
674 |
+
stage1_num_channels=[64],
|
675 |
+
stage2_num_modules=1,
|
676 |
+
stage2_num_blocks=[2, 2],
|
677 |
+
stage2_num_channels=[18, 36],
|
678 |
+
stage3_num_modules=3,
|
679 |
+
stage3_num_blocks=[2, 2, 2],
|
680 |
+
stage3_num_channels=[18, 36, 72],
|
681 |
+
stage4_num_modules=2,
|
682 |
+
stage4_num_blocks=[2, 2, 2, 2],
|
683 |
+
stage4_num_channels=[18, 36, 72, 144],
|
684 |
+
**kwargs)
|
685 |
+
return model
|
686 |
+
|
687 |
+
|
688 |
+
@manager.BACKBONES.add_component
|
689 |
+
def HRNet_W18(**kwargs):
|
690 |
+
model = HRNet(
|
691 |
+
stage1_num_modules=1,
|
692 |
+
stage1_num_blocks=[4],
|
693 |
+
stage1_num_channels=[64],
|
694 |
+
stage2_num_modules=1,
|
695 |
+
stage2_num_blocks=[4, 4],
|
696 |
+
stage2_num_channels=[18, 36],
|
697 |
+
stage3_num_modules=4,
|
698 |
+
stage3_num_blocks=[4, 4, 4],
|
699 |
+
stage3_num_channels=[18, 36, 72],
|
700 |
+
stage4_num_modules=3,
|
701 |
+
stage4_num_blocks=[4, 4, 4, 4],
|
702 |
+
stage4_num_channels=[18, 36, 72, 144],
|
703 |
+
**kwargs)
|
704 |
+
return model
|
705 |
+
|
706 |
+
|
707 |
+
@manager.BACKBONES.add_component
|
708 |
+
def HRNet_W30(**kwargs):
|
709 |
+
model = HRNet(
|
710 |
+
stage1_num_modules=1,
|
711 |
+
stage1_num_blocks=[4],
|
712 |
+
stage1_num_channels=[64],
|
713 |
+
stage2_num_modules=1,
|
714 |
+
stage2_num_blocks=[4, 4],
|
715 |
+
stage2_num_channels=[30, 60],
|
716 |
+
stage3_num_modules=4,
|
717 |
+
stage3_num_blocks=[4, 4, 4],
|
718 |
+
stage3_num_channels=[30, 60, 120],
|
719 |
+
stage4_num_modules=3,
|
720 |
+
stage4_num_blocks=[4, 4, 4, 4],
|
721 |
+
stage4_num_channels=[30, 60, 120, 240],
|
722 |
+
**kwargs)
|
723 |
+
return model
|
724 |
+
|
725 |
+
|
726 |
+
@manager.BACKBONES.add_component
|
727 |
+
def HRNet_W32(**kwargs):
|
728 |
+
model = HRNet(
|
729 |
+
stage1_num_modules=1,
|
730 |
+
stage1_num_blocks=[4],
|
731 |
+
stage1_num_channels=[64],
|
732 |
+
stage2_num_modules=1,
|
733 |
+
stage2_num_blocks=[4, 4],
|
734 |
+
stage2_num_channels=[32, 64],
|
735 |
+
stage3_num_modules=4,
|
736 |
+
stage3_num_blocks=[4, 4, 4],
|
737 |
+
stage3_num_channels=[32, 64, 128],
|
738 |
+
stage4_num_modules=3,
|
739 |
+
stage4_num_blocks=[4, 4, 4, 4],
|
740 |
+
stage4_num_channels=[32, 64, 128, 256],
|
741 |
+
**kwargs)
|
742 |
+
return model
|
743 |
+
|
744 |
+
|
745 |
+
@manager.BACKBONES.add_component
|
746 |
+
def HRNet_W40(**kwargs):
|
747 |
+
model = HRNet(
|
748 |
+
stage1_num_modules=1,
|
749 |
+
stage1_num_blocks=[4],
|
750 |
+
stage1_num_channels=[64],
|
751 |
+
stage2_num_modules=1,
|
752 |
+
stage2_num_blocks=[4, 4],
|
753 |
+
stage2_num_channels=[40, 80],
|
754 |
+
stage3_num_modules=4,
|
755 |
+
stage3_num_blocks=[4, 4, 4],
|
756 |
+
stage3_num_channels=[40, 80, 160],
|
757 |
+
stage4_num_modules=3,
|
758 |
+
stage4_num_blocks=[4, 4, 4, 4],
|
759 |
+
stage4_num_channels=[40, 80, 160, 320],
|
760 |
+
**kwargs)
|
761 |
+
return model
|
762 |
+
|
763 |
+
|
764 |
+
@manager.BACKBONES.add_component
|
765 |
+
def HRNet_W44(**kwargs):
|
766 |
+
model = HRNet(
|
767 |
+
stage1_num_modules=1,
|
768 |
+
stage1_num_blocks=[4],
|
769 |
+
stage1_num_channels=[64],
|
770 |
+
stage2_num_modules=1,
|
771 |
+
stage2_num_blocks=[4, 4],
|
772 |
+
stage2_num_channels=[44, 88],
|
773 |
+
stage3_num_modules=4,
|
774 |
+
stage3_num_blocks=[4, 4, 4],
|
775 |
+
stage3_num_channels=[44, 88, 176],
|
776 |
+
stage4_num_modules=3,
|
777 |
+
stage4_num_blocks=[4, 4, 4, 4],
|
778 |
+
stage4_num_channels=[44, 88, 176, 352],
|
779 |
+
**kwargs)
|
780 |
+
return model
|
781 |
+
|
782 |
+
|
783 |
+
@manager.BACKBONES.add_component
|
784 |
+
def HRNet_W48(**kwargs):
|
785 |
+
model = HRNet(
|
786 |
+
stage1_num_modules=1,
|
787 |
+
stage1_num_blocks=[4],
|
788 |
+
stage1_num_channels=[64],
|
789 |
+
stage2_num_modules=1,
|
790 |
+
stage2_num_blocks=[4, 4],
|
791 |
+
stage2_num_channels=[48, 96],
|
792 |
+
stage3_num_modules=4,
|
793 |
+
stage3_num_blocks=[4, 4, 4],
|
794 |
+
stage3_num_channels=[48, 96, 192],
|
795 |
+
stage4_num_modules=3,
|
796 |
+
stage4_num_blocks=[4, 4, 4, 4],
|
797 |
+
stage4_num_channels=[48, 96, 192, 384],
|
798 |
+
**kwargs)
|
799 |
+
return model
|
800 |
+
|
801 |
+
|
802 |
+
@manager.BACKBONES.add_component
|
803 |
+
def HRNet_W60(**kwargs):
|
804 |
+
model = HRNet(
|
805 |
+
stage1_num_modules=1,
|
806 |
+
stage1_num_blocks=[4],
|
807 |
+
stage1_num_channels=[64],
|
808 |
+
stage2_num_modules=1,
|
809 |
+
stage2_num_blocks=[4, 4],
|
810 |
+
stage2_num_channels=[60, 120],
|
811 |
+
stage3_num_modules=4,
|
812 |
+
stage3_num_blocks=[4, 4, 4],
|
813 |
+
stage3_num_channels=[60, 120, 240],
|
814 |
+
stage4_num_modules=3,
|
815 |
+
stage4_num_blocks=[4, 4, 4, 4],
|
816 |
+
stage4_num_channels=[60, 120, 240, 480],
|
817 |
+
**kwargs)
|
818 |
+
return model
|
819 |
+
|
820 |
+
|
821 |
+
@manager.BACKBONES.add_component
|
822 |
+
def HRNet_W64(**kwargs):
|
823 |
+
model = HRNet(
|
824 |
+
stage1_num_modules=1,
|
825 |
+
stage1_num_blocks=[4],
|
826 |
+
stage1_num_channels=[64],
|
827 |
+
stage2_num_modules=1,
|
828 |
+
stage2_num_blocks=[4, 4],
|
829 |
+
stage2_num_channels=[64, 128],
|
830 |
+
stage3_num_modules=4,
|
831 |
+
stage3_num_blocks=[4, 4, 4],
|
832 |
+
stage3_num_channels=[64, 128, 256],
|
833 |
+
stage4_num_modules=3,
|
834 |
+
stage4_num_blocks=[4, 4, 4, 4],
|
835 |
+
stage4_num_channels=[64, 128, 256, 512],
|
836 |
+
**kwargs)
|
837 |
+
return model
|
paddleseg/models/backbones/lite_hrnet.py
ADDED
@@ -0,0 +1,972 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
This code is based on
|
16 |
+
https://github.com/HRNet/Lite-HRNet/blob/hrnet/models/backbones/litehrnet.py
|
17 |
+
"""
|
18 |
+
|
19 |
+
import paddle
|
20 |
+
import paddle.nn as nn
|
21 |
+
import paddle.nn.functional as F
|
22 |
+
from numbers import Integral
|
23 |
+
from paddle import ParamAttr
|
24 |
+
from paddle.regularizer import L2Decay
|
25 |
+
from paddle.nn.initializer import Normal, Constant
|
26 |
+
|
27 |
+
from paddleseg.cvlibs import manager
|
28 |
+
from paddleseg import utils
|
29 |
+
|
30 |
+
__all__ = [
|
31 |
+
"Lite_HRNet_18", "Lite_HRNet_30", "Lite_HRNet_naive",
|
32 |
+
"Lite_HRNet_wider_naive", "LiteHRNet"
|
33 |
+
]
|
34 |
+
|
35 |
+
|
36 |
+
def Conv2d(in_channels,
|
37 |
+
out_channels,
|
38 |
+
kernel_size,
|
39 |
+
stride=1,
|
40 |
+
padding=0,
|
41 |
+
dilation=1,
|
42 |
+
groups=1,
|
43 |
+
bias=True,
|
44 |
+
weight_init=Normal(std=0.001),
|
45 |
+
bias_init=Constant(0.)):
|
46 |
+
weight_attr = paddle.framework.ParamAttr(initializer=weight_init)
|
47 |
+
if bias:
|
48 |
+
bias_attr = paddle.framework.ParamAttr(initializer=bias_init)
|
49 |
+
else:
|
50 |
+
bias_attr = False
|
51 |
+
conv = nn.Conv2D(
|
52 |
+
in_channels,
|
53 |
+
out_channels,
|
54 |
+
kernel_size,
|
55 |
+
stride,
|
56 |
+
padding,
|
57 |
+
dilation,
|
58 |
+
groups,
|
59 |
+
weight_attr=weight_attr,
|
60 |
+
bias_attr=bias_attr)
|
61 |
+
return conv
|
62 |
+
|
63 |
+
|
64 |
+
def channel_shuffle(x, groups):
|
65 |
+
x_shape = paddle.shape(x)
|
66 |
+
batch_size, height, width = x_shape[0], x_shape[2], x_shape[3]
|
67 |
+
num_channels = x.shape[1]
|
68 |
+
channels_per_group = num_channels // groups
|
69 |
+
|
70 |
+
x = paddle.reshape(
|
71 |
+
x=x, shape=[batch_size, groups, channels_per_group, height, width])
|
72 |
+
x = paddle.transpose(x=x, perm=[0, 2, 1, 3, 4])
|
73 |
+
x = paddle.reshape(x=x, shape=[batch_size, num_channels, height, width])
|
74 |
+
|
75 |
+
return x
|
76 |
+
|
77 |
+
|
78 |
+
class ConvNormLayer(nn.Layer):
|
79 |
+
def __init__(self,
|
80 |
+
ch_in,
|
81 |
+
ch_out,
|
82 |
+
filter_size,
|
83 |
+
stride=1,
|
84 |
+
groups=1,
|
85 |
+
norm_type=None,
|
86 |
+
norm_groups=32,
|
87 |
+
norm_decay=0.,
|
88 |
+
freeze_norm=False,
|
89 |
+
act=None):
|
90 |
+
super(ConvNormLayer, self).__init__()
|
91 |
+
self.act = act
|
92 |
+
norm_lr = 0. if freeze_norm else 1.
|
93 |
+
if norm_type is not None:
|
94 |
+
assert norm_type in ['bn', 'sync_bn', 'gn'], \
|
95 |
+
"norm_type should be one of ['bn', 'sync_bn', 'gn'], but got {}".format(norm_type)
|
96 |
+
param_attr = ParamAttr(
|
97 |
+
initializer=Constant(1.0),
|
98 |
+
learning_rate=norm_lr,
|
99 |
+
regularizer=L2Decay(norm_decay), )
|
100 |
+
bias_attr = ParamAttr(
|
101 |
+
learning_rate=norm_lr, regularizer=L2Decay(norm_decay))
|
102 |
+
global_stats = True if freeze_norm else None
|
103 |
+
if norm_type in ['bn', 'sync_bn']:
|
104 |
+
self.norm = nn.BatchNorm2D(
|
105 |
+
ch_out,
|
106 |
+
weight_attr=param_attr,
|
107 |
+
bias_attr=bias_attr,
|
108 |
+
use_global_stats=global_stats, )
|
109 |
+
elif norm_type == 'gn':
|
110 |
+
self.norm = nn.GroupNorm(
|
111 |
+
num_groups=norm_groups,
|
112 |
+
num_channels=ch_out,
|
113 |
+
weight_attr=param_attr,
|
114 |
+
bias_attr=bias_attr)
|
115 |
+
norm_params = self.norm.parameters()
|
116 |
+
if freeze_norm:
|
117 |
+
for param in norm_params:
|
118 |
+
param.stop_gradient = True
|
119 |
+
conv_bias_attr = False
|
120 |
+
else:
|
121 |
+
conv_bias_attr = True
|
122 |
+
self.norm = None
|
123 |
+
|
124 |
+
self.conv = nn.Conv2D(
|
125 |
+
in_channels=ch_in,
|
126 |
+
out_channels=ch_out,
|
127 |
+
kernel_size=filter_size,
|
128 |
+
stride=stride,
|
129 |
+
padding=(filter_size - 1) // 2,
|
130 |
+
groups=groups,
|
131 |
+
weight_attr=ParamAttr(initializer=Normal(
|
132 |
+
mean=0., std=0.001)),
|
133 |
+
bias_attr=conv_bias_attr)
|
134 |
+
|
135 |
+
def forward(self, inputs):
|
136 |
+
out = self.conv(inputs)
|
137 |
+
if self.norm is not None:
|
138 |
+
out = self.norm(out)
|
139 |
+
|
140 |
+
if self.act == 'relu':
|
141 |
+
out = F.relu(out)
|
142 |
+
elif self.act == 'sigmoid':
|
143 |
+
out = F.sigmoid(out)
|
144 |
+
return out
|
145 |
+
|
146 |
+
|
147 |
+
class DepthWiseSeparableConvNormLayer(nn.Layer):
|
148 |
+
def __init__(self,
|
149 |
+
ch_in,
|
150 |
+
ch_out,
|
151 |
+
filter_size,
|
152 |
+
stride=1,
|
153 |
+
dw_norm_type=None,
|
154 |
+
pw_norm_type=None,
|
155 |
+
norm_decay=0.,
|
156 |
+
freeze_norm=False,
|
157 |
+
dw_act=None,
|
158 |
+
pw_act=None):
|
159 |
+
super(DepthWiseSeparableConvNormLayer, self).__init__()
|
160 |
+
self.depthwise_conv = ConvNormLayer(
|
161 |
+
ch_in=ch_in,
|
162 |
+
ch_out=ch_in,
|
163 |
+
filter_size=filter_size,
|
164 |
+
stride=stride,
|
165 |
+
groups=ch_in,
|
166 |
+
norm_type=dw_norm_type,
|
167 |
+
act=dw_act,
|
168 |
+
norm_decay=norm_decay,
|
169 |
+
freeze_norm=freeze_norm, )
|
170 |
+
self.pointwise_conv = ConvNormLayer(
|
171 |
+
ch_in=ch_in,
|
172 |
+
ch_out=ch_out,
|
173 |
+
filter_size=1,
|
174 |
+
stride=1,
|
175 |
+
norm_type=pw_norm_type,
|
176 |
+
act=pw_act,
|
177 |
+
norm_decay=norm_decay,
|
178 |
+
freeze_norm=freeze_norm, )
|
179 |
+
|
180 |
+
def forward(self, x):
|
181 |
+
x = self.depthwise_conv(x)
|
182 |
+
x = self.pointwise_conv(x)
|
183 |
+
return x
|
184 |
+
|
185 |
+
|
186 |
+
class CrossResolutionWeightingModule(nn.Layer):
|
187 |
+
def __init__(self,
|
188 |
+
channels,
|
189 |
+
ratio=16,
|
190 |
+
norm_type='bn',
|
191 |
+
freeze_norm=False,
|
192 |
+
norm_decay=0.):
|
193 |
+
super(CrossResolutionWeightingModule, self).__init__()
|
194 |
+
self.channels = channels
|
195 |
+
total_channel = sum(channels)
|
196 |
+
self.conv1 = ConvNormLayer(
|
197 |
+
ch_in=total_channel,
|
198 |
+
ch_out=total_channel // ratio,
|
199 |
+
filter_size=1,
|
200 |
+
stride=1,
|
201 |
+
norm_type=norm_type,
|
202 |
+
act='relu',
|
203 |
+
freeze_norm=freeze_norm,
|
204 |
+
norm_decay=norm_decay)
|
205 |
+
self.conv2 = ConvNormLayer(
|
206 |
+
ch_in=total_channel // ratio,
|
207 |
+
ch_out=total_channel,
|
208 |
+
filter_size=1,
|
209 |
+
stride=1,
|
210 |
+
norm_type=norm_type,
|
211 |
+
act='sigmoid',
|
212 |
+
freeze_norm=freeze_norm,
|
213 |
+
norm_decay=norm_decay)
|
214 |
+
|
215 |
+
def forward(self, x):
|
216 |
+
out = []
|
217 |
+
for idx, xi in enumerate(x[:-1]):
|
218 |
+
kernel_size = stride = pow(2, len(x) - idx - 1)
|
219 |
+
xi = F.avg_pool2d(xi, kernel_size=kernel_size, stride=stride)
|
220 |
+
out.append(xi)
|
221 |
+
out.append(x[-1])
|
222 |
+
|
223 |
+
out = paddle.concat(out, 1)
|
224 |
+
out = self.conv1(out)
|
225 |
+
out = self.conv2(out)
|
226 |
+
out = paddle.split(out, self.channels, 1)
|
227 |
+
out = [
|
228 |
+
s * F.interpolate(
|
229 |
+
a, paddle.shape(s)[-2:], mode='nearest') for s, a in zip(x, out)
|
230 |
+
]
|
231 |
+
return out
|
232 |
+
|
233 |
+
|
234 |
+
class SpatialWeightingModule(nn.Layer):
|
235 |
+
def __init__(self, in_channel, ratio=16, freeze_norm=False, norm_decay=0.):
|
236 |
+
super(SpatialWeightingModule, self).__init__()
|
237 |
+
self.global_avgpooling = nn.AdaptiveAvgPool2D(1)
|
238 |
+
self.conv1 = ConvNormLayer(
|
239 |
+
ch_in=in_channel,
|
240 |
+
ch_out=in_channel // ratio,
|
241 |
+
filter_size=1,
|
242 |
+
stride=1,
|
243 |
+
act='relu',
|
244 |
+
freeze_norm=freeze_norm,
|
245 |
+
norm_decay=norm_decay)
|
246 |
+
self.conv2 = ConvNormLayer(
|
247 |
+
ch_in=in_channel // ratio,
|
248 |
+
ch_out=in_channel,
|
249 |
+
filter_size=1,
|
250 |
+
stride=1,
|
251 |
+
act='sigmoid',
|
252 |
+
freeze_norm=freeze_norm,
|
253 |
+
norm_decay=norm_decay)
|
254 |
+
|
255 |
+
def forward(self, x):
|
256 |
+
out = self.global_avgpooling(x)
|
257 |
+
out = self.conv1(out)
|
258 |
+
out = self.conv2(out)
|
259 |
+
return x * out
|
260 |
+
|
261 |
+
|
262 |
+
class ConditionalChannelWeightingBlock(nn.Layer):
|
263 |
+
def __init__(self,
|
264 |
+
in_channels,
|
265 |
+
stride,
|
266 |
+
reduce_ratio,
|
267 |
+
norm_type='bn',
|
268 |
+
freeze_norm=False,
|
269 |
+
norm_decay=0.):
|
270 |
+
super(ConditionalChannelWeightingBlock, self).__init__()
|
271 |
+
assert stride in [1, 2]
|
272 |
+
branch_channels = [channel // 2 for channel in in_channels]
|
273 |
+
|
274 |
+
self.cross_resolution_weighting = CrossResolutionWeightingModule(
|
275 |
+
branch_channels,
|
276 |
+
ratio=reduce_ratio,
|
277 |
+
norm_type=norm_type,
|
278 |
+
freeze_norm=freeze_norm,
|
279 |
+
norm_decay=norm_decay)
|
280 |
+
self.depthwise_convs = nn.LayerList([
|
281 |
+
ConvNormLayer(
|
282 |
+
channel,
|
283 |
+
channel,
|
284 |
+
filter_size=3,
|
285 |
+
stride=stride,
|
286 |
+
groups=channel,
|
287 |
+
norm_type=norm_type,
|
288 |
+
freeze_norm=freeze_norm,
|
289 |
+
norm_decay=norm_decay) for channel in branch_channels
|
290 |
+
])
|
291 |
+
|
292 |
+
self.spatial_weighting = nn.LayerList([
|
293 |
+
SpatialWeightingModule(
|
294 |
+
channel,
|
295 |
+
ratio=4,
|
296 |
+
freeze_norm=freeze_norm,
|
297 |
+
norm_decay=norm_decay) for channel in branch_channels
|
298 |
+
])
|
299 |
+
|
300 |
+
def forward(self, x):
|
301 |
+
x = [s.chunk(2, axis=1) for s in x]
|
302 |
+
x1 = [s[0] for s in x]
|
303 |
+
x2 = [s[1] for s in x]
|
304 |
+
|
305 |
+
x2 = self.cross_resolution_weighting(x2)
|
306 |
+
x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)]
|
307 |
+
x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)]
|
308 |
+
|
309 |
+
out = [paddle.concat([s1, s2], axis=1) for s1, s2 in zip(x1, x2)]
|
310 |
+
out = [channel_shuffle(s, groups=2) for s in out]
|
311 |
+
return out
|
312 |
+
|
313 |
+
|
314 |
+
class ShuffleUnit(nn.Layer):
|
315 |
+
def __init__(self,
|
316 |
+
in_channel,
|
317 |
+
out_channel,
|
318 |
+
stride,
|
319 |
+
norm_type='bn',
|
320 |
+
freeze_norm=False,
|
321 |
+
norm_decay=0.):
|
322 |
+
super(ShuffleUnit, self).__init__()
|
323 |
+
branch_channel = out_channel // 2
|
324 |
+
self.stride = stride
|
325 |
+
if self.stride == 1:
|
326 |
+
assert in_channel == branch_channel * 2, \
|
327 |
+
"when stride=1, in_channel {} should equal to branch_channel*2 {}".format(in_channel, branch_channel * 2)
|
328 |
+
if stride > 1:
|
329 |
+
self.branch1 = nn.Sequential(
|
330 |
+
ConvNormLayer(
|
331 |
+
ch_in=in_channel,
|
332 |
+
ch_out=in_channel,
|
333 |
+
filter_size=3,
|
334 |
+
stride=self.stride,
|
335 |
+
groups=in_channel,
|
336 |
+
norm_type=norm_type,
|
337 |
+
freeze_norm=freeze_norm,
|
338 |
+
norm_decay=norm_decay),
|
339 |
+
ConvNormLayer(
|
340 |
+
ch_in=in_channel,
|
341 |
+
ch_out=branch_channel,
|
342 |
+
filter_size=1,
|
343 |
+
stride=1,
|
344 |
+
norm_type=norm_type,
|
345 |
+
act='relu',
|
346 |
+
freeze_norm=freeze_norm,
|
347 |
+
norm_decay=norm_decay), )
|
348 |
+
self.branch2 = nn.Sequential(
|
349 |
+
ConvNormLayer(
|
350 |
+
ch_in=branch_channel if stride == 1 else in_channel,
|
351 |
+
ch_out=branch_channel,
|
352 |
+
filter_size=1,
|
353 |
+
stride=1,
|
354 |
+
norm_type=norm_type,
|
355 |
+
act='relu',
|
356 |
+
freeze_norm=freeze_norm,
|
357 |
+
norm_decay=norm_decay),
|
358 |
+
ConvNormLayer(
|
359 |
+
ch_in=branch_channel,
|
360 |
+
ch_out=branch_channel,
|
361 |
+
filter_size=3,
|
362 |
+
stride=self.stride,
|
363 |
+
groups=branch_channel,
|
364 |
+
norm_type=norm_type,
|
365 |
+
freeze_norm=freeze_norm,
|
366 |
+
norm_decay=norm_decay),
|
367 |
+
ConvNormLayer(
|
368 |
+
ch_in=branch_channel,
|
369 |
+
ch_out=branch_channel,
|
370 |
+
filter_size=1,
|
371 |
+
stride=1,
|
372 |
+
norm_type=norm_type,
|
373 |
+
act='relu',
|
374 |
+
freeze_norm=freeze_norm,
|
375 |
+
norm_decay=norm_decay), )
|
376 |
+
|
377 |
+
def forward(self, x):
|
378 |
+
if self.stride > 1:
|
379 |
+
x1 = self.branch1(x)
|
380 |
+
x2 = self.branch2(x)
|
381 |
+
else:
|
382 |
+
x1, x2 = x.chunk(2, axis=1)
|
383 |
+
x2 = self.branch2(x2)
|
384 |
+
out = paddle.concat([x1, x2], axis=1)
|
385 |
+
out = channel_shuffle(out, groups=2)
|
386 |
+
return out
|
387 |
+
|
388 |
+
|
389 |
+
class IterativeHead(nn.Layer):
|
390 |
+
def __init__(self,
|
391 |
+
in_channels,
|
392 |
+
norm_type='bn',
|
393 |
+
freeze_norm=False,
|
394 |
+
norm_decay=0.):
|
395 |
+
super(IterativeHead, self).__init__()
|
396 |
+
num_branches = len(in_channels)
|
397 |
+
self.in_channels = in_channels[::-1]
|
398 |
+
|
399 |
+
projects = []
|
400 |
+
for i in range(num_branches):
|
401 |
+
if i != num_branches - 1:
|
402 |
+
projects.append(
|
403 |
+
DepthWiseSeparableConvNormLayer(
|
404 |
+
ch_in=self.in_channels[i],
|
405 |
+
ch_out=self.in_channels[i + 1],
|
406 |
+
filter_size=3,
|
407 |
+
stride=1,
|
408 |
+
dw_act=None,
|
409 |
+
pw_act='relu',
|
410 |
+
dw_norm_type=norm_type,
|
411 |
+
pw_norm_type=norm_type,
|
412 |
+
freeze_norm=freeze_norm,
|
413 |
+
norm_decay=norm_decay))
|
414 |
+
else:
|
415 |
+
projects.append(
|
416 |
+
DepthWiseSeparableConvNormLayer(
|
417 |
+
ch_in=self.in_channels[i],
|
418 |
+
ch_out=self.in_channels[i],
|
419 |
+
filter_size=3,
|
420 |
+
stride=1,
|
421 |
+
dw_act=None,
|
422 |
+
pw_act='relu',
|
423 |
+
dw_norm_type=norm_type,
|
424 |
+
pw_norm_type=norm_type,
|
425 |
+
freeze_norm=freeze_norm,
|
426 |
+
norm_decay=norm_decay))
|
427 |
+
self.projects = nn.LayerList(projects)
|
428 |
+
|
429 |
+
def forward(self, x):
|
430 |
+
x = x[::-1]
|
431 |
+
y = []
|
432 |
+
last_x = None
|
433 |
+
for i, s in enumerate(x):
|
434 |
+
if last_x is not None:
|
435 |
+
last_x = F.interpolate(
|
436 |
+
last_x,
|
437 |
+
size=paddle.shape(s)[-2:],
|
438 |
+
mode='bilinear',
|
439 |
+
align_corners=True)
|
440 |
+
s = s + last_x
|
441 |
+
s = self.projects[i](s)
|
442 |
+
y.append(s)
|
443 |
+
last_x = s
|
444 |
+
|
445 |
+
return y[::-1]
|
446 |
+
|
447 |
+
|
448 |
+
class Stem(nn.Layer):
|
449 |
+
def __init__(self,
|
450 |
+
in_channel,
|
451 |
+
stem_channel,
|
452 |
+
out_channel,
|
453 |
+
expand_ratio,
|
454 |
+
norm_type='bn',
|
455 |
+
freeze_norm=False,
|
456 |
+
norm_decay=0.):
|
457 |
+
super(Stem, self).__init__()
|
458 |
+
self.conv1 = ConvNormLayer(
|
459 |
+
in_channel,
|
460 |
+
stem_channel,
|
461 |
+
filter_size=3,
|
462 |
+
stride=2,
|
463 |
+
norm_type=norm_type,
|
464 |
+
act='relu',
|
465 |
+
freeze_norm=freeze_norm,
|
466 |
+
norm_decay=norm_decay)
|
467 |
+
mid_channel = int(round(stem_channel * expand_ratio))
|
468 |
+
branch_channel = stem_channel // 2
|
469 |
+
if stem_channel == out_channel:
|
470 |
+
inc_channel = out_channel - branch_channel
|
471 |
+
else:
|
472 |
+
inc_channel = out_channel - stem_channel
|
473 |
+
self.branch1 = nn.Sequential(
|
474 |
+
ConvNormLayer(
|
475 |
+
ch_in=branch_channel,
|
476 |
+
ch_out=branch_channel,
|
477 |
+
filter_size=3,
|
478 |
+
stride=2,
|
479 |
+
groups=branch_channel,
|
480 |
+
norm_type=norm_type,
|
481 |
+
freeze_norm=freeze_norm,
|
482 |
+
norm_decay=norm_decay),
|
483 |
+
ConvNormLayer(
|
484 |
+
ch_in=branch_channel,
|
485 |
+
ch_out=inc_channel,
|
486 |
+
filter_size=1,
|
487 |
+
stride=1,
|
488 |
+
norm_type=norm_type,
|
489 |
+
act='relu',
|
490 |
+
freeze_norm=freeze_norm,
|
491 |
+
norm_decay=norm_decay), )
|
492 |
+
self.expand_conv = ConvNormLayer(
|
493 |
+
ch_in=branch_channel,
|
494 |
+
ch_out=mid_channel,
|
495 |
+
filter_size=1,
|
496 |
+
stride=1,
|
497 |
+
norm_type=norm_type,
|
498 |
+
act='relu',
|
499 |
+
freeze_norm=freeze_norm,
|
500 |
+
norm_decay=norm_decay)
|
501 |
+
self.depthwise_conv = ConvNormLayer(
|
502 |
+
ch_in=mid_channel,
|
503 |
+
ch_out=mid_channel,
|
504 |
+
filter_size=3,
|
505 |
+
stride=2,
|
506 |
+
groups=mid_channel,
|
507 |
+
norm_type=norm_type,
|
508 |
+
freeze_norm=freeze_norm,
|
509 |
+
norm_decay=norm_decay)
|
510 |
+
self.linear_conv = ConvNormLayer(
|
511 |
+
ch_in=mid_channel,
|
512 |
+
ch_out=branch_channel
|
513 |
+
if stem_channel == out_channel else stem_channel,
|
514 |
+
filter_size=1,
|
515 |
+
stride=1,
|
516 |
+
norm_type=norm_type,
|
517 |
+
act='relu',
|
518 |
+
freeze_norm=freeze_norm,
|
519 |
+
norm_decay=norm_decay)
|
520 |
+
|
521 |
+
def forward(self, x):
|
522 |
+
x = self.conv1(x)
|
523 |
+
x1, x2 = x.chunk(2, axis=1)
|
524 |
+
x1 = self.branch1(x1)
|
525 |
+
x2 = self.expand_conv(x2)
|
526 |
+
x2 = self.depthwise_conv(x2)
|
527 |
+
x2 = self.linear_conv(x2)
|
528 |
+
out = paddle.concat([x1, x2], axis=1)
|
529 |
+
out = channel_shuffle(out, groups=2)
|
530 |
+
|
531 |
+
return out
|
532 |
+
|
533 |
+
|
534 |
+
class LiteHRNetModule(nn.Layer):
|
535 |
+
def __init__(self,
|
536 |
+
num_branches,
|
537 |
+
num_blocks,
|
538 |
+
in_channels,
|
539 |
+
reduce_ratio,
|
540 |
+
module_type,
|
541 |
+
multiscale_output=False,
|
542 |
+
with_fuse=True,
|
543 |
+
norm_type='bn',
|
544 |
+
freeze_norm=False,
|
545 |
+
norm_decay=0.):
|
546 |
+
super(LiteHRNetModule, self).__init__()
|
547 |
+
assert num_branches == len(in_channels),\
|
548 |
+
"num_branches {} should equal to num_in_channels {}".format(num_branches, len(in_channels))
|
549 |
+
assert module_type in [
|
550 |
+
'LITE', 'NAIVE'
|
551 |
+
], "module_type should be one of ['LITE', 'NAIVE']"
|
552 |
+
self.num_branches = num_branches
|
553 |
+
self.in_channels = in_channels
|
554 |
+
self.multiscale_output = multiscale_output
|
555 |
+
self.with_fuse = with_fuse
|
556 |
+
self.norm_type = 'bn'
|
557 |
+
self.module_type = module_type
|
558 |
+
|
559 |
+
if self.module_type == 'LITE':
|
560 |
+
self.layers = self._make_weighting_blocks(
|
561 |
+
num_blocks,
|
562 |
+
reduce_ratio,
|
563 |
+
freeze_norm=freeze_norm,
|
564 |
+
norm_decay=norm_decay)
|
565 |
+
elif self.module_type == 'NAIVE':
|
566 |
+
self.layers = self._make_naive_branches(
|
567 |
+
num_branches,
|
568 |
+
num_blocks,
|
569 |
+
freeze_norm=freeze_norm,
|
570 |
+
norm_decay=norm_decay)
|
571 |
+
|
572 |
+
if self.with_fuse:
|
573 |
+
self.fuse_layers = self._make_fuse_layers(
|
574 |
+
freeze_norm=freeze_norm, norm_decay=norm_decay)
|
575 |
+
self.relu = nn.ReLU()
|
576 |
+
|
577 |
+
def _make_weighting_blocks(self,
|
578 |
+
num_blocks,
|
579 |
+
reduce_ratio,
|
580 |
+
stride=1,
|
581 |
+
freeze_norm=False,
|
582 |
+
norm_decay=0.):
|
583 |
+
layers = []
|
584 |
+
for i in range(num_blocks):
|
585 |
+
layers.append(
|
586 |
+
ConditionalChannelWeightingBlock(
|
587 |
+
self.in_channels,
|
588 |
+
stride=stride,
|
589 |
+
reduce_ratio=reduce_ratio,
|
590 |
+
norm_type=self.norm_type,
|
591 |
+
freeze_norm=freeze_norm,
|
592 |
+
norm_decay=norm_decay))
|
593 |
+
return nn.Sequential(*layers)
|
594 |
+
|
595 |
+
def _make_naive_branches(self,
|
596 |
+
num_branches,
|
597 |
+
num_blocks,
|
598 |
+
freeze_norm=False,
|
599 |
+
norm_decay=0.):
|
600 |
+
branches = []
|
601 |
+
for branch_idx in range(num_branches):
|
602 |
+
layers = []
|
603 |
+
for i in range(num_blocks):
|
604 |
+
layers.append(
|
605 |
+
ShuffleUnit(
|
606 |
+
self.in_channels[branch_idx],
|
607 |
+
self.in_channels[branch_idx],
|
608 |
+
stride=1,
|
609 |
+
norm_type=self.norm_type,
|
610 |
+
freeze_norm=freeze_norm,
|
611 |
+
norm_decay=norm_decay))
|
612 |
+
branches.append(nn.Sequential(*layers))
|
613 |
+
return nn.LayerList(branches)
|
614 |
+
|
615 |
+
def _make_fuse_layers(self, freeze_norm=False, norm_decay=0.):
|
616 |
+
if self.num_branches == 1:
|
617 |
+
return None
|
618 |
+
fuse_layers = []
|
619 |
+
num_out_branches = self.num_branches if self.multiscale_output else 1
|
620 |
+
for i in range(num_out_branches):
|
621 |
+
fuse_layer = []
|
622 |
+
for j in range(self.num_branches):
|
623 |
+
if j > i:
|
624 |
+
fuse_layer.append(
|
625 |
+
nn.Sequential(
|
626 |
+
Conv2d(
|
627 |
+
self.in_channels[j],
|
628 |
+
self.in_channels[i],
|
629 |
+
kernel_size=1,
|
630 |
+
stride=1,
|
631 |
+
padding=0,
|
632 |
+
bias=False, ),
|
633 |
+
nn.BatchNorm2D(self.in_channels[i]),
|
634 |
+
nn.Upsample(
|
635 |
+
scale_factor=2**(j - i), mode='nearest')))
|
636 |
+
elif j == i:
|
637 |
+
fuse_layer.append(None)
|
638 |
+
else:
|
639 |
+
conv_downsamples = []
|
640 |
+
for k in range(i - j):
|
641 |
+
if k == i - j - 1:
|
642 |
+
conv_downsamples.append(
|
643 |
+
nn.Sequential(
|
644 |
+
Conv2d(
|
645 |
+
self.in_channels[j],
|
646 |
+
self.in_channels[j],
|
647 |
+
kernel_size=3,
|
648 |
+
stride=2,
|
649 |
+
padding=1,
|
650 |
+
groups=self.in_channels[j],
|
651 |
+
bias=False, ),
|
652 |
+
nn.BatchNorm2D(self.in_channels[j]),
|
653 |
+
Conv2d(
|
654 |
+
self.in_channels[j],
|
655 |
+
self.in_channels[i],
|
656 |
+
kernel_size=1,
|
657 |
+
stride=1,
|
658 |
+
padding=0,
|
659 |
+
bias=False, ),
|
660 |
+
nn.BatchNorm2D(self.in_channels[i])))
|
661 |
+
else:
|
662 |
+
conv_downsamples.append(
|
663 |
+
nn.Sequential(
|
664 |
+
Conv2d(
|
665 |
+
self.in_channels[j],
|
666 |
+
self.in_channels[j],
|
667 |
+
kernel_size=3,
|
668 |
+
stride=2,
|
669 |
+
padding=1,
|
670 |
+
groups=self.in_channels[j],
|
671 |
+
bias=False, ),
|
672 |
+
nn.BatchNorm2D(self.in_channels[j]),
|
673 |
+
Conv2d(
|
674 |
+
self.in_channels[j],
|
675 |
+
self.in_channels[j],
|
676 |
+
kernel_size=1,
|
677 |
+
stride=1,
|
678 |
+
padding=0,
|
679 |
+
bias=False, ),
|
680 |
+
nn.BatchNorm2D(self.in_channels[j]),
|
681 |
+
nn.ReLU()))
|
682 |
+
|
683 |
+
fuse_layer.append(nn.Sequential(*conv_downsamples))
|
684 |
+
fuse_layers.append(nn.LayerList(fuse_layer))
|
685 |
+
|
686 |
+
return nn.LayerList(fuse_layers)
|
687 |
+
|
688 |
+
def forward(self, x):
|
689 |
+
if self.num_branches == 1:
|
690 |
+
return [self.layers[0](x[0])]
|
691 |
+
if self.module_type == 'LITE':
|
692 |
+
out = self.layers(x)
|
693 |
+
elif self.module_type == 'NAIVE':
|
694 |
+
for i in range(self.num_branches):
|
695 |
+
x[i] = self.layers[i](x[i])
|
696 |
+
out = x
|
697 |
+
if self.with_fuse:
|
698 |
+
out_fuse = []
|
699 |
+
for i in range(len(self.fuse_layers)):
|
700 |
+
y = out[0] if i == 0 else self.fuse_layers[i][0](out[0])
|
701 |
+
for j in range(self.num_branches):
|
702 |
+
if j == 0:
|
703 |
+
y += y
|
704 |
+
elif i == j:
|
705 |
+
y += out[j]
|
706 |
+
else:
|
707 |
+
y += self.fuse_layers[i][j](out[j])
|
708 |
+
if i == 0:
|
709 |
+
out[i] = y
|
710 |
+
out_fuse.append(self.relu(y))
|
711 |
+
out = out_fuse
|
712 |
+
elif not self.multiscale_output:
|
713 |
+
out = [out[0]]
|
714 |
+
return out
|
715 |
+
|
716 |
+
|
717 |
+
class LiteHRNet(nn.Layer):
|
718 |
+
"""
|
719 |
+
@inproceedings{Yulitehrnet21,
|
720 |
+
title={Lite-HRNet: A Lightweight High-Resolution Network},
|
721 |
+
author={Yu, Changqian and Xiao, Bin and Gao, Changxin and Yuan, Lu and Zhang, Lei and Sang, Nong and Wang, Jingdong},
|
722 |
+
booktitle={CVPR},year={2021}
|
723 |
+
}
|
724 |
+
|
725 |
+
Args:
|
726 |
+
network_type (str): the network_type should be one of ["lite_18", "lite_30", "naive", "wider_naive"],
|
727 |
+
"naive": Simply combining the shuffle block in ShuffleNet and the highresolution design pattern in HRNet.
|
728 |
+
"wider_naive": Naive network with wider channels in each block.
|
729 |
+
"lite_18": Lite-HRNet-18, which replaces the pointwise convolution in a shuffle block by conditional channel weighting.
|
730 |
+
"lite_30": Lite-HRNet-30, with more blocks compared with Lite-HRNet-18.
|
731 |
+
freeze_at (int): the stage to freeze
|
732 |
+
freeze_norm (bool): whether to freeze norm in HRNet
|
733 |
+
norm_decay (float): weight decay for normalization layer weights
|
734 |
+
return_idx (List): the stage to return
|
735 |
+
"""
|
736 |
+
|
737 |
+
def __init__(self,
|
738 |
+
network_type,
|
739 |
+
freeze_at=0,
|
740 |
+
freeze_norm=True,
|
741 |
+
norm_decay=0.,
|
742 |
+
return_idx=[0, 1, 2, 3],
|
743 |
+
use_head=False,
|
744 |
+
pretrained=None):
|
745 |
+
super(LiteHRNet, self).__init__()
|
746 |
+
if isinstance(return_idx, Integral):
|
747 |
+
return_idx = [return_idx]
|
748 |
+
assert network_type in ["lite_18", "lite_30", "naive", "wider_naive"], \
|
749 |
+
"the network_type should be one of [lite_18, lite_30, naive, wider_naive]"
|
750 |
+
assert len(return_idx) > 0, "need one or more return index"
|
751 |
+
self.freeze_at = freeze_at
|
752 |
+
self.freeze_norm = freeze_norm
|
753 |
+
self.norm_decay = norm_decay
|
754 |
+
self.return_idx = return_idx
|
755 |
+
self.norm_type = 'bn'
|
756 |
+
self.use_head = use_head
|
757 |
+
self.pretrained = pretrained
|
758 |
+
|
759 |
+
self.module_configs = {
|
760 |
+
"lite_18": {
|
761 |
+
"num_modules": [2, 4, 2],
|
762 |
+
"num_branches": [2, 3, 4],
|
763 |
+
"num_blocks": [2, 2, 2],
|
764 |
+
"module_type": ["LITE", "LITE", "LITE"],
|
765 |
+
"reduce_ratios": [8, 8, 8],
|
766 |
+
"num_channels": [[40, 80], [40, 80, 160], [40, 80, 160, 320]],
|
767 |
+
},
|
768 |
+
"lite_30": {
|
769 |
+
"num_modules": [3, 8, 3],
|
770 |
+
"num_branches": [2, 3, 4],
|
771 |
+
"num_blocks": [2, 2, 2],
|
772 |
+
"module_type": ["LITE", "LITE", "LITE"],
|
773 |
+
"reduce_ratios": [8, 8, 8],
|
774 |
+
"num_channels": [[40, 80], [40, 80, 160], [40, 80, 160, 320]],
|
775 |
+
},
|
776 |
+
"naive": {
|
777 |
+
"num_modules": [2, 4, 2],
|
778 |
+
"num_branches": [2, 3, 4],
|
779 |
+
"num_blocks": [2, 2, 2],
|
780 |
+
"module_type": ["NAIVE", "NAIVE", "NAIVE"],
|
781 |
+
"reduce_ratios": [1, 1, 1],
|
782 |
+
"num_channels": [[30, 60], [30, 60, 120], [30, 60, 120, 240]],
|
783 |
+
},
|
784 |
+
"wider_naive": {
|
785 |
+
"num_modules": [2, 4, 2],
|
786 |
+
"num_branches": [2, 3, 4],
|
787 |
+
"num_blocks": [2, 2, 2],
|
788 |
+
"module_type": ["NAIVE", "NAIVE", "NAIVE"],
|
789 |
+
"reduce_ratios": [1, 1, 1],
|
790 |
+
"num_channels": [[40, 80], [40, 80, 160], [40, 80, 160, 320]],
|
791 |
+
},
|
792 |
+
}
|
793 |
+
|
794 |
+
self.stages_config = self.module_configs[network_type]
|
795 |
+
|
796 |
+
self.stem = Stem(3, 32, 32, 1)
|
797 |
+
num_channels_pre_layer = [32]
|
798 |
+
for stage_idx in range(3):
|
799 |
+
num_channels = self.stages_config["num_channels"][stage_idx]
|
800 |
+
setattr(self, 'transition{}'.format(stage_idx),
|
801 |
+
self._make_transition_layer(num_channels_pre_layer,
|
802 |
+
num_channels, self.freeze_norm,
|
803 |
+
self.norm_decay))
|
804 |
+
stage, num_channels_pre_layer = self._make_stage(
|
805 |
+
self.stages_config, stage_idx, num_channels, True,
|
806 |
+
self.freeze_norm, self.norm_decay)
|
807 |
+
setattr(self, 'stage{}'.format(stage_idx), stage)
|
808 |
+
|
809 |
+
num_channels = self.stages_config["num_channels"][-1]
|
810 |
+
self.feat_channels = num_channels
|
811 |
+
|
812 |
+
if self.use_head:
|
813 |
+
self.head_layer = IterativeHead(num_channels_pre_layer, 'bn',
|
814 |
+
self.freeze_norm, self.norm_decay)
|
815 |
+
|
816 |
+
self.feat_channels = [num_channels[0]]
|
817 |
+
for i in range(1, len(num_channels)):
|
818 |
+
self.feat_channels.append(num_channels[i] // 2)
|
819 |
+
|
820 |
+
self.init_weight()
|
821 |
+
|
822 |
+
def init_weight(self):
|
823 |
+
if self.pretrained is not None:
|
824 |
+
utils.load_entire_model(self, self.pretrained)
|
825 |
+
|
826 |
+
def _make_transition_layer(self,
|
827 |
+
num_channels_pre_layer,
|
828 |
+
num_channels_cur_layer,
|
829 |
+
freeze_norm=False,
|
830 |
+
norm_decay=0.):
|
831 |
+
num_branches_pre = len(num_channels_pre_layer)
|
832 |
+
num_branches_cur = len(num_channels_cur_layer)
|
833 |
+
transition_layers = []
|
834 |
+
for i in range(num_branches_cur):
|
835 |
+
if i < num_branches_pre:
|
836 |
+
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
837 |
+
transition_layers.append(
|
838 |
+
nn.Sequential(
|
839 |
+
Conv2d(
|
840 |
+
num_channels_pre_layer[i],
|
841 |
+
num_channels_pre_layer[i],
|
842 |
+
kernel_size=3,
|
843 |
+
stride=1,
|
844 |
+
padding=1,
|
845 |
+
groups=num_channels_pre_layer[i],
|
846 |
+
bias=False),
|
847 |
+
nn.BatchNorm2D(num_channels_pre_layer[i]),
|
848 |
+
Conv2d(
|
849 |
+
num_channels_pre_layer[i],
|
850 |
+
num_channels_cur_layer[i],
|
851 |
+
kernel_size=1,
|
852 |
+
stride=1,
|
853 |
+
padding=0,
|
854 |
+
bias=False, ),
|
855 |
+
nn.BatchNorm2D(num_channels_cur_layer[i]),
|
856 |
+
nn.ReLU()))
|
857 |
+
else:
|
858 |
+
transition_layers.append(None)
|
859 |
+
else:
|
860 |
+
conv_downsamples = []
|
861 |
+
for j in range(i + 1 - num_branches_pre):
|
862 |
+
conv_downsamples.append(
|
863 |
+
nn.Sequential(
|
864 |
+
Conv2d(
|
865 |
+
num_channels_pre_layer[-1],
|
866 |
+
num_channels_pre_layer[-1],
|
867 |
+
groups=num_channels_pre_layer[-1],
|
868 |
+
kernel_size=3,
|
869 |
+
stride=2,
|
870 |
+
padding=1,
|
871 |
+
bias=False, ),
|
872 |
+
nn.BatchNorm2D(num_channels_pre_layer[-1]),
|
873 |
+
Conv2d(
|
874 |
+
num_channels_pre_layer[-1],
|
875 |
+
num_channels_cur_layer[i]
|
876 |
+
if j == i - num_branches_pre else
|
877 |
+
num_channels_pre_layer[-1],
|
878 |
+
kernel_size=1,
|
879 |
+
stride=1,
|
880 |
+
padding=0,
|
881 |
+
bias=False, ),
|
882 |
+
nn.BatchNorm2D(num_channels_cur_layer[i]
|
883 |
+
if j == i - num_branches_pre else
|
884 |
+
num_channels_pre_layer[-1]),
|
885 |
+
nn.ReLU()))
|
886 |
+
transition_layers.append(nn.Sequential(*conv_downsamples))
|
887 |
+
return nn.LayerList(transition_layers)
|
888 |
+
|
889 |
+
def _make_stage(self,
|
890 |
+
stages_config,
|
891 |
+
stage_idx,
|
892 |
+
in_channels,
|
893 |
+
multiscale_output,
|
894 |
+
freeze_norm=False,
|
895 |
+
norm_decay=0.):
|
896 |
+
num_modules = stages_config["num_modules"][stage_idx]
|
897 |
+
num_branches = stages_config["num_branches"][stage_idx]
|
898 |
+
num_blocks = stages_config["num_blocks"][stage_idx]
|
899 |
+
reduce_ratio = stages_config['reduce_ratios'][stage_idx]
|
900 |
+
module_type = stages_config['module_type'][stage_idx]
|
901 |
+
|
902 |
+
modules = []
|
903 |
+
for i in range(num_modules):
|
904 |
+
if not multiscale_output and i == num_modules - 1:
|
905 |
+
reset_multiscale_output = False
|
906 |
+
else:
|
907 |
+
reset_multiscale_output = True
|
908 |
+
modules.append(
|
909 |
+
LiteHRNetModule(
|
910 |
+
num_branches,
|
911 |
+
num_blocks,
|
912 |
+
in_channels,
|
913 |
+
reduce_ratio,
|
914 |
+
module_type,
|
915 |
+
multiscale_output=reset_multiscale_output,
|
916 |
+
with_fuse=True,
|
917 |
+
freeze_norm=freeze_norm,
|
918 |
+
norm_decay=norm_decay))
|
919 |
+
in_channels = modules[-1].in_channels
|
920 |
+
return nn.Sequential(*modules), in_channels
|
921 |
+
|
922 |
+
def forward(self, x):
|
923 |
+
x = self.stem(x)
|
924 |
+
|
925 |
+
y_list = [x]
|
926 |
+
for stage_idx in range(3):
|
927 |
+
x_list = []
|
928 |
+
transition = getattr(self, 'transition{}'.format(stage_idx))
|
929 |
+
for j in range(self.stages_config["num_branches"][stage_idx]):
|
930 |
+
if transition[j] is not None:
|
931 |
+
if j >= len(y_list):
|
932 |
+
x_list.append(transition[j](y_list[-1]))
|
933 |
+
else:
|
934 |
+
x_list.append(transition[j](y_list[j]))
|
935 |
+
else:
|
936 |
+
x_list.append(y_list[j])
|
937 |
+
y_list = getattr(self, 'stage{}'.format(stage_idx))(x_list)
|
938 |
+
|
939 |
+
if self.use_head:
|
940 |
+
y_list = self.head_layer(y_list)
|
941 |
+
|
942 |
+
res = []
|
943 |
+
for i, layer in enumerate(y_list):
|
944 |
+
if i == self.freeze_at:
|
945 |
+
layer.stop_gradient = True
|
946 |
+
if i in self.return_idx:
|
947 |
+
res.append(layer)
|
948 |
+
return res
|
949 |
+
|
950 |
+
|
951 |
+
@manager.BACKBONES.add_component
|
952 |
+
def Lite_HRNet_18(**kwargs):
|
953 |
+
model = LiteHRNet(network_type="lite_18", **kwargs)
|
954 |
+
return model
|
955 |
+
|
956 |
+
|
957 |
+
@manager.BACKBONES.add_component
|
958 |
+
def Lite_HRNet_30(**kwargs):
|
959 |
+
model = LiteHRNet(network_type="lite_30", **kwargs)
|
960 |
+
return model
|
961 |
+
|
962 |
+
|
963 |
+
@manager.BACKBONES.add_component
|
964 |
+
def Lite_HRNet_naive(**kwargs):
|
965 |
+
model = LiteHRNet(network_type="naive", **kwargs)
|
966 |
+
return model
|
967 |
+
|
968 |
+
|
969 |
+
@manager.BACKBONES.add_component
|
970 |
+
def Lite_HRNet_wider_naive(**kwargs):
|
971 |
+
model = LiteHRNet(network_type="wider_naive", **kwargs)
|
972 |
+
return model
|
paddleseg/models/backbones/mix_transformer.py
ADDED
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
from functools import partial
|
17 |
+
|
18 |
+
import paddle
|
19 |
+
import paddle.nn as nn
|
20 |
+
import paddle.nn.functional as F
|
21 |
+
import paddle.nn.initializer as paddle_init
|
22 |
+
|
23 |
+
from paddleseg.cvlibs import manager
|
24 |
+
from paddleseg.utils import utils
|
25 |
+
from paddleseg.models.backbones.transformer_utils import *
|
26 |
+
|
27 |
+
|
28 |
+
class Mlp(nn.Layer):
|
29 |
+
def __init__(self,
|
30 |
+
in_features,
|
31 |
+
hidden_features=None,
|
32 |
+
out_features=None,
|
33 |
+
act_layer=nn.GELU,
|
34 |
+
drop=0.):
|
35 |
+
super().__init__()
|
36 |
+
out_features = out_features or in_features
|
37 |
+
hidden_features = hidden_features or in_features
|
38 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
39 |
+
self.dwconv = DWConv(hidden_features)
|
40 |
+
self.act = act_layer()
|
41 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
42 |
+
self.drop = nn.Dropout(drop)
|
43 |
+
|
44 |
+
self.apply(self._init_weights)
|
45 |
+
|
46 |
+
def _init_weights(self, m):
|
47 |
+
if isinstance(m, nn.Linear):
|
48 |
+
trunc_normal_(m.weight)
|
49 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
50 |
+
zeros_(m.bias)
|
51 |
+
elif isinstance(m, nn.LayerNorm):
|
52 |
+
zeros_(m.bias)
|
53 |
+
ones_(m.weight)
|
54 |
+
elif isinstance(m, nn.Conv2D):
|
55 |
+
fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
|
56 |
+
fan_out //= m._groups
|
57 |
+
paddle_init.Normal(0, math.sqrt(2.0 / fan_out))(m.weight)
|
58 |
+
if m.bias is not None:
|
59 |
+
zeros_(m.bias)
|
60 |
+
|
61 |
+
def forward(self, x, H, W):
|
62 |
+
x = self.fc1(x)
|
63 |
+
x = self.dwconv(x, H, W)
|
64 |
+
x = self.act(x)
|
65 |
+
x = self.drop(x)
|
66 |
+
x = self.fc2(x)
|
67 |
+
x = self.drop(x)
|
68 |
+
return x
|
69 |
+
|
70 |
+
|
71 |
+
class Attention(nn.Layer):
|
72 |
+
def __init__(self,
|
73 |
+
dim,
|
74 |
+
num_heads=8,
|
75 |
+
qkv_bias=False,
|
76 |
+
qk_scale=None,
|
77 |
+
attn_drop=0.,
|
78 |
+
proj_drop=0.,
|
79 |
+
sr_ratio=1):
|
80 |
+
super().__init__()
|
81 |
+
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
82 |
+
|
83 |
+
self.dim = dim
|
84 |
+
self.num_heads = num_heads
|
85 |
+
head_dim = dim // num_heads
|
86 |
+
self.scale = qk_scale or head_dim**-0.5
|
87 |
+
self.dim = dim
|
88 |
+
|
89 |
+
self.q = nn.Linear(dim, dim, bias_attr=qkv_bias)
|
90 |
+
self.kv = nn.Linear(dim, dim * 2, bias_attr=qkv_bias)
|
91 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
92 |
+
self.proj = nn.Linear(dim, dim)
|
93 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
94 |
+
|
95 |
+
self.sr_ratio = sr_ratio
|
96 |
+
if sr_ratio > 1:
|
97 |
+
self.sr = nn.Conv2D(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
|
98 |
+
self.norm = nn.LayerNorm(dim)
|
99 |
+
|
100 |
+
self.apply(self._init_weights)
|
101 |
+
|
102 |
+
def _init_weights(self, m):
|
103 |
+
if isinstance(m, nn.Linear):
|
104 |
+
trunc_normal_(m.weight)
|
105 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
106 |
+
zeros_(m.bias)
|
107 |
+
elif isinstance(m, nn.LayerNorm):
|
108 |
+
zeros_(m.bias)
|
109 |
+
ones_(m.weight)
|
110 |
+
elif isinstance(m, nn.Conv2D):
|
111 |
+
fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
|
112 |
+
fan_out //= m._groups
|
113 |
+
paddle_init.Normal(0, math.sqrt(2.0 / fan_out))(m.weight)
|
114 |
+
if m.bias is not None:
|
115 |
+
zeros_(m.bias)
|
116 |
+
|
117 |
+
def forward(self, x, H, W):
|
118 |
+
x_shape = paddle.shape(x)
|
119 |
+
B, N = x_shape[0], x_shape[1]
|
120 |
+
C = self.dim
|
121 |
+
|
122 |
+
q = self.q(x).reshape([B, N, self.num_heads,
|
123 |
+
C // self.num_heads]).transpose([0, 2, 1, 3])
|
124 |
+
|
125 |
+
if self.sr_ratio > 1:
|
126 |
+
x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W])
|
127 |
+
x_ = self.sr(x_).reshape([B, C, -1]).transpose([0, 2, 1])
|
128 |
+
x_ = self.norm(x_)
|
129 |
+
kv = self.kv(x_).reshape(
|
130 |
+
[B, -1, 2, self.num_heads,
|
131 |
+
C // self.num_heads]).transpose([2, 0, 3, 1, 4])
|
132 |
+
else:
|
133 |
+
kv = self.kv(x).reshape(
|
134 |
+
[B, -1, 2, self.num_heads,
|
135 |
+
C // self.num_heads]).transpose([2, 0, 3, 1, 4])
|
136 |
+
k, v = kv[0], kv[1]
|
137 |
+
|
138 |
+
attn = (q @k.transpose([0, 1, 3, 2])) * self.scale
|
139 |
+
attn = F.softmax(attn, axis=-1)
|
140 |
+
attn = self.attn_drop(attn)
|
141 |
+
|
142 |
+
x = (attn @v).transpose([0, 2, 1, 3]).reshape([B, N, C])
|
143 |
+
x = self.proj(x)
|
144 |
+
x = self.proj_drop(x)
|
145 |
+
|
146 |
+
return x
|
147 |
+
|
148 |
+
|
149 |
+
class Block(nn.Layer):
|
150 |
+
def __init__(self,
|
151 |
+
dim,
|
152 |
+
num_heads,
|
153 |
+
mlp_ratio=4.,
|
154 |
+
qkv_bias=False,
|
155 |
+
qk_scale=None,
|
156 |
+
drop=0.,
|
157 |
+
attn_drop=0.,
|
158 |
+
drop_path=0.,
|
159 |
+
act_layer=nn.GELU,
|
160 |
+
norm_layer=nn.LayerNorm,
|
161 |
+
sr_ratio=1):
|
162 |
+
super().__init__()
|
163 |
+
self.norm1 = norm_layer(dim)
|
164 |
+
self.attn = Attention(
|
165 |
+
dim,
|
166 |
+
num_heads=num_heads,
|
167 |
+
qkv_bias=qkv_bias,
|
168 |
+
qk_scale=qk_scale,
|
169 |
+
attn_drop=attn_drop,
|
170 |
+
proj_drop=drop,
|
171 |
+
sr_ratio=sr_ratio)
|
172 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
173 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
|
174 |
+
self.norm2 = norm_layer(dim)
|
175 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
176 |
+
self.mlp = Mlp(in_features=dim,
|
177 |
+
hidden_features=mlp_hidden_dim,
|
178 |
+
act_layer=act_layer,
|
179 |
+
drop=drop)
|
180 |
+
|
181 |
+
self.apply(self._init_weights)
|
182 |
+
|
183 |
+
def _init_weights(self, m):
|
184 |
+
if isinstance(m, nn.Linear):
|
185 |
+
trunc_normal_(m.weight)
|
186 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
187 |
+
zeros_(m.bias)
|
188 |
+
elif isinstance(m, nn.LayerNorm):
|
189 |
+
zeros_(m.bias)
|
190 |
+
ones_(m.weight)
|
191 |
+
elif isinstance(m, nn.Conv2D):
|
192 |
+
fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
|
193 |
+
fan_out //= m._groups
|
194 |
+
paddle_init.Normal(0, math.sqrt(2.0 / fan_out))(m.weight)
|
195 |
+
if m.bias is not None:
|
196 |
+
zeros_(m.bias)
|
197 |
+
|
198 |
+
def forward(self, x, H, W):
|
199 |
+
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
|
200 |
+
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
201 |
+
|
202 |
+
return x
|
203 |
+
|
204 |
+
|
205 |
+
class OverlapPatchEmbed(nn.Layer):
|
206 |
+
""" Image to Patch Embedding
|
207 |
+
"""
|
208 |
+
|
209 |
+
def __init__(self,
|
210 |
+
img_size=224,
|
211 |
+
patch_size=7,
|
212 |
+
stride=4,
|
213 |
+
in_chans=3,
|
214 |
+
embed_dim=768):
|
215 |
+
super().__init__()
|
216 |
+
img_size = to_2tuple(img_size)
|
217 |
+
patch_size = to_2tuple(patch_size)
|
218 |
+
|
219 |
+
self.img_size = img_size
|
220 |
+
self.patch_size = patch_size
|
221 |
+
self.H, self.W = img_size[0] // patch_size[0], img_size[
|
222 |
+
1] // patch_size[1]
|
223 |
+
self.num_patches = self.H * self.W
|
224 |
+
self.proj = nn.Conv2D(
|
225 |
+
in_chans,
|
226 |
+
embed_dim,
|
227 |
+
kernel_size=patch_size,
|
228 |
+
stride=stride,
|
229 |
+
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
230 |
+
self.norm = nn.LayerNorm(embed_dim)
|
231 |
+
|
232 |
+
self.apply(self._init_weights)
|
233 |
+
|
234 |
+
def _init_weights(self, m):
|
235 |
+
if isinstance(m, nn.Linear):
|
236 |
+
trunc_normal_(m.weight)
|
237 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
238 |
+
zeros_(m.bias)
|
239 |
+
elif isinstance(m, nn.LayerNorm):
|
240 |
+
zeros_(m.bias)
|
241 |
+
ones_(m.weight)
|
242 |
+
elif isinstance(m, nn.Conv2D):
|
243 |
+
fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
|
244 |
+
fan_out //= m._groups
|
245 |
+
paddle_init.Normal(0, math.sqrt(2.0 / fan_out))(m.weight)
|
246 |
+
if m.bias is not None:
|
247 |
+
zeros_(m.bias)
|
248 |
+
|
249 |
+
def forward(self, x):
|
250 |
+
x = self.proj(x)
|
251 |
+
x_shape = paddle.shape(x)
|
252 |
+
H, W = x_shape[2], x_shape[3]
|
253 |
+
x = x.flatten(2).transpose([0, 2, 1])
|
254 |
+
x = self.norm(x)
|
255 |
+
|
256 |
+
return x, H, W
|
257 |
+
|
258 |
+
|
259 |
+
class MixVisionTransformer(nn.Layer):
|
260 |
+
def __init__(self,
|
261 |
+
img_size=224,
|
262 |
+
patch_size=16,
|
263 |
+
in_chans=3,
|
264 |
+
num_classes=1000,
|
265 |
+
embed_dims=[64, 128, 256, 512],
|
266 |
+
num_heads=[1, 2, 4, 8],
|
267 |
+
mlp_ratios=[4, 4, 4, 4],
|
268 |
+
qkv_bias=False,
|
269 |
+
qk_scale=None,
|
270 |
+
drop_rate=0.,
|
271 |
+
attn_drop_rate=0.,
|
272 |
+
drop_path_rate=0.,
|
273 |
+
norm_layer=nn.LayerNorm,
|
274 |
+
depths=[3, 4, 6, 3],
|
275 |
+
sr_ratios=[8, 4, 2, 1],
|
276 |
+
pretrained=None):
|
277 |
+
super().__init__()
|
278 |
+
self.num_classes = num_classes
|
279 |
+
self.depths = depths
|
280 |
+
self.feat_channels = embed_dims[:]
|
281 |
+
|
282 |
+
# patch_embed
|
283 |
+
self.patch_embed1 = OverlapPatchEmbed(
|
284 |
+
img_size=img_size,
|
285 |
+
patch_size=7,
|
286 |
+
stride=4,
|
287 |
+
in_chans=in_chans,
|
288 |
+
embed_dim=embed_dims[0])
|
289 |
+
self.patch_embed2 = OverlapPatchEmbed(
|
290 |
+
img_size=img_size // 4,
|
291 |
+
patch_size=3,
|
292 |
+
stride=2,
|
293 |
+
in_chans=embed_dims[0],
|
294 |
+
embed_dim=embed_dims[1])
|
295 |
+
self.patch_embed3 = OverlapPatchEmbed(
|
296 |
+
img_size=img_size // 8,
|
297 |
+
patch_size=3,
|
298 |
+
stride=2,
|
299 |
+
in_chans=embed_dims[1],
|
300 |
+
embed_dim=embed_dims[2])
|
301 |
+
self.patch_embed4 = OverlapPatchEmbed(
|
302 |
+
img_size=img_size // 16,
|
303 |
+
patch_size=3,
|
304 |
+
stride=2,
|
305 |
+
in_chans=embed_dims[2],
|
306 |
+
embed_dim=embed_dims[3])
|
307 |
+
|
308 |
+
# transformer encoder
|
309 |
+
dpr = [
|
310 |
+
x.numpy() for x in paddle.linspace(0, drop_path_rate, sum(depths))
|
311 |
+
] # stochastic depth decay rule
|
312 |
+
cur = 0
|
313 |
+
self.block1 = nn.LayerList([
|
314 |
+
Block(
|
315 |
+
dim=embed_dims[0],
|
316 |
+
num_heads=num_heads[0],
|
317 |
+
mlp_ratio=mlp_ratios[0],
|
318 |
+
qkv_bias=qkv_bias,
|
319 |
+
qk_scale=qk_scale,
|
320 |
+
drop=drop_rate,
|
321 |
+
attn_drop=attn_drop_rate,
|
322 |
+
drop_path=dpr[cur + i],
|
323 |
+
norm_layer=norm_layer,
|
324 |
+
sr_ratio=sr_ratios[0]) for i in range(depths[0])
|
325 |
+
])
|
326 |
+
self.norm1 = norm_layer(embed_dims[0])
|
327 |
+
|
328 |
+
cur += depths[0]
|
329 |
+
self.block2 = nn.LayerList([
|
330 |
+
Block(
|
331 |
+
dim=embed_dims[1],
|
332 |
+
num_heads=num_heads[1],
|
333 |
+
mlp_ratio=mlp_ratios[1],
|
334 |
+
qkv_bias=qkv_bias,
|
335 |
+
qk_scale=qk_scale,
|
336 |
+
drop=drop_rate,
|
337 |
+
attn_drop=attn_drop_rate,
|
338 |
+
drop_path=dpr[cur + i],
|
339 |
+
norm_layer=norm_layer,
|
340 |
+
sr_ratio=sr_ratios[1]) for i in range(depths[1])
|
341 |
+
])
|
342 |
+
self.norm2 = norm_layer(embed_dims[1])
|
343 |
+
|
344 |
+
cur += depths[1]
|
345 |
+
self.block3 = nn.LayerList([
|
346 |
+
Block(
|
347 |
+
dim=embed_dims[2],
|
348 |
+
num_heads=num_heads[2],
|
349 |
+
mlp_ratio=mlp_ratios[2],
|
350 |
+
qkv_bias=qkv_bias,
|
351 |
+
qk_scale=qk_scale,
|
352 |
+
drop=drop_rate,
|
353 |
+
attn_drop=attn_drop_rate,
|
354 |
+
drop_path=dpr[cur + i],
|
355 |
+
norm_layer=norm_layer,
|
356 |
+
sr_ratio=sr_ratios[2]) for i in range(depths[2])
|
357 |
+
])
|
358 |
+
self.norm3 = norm_layer(embed_dims[2])
|
359 |
+
|
360 |
+
cur += depths[2]
|
361 |
+
self.block4 = nn.LayerList([
|
362 |
+
Block(
|
363 |
+
dim=embed_dims[3],
|
364 |
+
num_heads=num_heads[3],
|
365 |
+
mlp_ratio=mlp_ratios[3],
|
366 |
+
qkv_bias=qkv_bias,
|
367 |
+
qk_scale=qk_scale,
|
368 |
+
drop=drop_rate,
|
369 |
+
attn_drop=attn_drop_rate,
|
370 |
+
drop_path=dpr[cur + i],
|
371 |
+
norm_layer=norm_layer,
|
372 |
+
sr_ratio=sr_ratios[3]) for i in range(depths[3])
|
373 |
+
])
|
374 |
+
self.norm4 = norm_layer(embed_dims[3])
|
375 |
+
|
376 |
+
self.pretrained = pretrained
|
377 |
+
self.init_weight()
|
378 |
+
|
379 |
+
def init_weight(self):
|
380 |
+
if self.pretrained is not None:
|
381 |
+
utils.load_pretrained_model(self, self.pretrained)
|
382 |
+
else:
|
383 |
+
self.apply(self._init_weights)
|
384 |
+
|
385 |
+
def _init_weights(self, m):
|
386 |
+
if isinstance(m, nn.Linear):
|
387 |
+
trunc_normal_(m.weight)
|
388 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
389 |
+
zeros_(m.bias)
|
390 |
+
elif isinstance(m, nn.LayerNorm):
|
391 |
+
zeros_(m.bias)
|
392 |
+
ones_(m.weight)
|
393 |
+
elif isinstance(m, nn.Conv2D):
|
394 |
+
fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
|
395 |
+
fan_out //= m._groups
|
396 |
+
paddle_init.Normal(0, math.sqrt(2.0 / fan_out))(m.weight)
|
397 |
+
if m.bias is not None:
|
398 |
+
zeros_(m.bias)
|
399 |
+
|
400 |
+
def reset_drop_path(self, drop_path_rate):
|
401 |
+
dpr = [
|
402 |
+
x.item()
|
403 |
+
for x in paddle.linspace(0, drop_path_rate, sum(self.depths))
|
404 |
+
]
|
405 |
+
cur = 0
|
406 |
+
for i in range(self.depths[0]):
|
407 |
+
self.block1[i].drop_path.drop_prob = dpr[cur + i]
|
408 |
+
|
409 |
+
cur += self.depths[0]
|
410 |
+
for i in range(self.depths[1]):
|
411 |
+
self.block2[i].drop_path.drop_prob = dpr[cur + i]
|
412 |
+
|
413 |
+
cur += self.depths[1]
|
414 |
+
for i in range(self.depths[2]):
|
415 |
+
self.block3[i].drop_path.drop_prob = dpr[cur + i]
|
416 |
+
|
417 |
+
cur += self.depths[2]
|
418 |
+
for i in range(self.depths[3]):
|
419 |
+
self.block4[i].drop_path.drop_prob = dpr[cur + i]
|
420 |
+
|
421 |
+
def freeze_patch_emb(self):
|
422 |
+
self.patch_embed1.requires_grad = False
|
423 |
+
|
424 |
+
def get_classifier(self):
|
425 |
+
return self.head
|
426 |
+
|
427 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
428 |
+
self.num_classes = num_classes
|
429 |
+
self.head = nn.Linear(self.embed_dim,
|
430 |
+
num_classes) if num_classes > 0 else nn.Identity()
|
431 |
+
|
432 |
+
def forward_features(self, x):
|
433 |
+
B = paddle.shape(x)[0]
|
434 |
+
outs = []
|
435 |
+
|
436 |
+
# stage 1
|
437 |
+
x, H, W = self.patch_embed1(x)
|
438 |
+
for i, blk in enumerate(self.block1):
|
439 |
+
x = blk(x, H, W)
|
440 |
+
|
441 |
+
x = self.norm1(x)
|
442 |
+
x = x.reshape([B, H, W, self.feat_channels[0]]).transpose([0, 3, 1, 2])
|
443 |
+
outs.append(x)
|
444 |
+
|
445 |
+
# stage 2
|
446 |
+
x, H, W = self.patch_embed2(x)
|
447 |
+
for i, blk in enumerate(self.block2):
|
448 |
+
x = blk(x, H, W)
|
449 |
+
x = self.norm2(x)
|
450 |
+
x = x.reshape([B, H, W, self.feat_channels[1]]).transpose([0, 3, 1, 2])
|
451 |
+
outs.append(x)
|
452 |
+
|
453 |
+
# stage 3
|
454 |
+
x, H, W = self.patch_embed3(x)
|
455 |
+
for i, blk in enumerate(self.block3):
|
456 |
+
x = blk(x, H, W)
|
457 |
+
x = self.norm3(x)
|
458 |
+
x = x.reshape([B, H, W, self.feat_channels[2]]).transpose([0, 3, 1, 2])
|
459 |
+
outs.append(x)
|
460 |
+
|
461 |
+
# stage 4
|
462 |
+
x, H, W = self.patch_embed4(x)
|
463 |
+
for i, blk in enumerate(self.block4):
|
464 |
+
x = blk(x, H, W)
|
465 |
+
x = self.norm4(x)
|
466 |
+
x = x.reshape([B, H, W, self.feat_channels[3]]).transpose([0, 3, 1, 2])
|
467 |
+
outs.append(x)
|
468 |
+
|
469 |
+
return outs
|
470 |
+
|
471 |
+
def forward(self, x):
|
472 |
+
x = self.forward_features(x)
|
473 |
+
# x = self.head(x)
|
474 |
+
|
475 |
+
return x
|
476 |
+
|
477 |
+
|
478 |
+
class DWConv(nn.Layer):
|
479 |
+
def __init__(self, dim=768):
|
480 |
+
super(DWConv, self).__init__()
|
481 |
+
self.dim = dim
|
482 |
+
self.dwconv = nn.Conv2D(dim, dim, 3, 1, 1, bias_attr=True, groups=dim)
|
483 |
+
|
484 |
+
def forward(self, x, H, W):
|
485 |
+
x_shape = paddle.shape(x)
|
486 |
+
B, N = x_shape[0], x_shape[1]
|
487 |
+
x = x.transpose([0, 2, 1]).reshape([B, self.dim, H, W])
|
488 |
+
x = self.dwconv(x)
|
489 |
+
x = x.flatten(2).transpose([0, 2, 1])
|
490 |
+
|
491 |
+
return x
|
492 |
+
|
493 |
+
|
494 |
+
@manager.BACKBONES.add_component
|
495 |
+
def MixVisionTransformer_B0(**kwargs):
|
496 |
+
return MixVisionTransformer(
|
497 |
+
patch_size=4,
|
498 |
+
embed_dims=[32, 64, 160, 256],
|
499 |
+
num_heads=[1, 2, 5, 8],
|
500 |
+
mlp_ratios=[4, 4, 4, 4],
|
501 |
+
qkv_bias=True,
|
502 |
+
norm_layer=partial(
|
503 |
+
nn.LayerNorm, epsilon=1e-6),
|
504 |
+
depths=[2, 2, 2, 2],
|
505 |
+
sr_ratios=[8, 4, 2, 1],
|
506 |
+
drop_rate=0.0,
|
507 |
+
drop_path_rate=0.1,
|
508 |
+
**kwargs)
|
509 |
+
|
510 |
+
|
511 |
+
@manager.BACKBONES.add_component
|
512 |
+
def MixVisionTransformer_B1(**kwargs):
|
513 |
+
return MixVisionTransformer(
|
514 |
+
patch_size=4,
|
515 |
+
embed_dims=[64, 128, 320, 512],
|
516 |
+
num_heads=[1, 2, 5, 8],
|
517 |
+
mlp_ratios=[4, 4, 4, 4],
|
518 |
+
qkv_bias=True,
|
519 |
+
norm_layer=partial(
|
520 |
+
nn.LayerNorm, epsilon=1e-6),
|
521 |
+
depths=[2, 2, 2, 2],
|
522 |
+
sr_ratios=[8, 4, 2, 1],
|
523 |
+
drop_rate=0.0,
|
524 |
+
drop_path_rate=0.1,
|
525 |
+
**kwargs)
|
526 |
+
|
527 |
+
|
528 |
+
@manager.BACKBONES.add_component
|
529 |
+
def MixVisionTransformer_B2(**kwargs):
|
530 |
+
return MixVisionTransformer(
|
531 |
+
patch_size=4,
|
532 |
+
embed_dims=[64, 128, 320, 512],
|
533 |
+
num_heads=[1, 2, 5, 8],
|
534 |
+
mlp_ratios=[4, 4, 4, 4],
|
535 |
+
qkv_bias=True,
|
536 |
+
norm_layer=partial(
|
537 |
+
nn.LayerNorm, epsilon=1e-6),
|
538 |
+
depths=[3, 4, 6, 3],
|
539 |
+
sr_ratios=[8, 4, 2, 1],
|
540 |
+
drop_rate=0.0,
|
541 |
+
drop_path_rate=0.1,
|
542 |
+
**kwargs)
|
543 |
+
|
544 |
+
|
545 |
+
@manager.BACKBONES.add_component
|
546 |
+
def MixVisionTransformer_B3(**kwargs):
|
547 |
+
return MixVisionTransformer(
|
548 |
+
patch_size=4,
|
549 |
+
embed_dims=[64, 128, 320, 512],
|
550 |
+
num_heads=[1, 2, 5, 8],
|
551 |
+
mlp_ratios=[4, 4, 4, 4],
|
552 |
+
qkv_bias=True,
|
553 |
+
norm_layer=partial(
|
554 |
+
nn.LayerNorm, epsilon=1e-6),
|
555 |
+
depths=[3, 4, 18, 3],
|
556 |
+
sr_ratios=[8, 4, 2, 1],
|
557 |
+
drop_rate=0.0,
|
558 |
+
drop_path_rate=0.1,
|
559 |
+
**kwargs)
|
560 |
+
|
561 |
+
|
562 |
+
@manager.BACKBONES.add_component
|
563 |
+
def MixVisionTransformer_B4(**kwargs):
|
564 |
+
return MixVisionTransformer(
|
565 |
+
patch_size=4,
|
566 |
+
embed_dims=[64, 128, 320, 512],
|
567 |
+
num_heads=[1, 2, 5, 8],
|
568 |
+
mlp_ratios=[4, 4, 4, 4],
|
569 |
+
qkv_bias=True,
|
570 |
+
norm_layer=partial(
|
571 |
+
nn.LayerNorm, epsilon=1e-6),
|
572 |
+
depths=[3, 8, 27, 3],
|
573 |
+
sr_ratios=[8, 4, 2, 1],
|
574 |
+
drop_rate=0.0,
|
575 |
+
drop_path_rate=0.1,
|
576 |
+
**kwargs)
|
577 |
+
|
578 |
+
|
579 |
+
@manager.BACKBONES.add_component
|
580 |
+
def MixVisionTransformer_B5(**kwargs):
|
581 |
+
return MixVisionTransformer(
|
582 |
+
patch_size=4,
|
583 |
+
embed_dims=[64, 128, 320, 512],
|
584 |
+
num_heads=[1, 2, 5, 8],
|
585 |
+
mlp_ratios=[4, 4, 4, 4],
|
586 |
+
qkv_bias=True,
|
587 |
+
norm_layer=partial(
|
588 |
+
nn.LayerNorm, epsilon=1e-6),
|
589 |
+
depths=[3, 6, 40, 3],
|
590 |
+
sr_ratios=[8, 4, 2, 1],
|
591 |
+
drop_rate=0.0,
|
592 |
+
drop_path_rate=0.1,
|
593 |
+
**kwargs)
|
paddleseg/models/backbones/mobilenetv2.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import paddle
|
16 |
+
from paddle import ParamAttr
|
17 |
+
import paddle.nn as nn
|
18 |
+
import paddle.nn.functional as F
|
19 |
+
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
|
20 |
+
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
|
21 |
+
|
22 |
+
from paddleseg.cvlibs import manager
|
23 |
+
from paddleseg import utils
|
24 |
+
|
25 |
+
__all__ = [
|
26 |
+
"MobileNetV2_x0_25",
|
27 |
+
"MobileNetV2_x0_5",
|
28 |
+
"MobileNetV2_x0_75",
|
29 |
+
"MobileNetV2_x1_0",
|
30 |
+
"MobileNetV2_x1_5",
|
31 |
+
"MobileNetV2_x2_0",
|
32 |
+
]
|
33 |
+
|
34 |
+
|
35 |
+
class MobileNetV2(nn.Layer):
|
36 |
+
"""
|
37 |
+
The MobileNetV2 implementation based on PaddlePaddle.
|
38 |
+
|
39 |
+
The original article refers to
|
40 |
+
Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen
|
41 |
+
"MobileNetV2: Inverted Residuals and Linear Bottlenecks"
|
42 |
+
(https://arxiv.org/abs/1801.04381).
|
43 |
+
|
44 |
+
Args:
|
45 |
+
scale (float, optional): The scale of channel. Default: 1.0
|
46 |
+
pretrained (str, optional): The path or url of pretrained model. Default: None
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self, scale=1.0, pretrained=None):
|
50 |
+
super().__init__()
|
51 |
+
self.scale = scale
|
52 |
+
self.pretrained = pretrained
|
53 |
+
prefix_name = ""
|
54 |
+
|
55 |
+
bottleneck_params_list = [
|
56 |
+
(1, 16, 1, 1),
|
57 |
+
(6, 24, 2, 2), # x4
|
58 |
+
(6, 32, 3, 2), # x8
|
59 |
+
(6, 64, 4, 2),
|
60 |
+
(6, 96, 3, 1), # x16
|
61 |
+
(6, 160, 3, 2),
|
62 |
+
(6, 320, 1, 1), # x32
|
63 |
+
]
|
64 |
+
self.out_index = [1, 2, 4, 6]
|
65 |
+
|
66 |
+
self.conv1 = ConvBNLayer(
|
67 |
+
num_channels=3,
|
68 |
+
num_filters=int(32 * scale),
|
69 |
+
filter_size=3,
|
70 |
+
stride=2,
|
71 |
+
padding=1,
|
72 |
+
name=prefix_name + "conv1_1")
|
73 |
+
|
74 |
+
self.block_list = []
|
75 |
+
i = 1
|
76 |
+
in_c = int(32 * scale)
|
77 |
+
for layer_setting in bottleneck_params_list:
|
78 |
+
t, c, n, s = layer_setting
|
79 |
+
i += 1
|
80 |
+
block = self.add_sublayer(
|
81 |
+
prefix_name + "conv" + str(i),
|
82 |
+
sublayer=InvresiBlocks(
|
83 |
+
in_c=in_c,
|
84 |
+
t=t,
|
85 |
+
c=int(c * scale),
|
86 |
+
n=n,
|
87 |
+
s=s,
|
88 |
+
name=prefix_name + "conv" + str(i)))
|
89 |
+
self.block_list.append(block)
|
90 |
+
in_c = int(c * scale)
|
91 |
+
|
92 |
+
out_channels = [
|
93 |
+
bottleneck_params_list[idx][1] for idx in self.out_index
|
94 |
+
]
|
95 |
+
self.feat_channels = [int(c * scale) for c in out_channels]
|
96 |
+
|
97 |
+
self.init_weight()
|
98 |
+
|
99 |
+
def forward(self, inputs):
|
100 |
+
feat_list = []
|
101 |
+
|
102 |
+
y = self.conv1(inputs, if_act=True)
|
103 |
+
for idx, block in enumerate(self.block_list):
|
104 |
+
y = block(y)
|
105 |
+
if idx in self.out_index:
|
106 |
+
feat_list.append(y)
|
107 |
+
|
108 |
+
return feat_list
|
109 |
+
|
110 |
+
def init_weight(self):
|
111 |
+
if self.pretrained is not None:
|
112 |
+
utils.load_entire_model(self, self.pretrained)
|
113 |
+
|
114 |
+
|
115 |
+
class ConvBNLayer(nn.Layer):
|
116 |
+
def __init__(self,
|
117 |
+
num_channels,
|
118 |
+
filter_size,
|
119 |
+
num_filters,
|
120 |
+
stride,
|
121 |
+
padding,
|
122 |
+
channels=None,
|
123 |
+
num_groups=1,
|
124 |
+
name=None,
|
125 |
+
use_cudnn=True):
|
126 |
+
super(ConvBNLayer, self).__init__()
|
127 |
+
|
128 |
+
self._conv = Conv2D(
|
129 |
+
in_channels=num_channels,
|
130 |
+
out_channels=num_filters,
|
131 |
+
kernel_size=filter_size,
|
132 |
+
stride=stride,
|
133 |
+
padding=padding,
|
134 |
+
groups=num_groups,
|
135 |
+
weight_attr=ParamAttr(name=name + "_weights"),
|
136 |
+
bias_attr=False)
|
137 |
+
|
138 |
+
self._batch_norm = BatchNorm(
|
139 |
+
num_filters,
|
140 |
+
param_attr=ParamAttr(name=name + "_bn_scale"),
|
141 |
+
bias_attr=ParamAttr(name=name + "_bn_offset"),
|
142 |
+
moving_mean_name=name + "_bn_mean",
|
143 |
+
moving_variance_name=name + "_bn_variance")
|
144 |
+
|
145 |
+
def forward(self, inputs, if_act=True):
|
146 |
+
y = self._conv(inputs)
|
147 |
+
y = self._batch_norm(y)
|
148 |
+
if if_act:
|
149 |
+
y = F.relu6(y)
|
150 |
+
return y
|
151 |
+
|
152 |
+
|
153 |
+
class InvertedResidualUnit(nn.Layer):
|
154 |
+
def __init__(self, num_channels, num_in_filter, num_filters, stride,
|
155 |
+
filter_size, padding, expansion_factor, name):
|
156 |
+
super(InvertedResidualUnit, self).__init__()
|
157 |
+
num_expfilter = int(round(num_in_filter * expansion_factor))
|
158 |
+
self._expand_conv = ConvBNLayer(
|
159 |
+
num_channels=num_channels,
|
160 |
+
num_filters=num_expfilter,
|
161 |
+
filter_size=1,
|
162 |
+
stride=1,
|
163 |
+
padding=0,
|
164 |
+
num_groups=1,
|
165 |
+
name=name + "_expand")
|
166 |
+
|
167 |
+
self._bottleneck_conv = ConvBNLayer(
|
168 |
+
num_channels=num_expfilter,
|
169 |
+
num_filters=num_expfilter,
|
170 |
+
filter_size=filter_size,
|
171 |
+
stride=stride,
|
172 |
+
padding=padding,
|
173 |
+
num_groups=num_expfilter,
|
174 |
+
use_cudnn=False,
|
175 |
+
name=name + "_dwise")
|
176 |
+
|
177 |
+
self._linear_conv = ConvBNLayer(
|
178 |
+
num_channels=num_expfilter,
|
179 |
+
num_filters=num_filters,
|
180 |
+
filter_size=1,
|
181 |
+
stride=1,
|
182 |
+
padding=0,
|
183 |
+
num_groups=1,
|
184 |
+
name=name + "_linear")
|
185 |
+
|
186 |
+
def forward(self, inputs, ifshortcut):
|
187 |
+
y = self._expand_conv(inputs, if_act=True)
|
188 |
+
y = self._bottleneck_conv(y, if_act=True)
|
189 |
+
y = self._linear_conv(y, if_act=False)
|
190 |
+
if ifshortcut:
|
191 |
+
y = paddle.add(inputs, y)
|
192 |
+
return y
|
193 |
+
|
194 |
+
|
195 |
+
class InvresiBlocks(nn.Layer):
|
196 |
+
def __init__(self, in_c, t, c, n, s, name):
|
197 |
+
super(InvresiBlocks, self).__init__()
|
198 |
+
|
199 |
+
self._first_block = InvertedResidualUnit(
|
200 |
+
num_channels=in_c,
|
201 |
+
num_in_filter=in_c,
|
202 |
+
num_filters=c,
|
203 |
+
stride=s,
|
204 |
+
filter_size=3,
|
205 |
+
padding=1,
|
206 |
+
expansion_factor=t,
|
207 |
+
name=name + "_1")
|
208 |
+
|
209 |
+
self._block_list = []
|
210 |
+
for i in range(1, n):
|
211 |
+
block = self.add_sublayer(
|
212 |
+
name + "_" + str(i + 1),
|
213 |
+
sublayer=InvertedResidualUnit(
|
214 |
+
num_channels=c,
|
215 |
+
num_in_filter=c,
|
216 |
+
num_filters=c,
|
217 |
+
stride=1,
|
218 |
+
filter_size=3,
|
219 |
+
padding=1,
|
220 |
+
expansion_factor=t,
|
221 |
+
name=name + "_" + str(i + 1)))
|
222 |
+
self._block_list.append(block)
|
223 |
+
|
224 |
+
def forward(self, inputs):
|
225 |
+
y = self._first_block(inputs, ifshortcut=False)
|
226 |
+
for block in self._block_list:
|
227 |
+
y = block(y, ifshortcut=True)
|
228 |
+
return y
|
229 |
+
|
230 |
+
|
231 |
+
@manager.BACKBONES.add_component
|
232 |
+
def MobileNetV2_x0_25(**kwargs):
|
233 |
+
model = MobileNetV2(scale=0.25, **kwargs)
|
234 |
+
return model
|
235 |
+
|
236 |
+
|
237 |
+
@manager.BACKBONES.add_component
|
238 |
+
def MobileNetV2_x0_5(**kwargs):
|
239 |
+
model = MobileNetV2(scale=0.5, **kwargs)
|
240 |
+
return model
|
241 |
+
|
242 |
+
|
243 |
+
@manager.BACKBONES.add_component
|
244 |
+
def MobileNetV2_x0_75(**kwargs):
|
245 |
+
model = MobileNetV2(scale=0.75, **kwargs)
|
246 |
+
return model
|
247 |
+
|
248 |
+
|
249 |
+
@manager.BACKBONES.add_component
|
250 |
+
def MobileNetV2_x1_0(**kwargs):
|
251 |
+
model = MobileNetV2(scale=1.0, **kwargs)
|
252 |
+
return model
|
253 |
+
|
254 |
+
|
255 |
+
@manager.BACKBONES.add_component
|
256 |
+
def MobileNetV2_x1_5(**kwargs):
|
257 |
+
model = MobileNetV2(scale=1.5, **kwargs)
|
258 |
+
return model
|
259 |
+
|
260 |
+
|
261 |
+
@manager.BACKBONES.add_component
|
262 |
+
def MobileNetV2_x2_0(**kwargs):
|
263 |
+
model = MobileNetV2(scale=2.0, **kwargs)
|
264 |
+
return model
|
paddleseg/models/backbones/mobilenetv3.py
ADDED
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import paddle
|
16 |
+
import paddle.nn as nn
|
17 |
+
from paddle import ParamAttr
|
18 |
+
from paddle.regularizer import L2Decay
|
19 |
+
from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Dropout, Linear
|
20 |
+
|
21 |
+
from paddleseg.cvlibs import manager
|
22 |
+
from paddleseg.utils import utils, logger
|
23 |
+
from paddleseg.models import layers
|
24 |
+
|
25 |
+
__all__ = [
|
26 |
+
"MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5",
|
27 |
+
"MobileNetV3_small_x0_75", "MobileNetV3_small_x1_0",
|
28 |
+
"MobileNetV3_small_x1_25", "MobileNetV3_large_x0_35",
|
29 |
+
"MobileNetV3_large_x0_5", "MobileNetV3_large_x0_75",
|
30 |
+
"MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25"
|
31 |
+
]
|
32 |
+
|
33 |
+
MODEL_STAGES_PATTERN = {
|
34 |
+
"MobileNetV3_small": ["blocks[0]", "blocks[2]", "blocks[7]", "blocks[10]"],
|
35 |
+
"MobileNetV3_large":
|
36 |
+
["blocks[0]", "blocks[2]", "blocks[5]", "blocks[11]", "blocks[14]"]
|
37 |
+
}
|
38 |
+
|
39 |
+
# "large", "small" is just for MobinetV3_large, MobileNetV3_small respectively.
|
40 |
+
# The type of "large" or "small" config is a list. Each element(list) represents a depthwise block, which is composed of k, exp, se, act, s.
|
41 |
+
# k: kernel_size
|
42 |
+
# exp: middle channel number in depthwise block
|
43 |
+
# c: output channel number in depthwise block
|
44 |
+
# se: whether to use SE block
|
45 |
+
# act: which activation to use
|
46 |
+
# s: stride in depthwise block
|
47 |
+
# d: dilation rate in depthwise block
|
48 |
+
NET_CONFIG = {
|
49 |
+
"large": [
|
50 |
+
# k, exp, c, se, act, s
|
51 |
+
[3, 16, 16, False, "relu", 1],
|
52 |
+
[3, 64, 24, False, "relu", 2],
|
53 |
+
[3, 72, 24, False, "relu", 1], # x4
|
54 |
+
[5, 72, 40, True, "relu", 2],
|
55 |
+
[5, 120, 40, True, "relu", 1],
|
56 |
+
[5, 120, 40, True, "relu", 1], # x8
|
57 |
+
[3, 240, 80, False, "hardswish", 2],
|
58 |
+
[3, 200, 80, False, "hardswish", 1],
|
59 |
+
[3, 184, 80, False, "hardswish", 1],
|
60 |
+
[3, 184, 80, False, "hardswish", 1],
|
61 |
+
[3, 480, 112, True, "hardswish", 1],
|
62 |
+
[3, 672, 112, True, "hardswish", 1], # x16
|
63 |
+
[5, 672, 160, True, "hardswish", 2],
|
64 |
+
[5, 960, 160, True, "hardswish", 1],
|
65 |
+
[5, 960, 160, True, "hardswish", 1], # x32
|
66 |
+
],
|
67 |
+
"small": [
|
68 |
+
# k, exp, c, se, act, s
|
69 |
+
[3, 16, 16, True, "relu", 2],
|
70 |
+
[3, 72, 24, False, "relu", 2],
|
71 |
+
[3, 88, 24, False, "relu", 1],
|
72 |
+
[5, 96, 40, True, "hardswish", 2],
|
73 |
+
[5, 240, 40, True, "hardswish", 1],
|
74 |
+
[5, 240, 40, True, "hardswish", 1],
|
75 |
+
[5, 120, 48, True, "hardswish", 1],
|
76 |
+
[5, 144, 48, True, "hardswish", 1],
|
77 |
+
[5, 288, 96, True, "hardswish", 2],
|
78 |
+
[5, 576, 96, True, "hardswish", 1],
|
79 |
+
[5, 576, 96, True, "hardswish", 1],
|
80 |
+
],
|
81 |
+
"large_os8": [
|
82 |
+
# k, exp, c, se, act, s, {d}
|
83 |
+
[3, 16, 16, False, "relu", 1],
|
84 |
+
[3, 64, 24, False, "relu", 2],
|
85 |
+
[3, 72, 24, False, "relu", 1], # x4
|
86 |
+
[5, 72, 40, True, "relu", 2],
|
87 |
+
[5, 120, 40, True, "relu", 1],
|
88 |
+
[5, 120, 40, True, "relu", 1], # x8
|
89 |
+
[3, 240, 80, False, "hardswish", 1],
|
90 |
+
[3, 200, 80, False, "hardswish", 1, 2],
|
91 |
+
[3, 184, 80, False, "hardswish", 1, 2],
|
92 |
+
[3, 184, 80, False, "hardswish", 1, 2],
|
93 |
+
[3, 480, 112, True, "hardswish", 1, 2],
|
94 |
+
[3, 672, 112, True, "hardswish", 1, 2],
|
95 |
+
[5, 672, 160, True, "hardswish", 1, 2],
|
96 |
+
[5, 960, 160, True, "hardswish", 1, 4],
|
97 |
+
[5, 960, 160, True, "hardswish", 1, 4],
|
98 |
+
],
|
99 |
+
"small_os8": [
|
100 |
+
# k, exp, c, se, act, s, {d}
|
101 |
+
[3, 16, 16, True, "relu", 2],
|
102 |
+
[3, 72, 24, False, "relu", 2],
|
103 |
+
[3, 88, 24, False, "relu", 1],
|
104 |
+
[5, 96, 40, True, "hardswish", 1],
|
105 |
+
[5, 240, 40, True, "hardswish", 1, 2],
|
106 |
+
[5, 240, 40, True, "hardswish", 1, 2],
|
107 |
+
[5, 120, 48, True, "hardswish", 1, 2],
|
108 |
+
[5, 144, 48, True, "hardswish", 1, 2],
|
109 |
+
[5, 288, 96, True, "hardswish", 1, 2],
|
110 |
+
[5, 576, 96, True, "hardswish", 1, 4],
|
111 |
+
[5, 576, 96, True, "hardswish", 1, 4],
|
112 |
+
]
|
113 |
+
}
|
114 |
+
|
115 |
+
OUT_INDEX = {"large": [2, 5, 11, 14], "small": [0, 2, 7, 10]}
|
116 |
+
|
117 |
+
|
118 |
+
def _make_divisible(v, divisor=8, min_value=None):
|
119 |
+
if min_value is None:
|
120 |
+
min_value = divisor
|
121 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
122 |
+
if new_v < 0.9 * v:
|
123 |
+
new_v += divisor
|
124 |
+
return new_v
|
125 |
+
|
126 |
+
|
127 |
+
def _create_act(act):
|
128 |
+
if act == "hardswish":
|
129 |
+
return nn.Hardswish()
|
130 |
+
elif act == "relu":
|
131 |
+
return nn.ReLU()
|
132 |
+
elif act is None:
|
133 |
+
return None
|
134 |
+
else:
|
135 |
+
raise RuntimeError(
|
136 |
+
"The activation function is not supported: {}".format(act))
|
137 |
+
|
138 |
+
|
139 |
+
class MobileNetV3(nn.Layer):
|
140 |
+
"""
|
141 |
+
MobileNetV3
|
142 |
+
Args:
|
143 |
+
config: list. MobileNetV3 depthwise blocks config.
|
144 |
+
scale: float=1.0. The coefficient that controls the size of network parameters.
|
145 |
+
Returns:
|
146 |
+
model: nn.Layer. Specific MobileNetV3 model depends on args.
|
147 |
+
"""
|
148 |
+
|
149 |
+
def __init__(self,
|
150 |
+
config,
|
151 |
+
stages_pattern,
|
152 |
+
out_index,
|
153 |
+
scale=1.0,
|
154 |
+
pretrained=None):
|
155 |
+
super().__init__()
|
156 |
+
|
157 |
+
self.cfg = config
|
158 |
+
self.out_index = out_index
|
159 |
+
self.scale = scale
|
160 |
+
self.pretrained = pretrained
|
161 |
+
inplanes = 16
|
162 |
+
|
163 |
+
self.conv = ConvBNLayer(
|
164 |
+
in_c=3,
|
165 |
+
out_c=_make_divisible(inplanes * self.scale),
|
166 |
+
filter_size=3,
|
167 |
+
stride=2,
|
168 |
+
padding=1,
|
169 |
+
num_groups=1,
|
170 |
+
if_act=True,
|
171 |
+
act="hardswish")
|
172 |
+
self.blocks = nn.Sequential(*[
|
173 |
+
ResidualUnit(
|
174 |
+
in_c=_make_divisible(inplanes * self.scale if i == 0 else
|
175 |
+
self.cfg[i - 1][2] * self.scale),
|
176 |
+
mid_c=_make_divisible(self.scale * exp),
|
177 |
+
out_c=_make_divisible(self.scale * c),
|
178 |
+
filter_size=k,
|
179 |
+
stride=s,
|
180 |
+
use_se=se,
|
181 |
+
act=act,
|
182 |
+
dilation=td[0] if td else 1)
|
183 |
+
for i, (k, exp, c, se, act, s, *td) in enumerate(self.cfg)
|
184 |
+
])
|
185 |
+
|
186 |
+
out_channels = [config[idx][2] for idx in self.out_index]
|
187 |
+
self.feat_channels = [
|
188 |
+
_make_divisible(self.scale * c) for c in out_channels
|
189 |
+
]
|
190 |
+
|
191 |
+
self.init_res(stages_pattern)
|
192 |
+
self.init_weight()
|
193 |
+
|
194 |
+
def init_weight(self):
|
195 |
+
if self.pretrained is not None:
|
196 |
+
utils.load_entire_model(self, self.pretrained)
|
197 |
+
|
198 |
+
def init_res(self, stages_pattern, return_patterns=None,
|
199 |
+
return_stages=None):
|
200 |
+
if return_patterns and return_stages:
|
201 |
+
msg = f"The 'return_patterns' would be ignored when 'return_stages' is set."
|
202 |
+
logger.warning(msg)
|
203 |
+
return_stages = None
|
204 |
+
|
205 |
+
if return_stages is True:
|
206 |
+
return_patterns = stages_pattern
|
207 |
+
# return_stages is int or bool
|
208 |
+
if type(return_stages) is int:
|
209 |
+
return_stages = [return_stages]
|
210 |
+
if isinstance(return_stages, list):
|
211 |
+
if max(return_stages) > len(stages_pattern) or min(
|
212 |
+
return_stages) < 0:
|
213 |
+
msg = f"The 'return_stages' set error. Illegal value(s) have been ignored. The stages' pattern list is {stages_pattern}."
|
214 |
+
logger.warning(msg)
|
215 |
+
return_stages = [
|
216 |
+
val for val in return_stages
|
217 |
+
if val >= 0 and val < len(stages_pattern)
|
218 |
+
]
|
219 |
+
return_patterns = [stages_pattern[i] for i in return_stages]
|
220 |
+
|
221 |
+
def forward(self, x):
|
222 |
+
x = self.conv(x)
|
223 |
+
|
224 |
+
feat_list = []
|
225 |
+
for idx, block in enumerate(self.blocks):
|
226 |
+
x = block(x)
|
227 |
+
if idx in self.out_index:
|
228 |
+
feat_list.append(x)
|
229 |
+
|
230 |
+
return feat_list
|
231 |
+
|
232 |
+
|
233 |
+
class ConvBNLayer(nn.Layer):
|
234 |
+
def __init__(self,
|
235 |
+
in_c,
|
236 |
+
out_c,
|
237 |
+
filter_size,
|
238 |
+
stride,
|
239 |
+
padding,
|
240 |
+
num_groups=1,
|
241 |
+
if_act=True,
|
242 |
+
act=None,
|
243 |
+
dilation=1):
|
244 |
+
super().__init__()
|
245 |
+
|
246 |
+
self.conv = Conv2D(
|
247 |
+
in_channels=in_c,
|
248 |
+
out_channels=out_c,
|
249 |
+
kernel_size=filter_size,
|
250 |
+
stride=stride,
|
251 |
+
padding=padding,
|
252 |
+
groups=num_groups,
|
253 |
+
bias_attr=False,
|
254 |
+
dilation=dilation)
|
255 |
+
self.bn = BatchNorm(
|
256 |
+
num_channels=out_c,
|
257 |
+
act=None,
|
258 |
+
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
259 |
+
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
260 |
+
self.if_act = if_act
|
261 |
+
self.act = _create_act(act)
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
x = self.conv(x)
|
265 |
+
x = self.bn(x)
|
266 |
+
if self.if_act:
|
267 |
+
x = self.act(x)
|
268 |
+
return x
|
269 |
+
|
270 |
+
|
271 |
+
class ResidualUnit(nn.Layer):
|
272 |
+
def __init__(self,
|
273 |
+
in_c,
|
274 |
+
mid_c,
|
275 |
+
out_c,
|
276 |
+
filter_size,
|
277 |
+
stride,
|
278 |
+
use_se,
|
279 |
+
act=None,
|
280 |
+
dilation=1):
|
281 |
+
super().__init__()
|
282 |
+
self.if_shortcut = stride == 1 and in_c == out_c
|
283 |
+
self.if_se = use_se
|
284 |
+
|
285 |
+
self.expand_conv = ConvBNLayer(
|
286 |
+
in_c=in_c,
|
287 |
+
out_c=mid_c,
|
288 |
+
filter_size=1,
|
289 |
+
stride=1,
|
290 |
+
padding=0,
|
291 |
+
if_act=True,
|
292 |
+
act=act)
|
293 |
+
self.bottleneck_conv = ConvBNLayer(
|
294 |
+
in_c=mid_c,
|
295 |
+
out_c=mid_c,
|
296 |
+
filter_size=filter_size,
|
297 |
+
stride=stride,
|
298 |
+
padding=int((filter_size - 1) // 2) * dilation,
|
299 |
+
num_groups=mid_c,
|
300 |
+
if_act=True,
|
301 |
+
act=act,
|
302 |
+
dilation=dilation)
|
303 |
+
if self.if_se:
|
304 |
+
self.mid_se = SEModule(mid_c)
|
305 |
+
self.linear_conv = ConvBNLayer(
|
306 |
+
in_c=mid_c,
|
307 |
+
out_c=out_c,
|
308 |
+
filter_size=1,
|
309 |
+
stride=1,
|
310 |
+
padding=0,
|
311 |
+
if_act=False,
|
312 |
+
act=None)
|
313 |
+
|
314 |
+
def forward(self, x):
|
315 |
+
identity = x
|
316 |
+
x = self.expand_conv(x)
|
317 |
+
x = self.bottleneck_conv(x)
|
318 |
+
if self.if_se:
|
319 |
+
x = self.mid_se(x)
|
320 |
+
x = self.linear_conv(x)
|
321 |
+
if self.if_shortcut:
|
322 |
+
x = paddle.add(identity, x)
|
323 |
+
return x
|
324 |
+
|
325 |
+
|
326 |
+
# nn.Hardsigmoid can't transfer "slope" and "offset" in nn.functional.hardsigmoid
|
327 |
+
class Hardsigmoid(nn.Layer):
|
328 |
+
def __init__(self, slope=0.2, offset=0.5):
|
329 |
+
super().__init__()
|
330 |
+
self.slope = slope
|
331 |
+
self.offset = offset
|
332 |
+
|
333 |
+
def forward(self, x):
|
334 |
+
return nn.functional.hardsigmoid(
|
335 |
+
x, slope=self.slope, offset=self.offset)
|
336 |
+
|
337 |
+
|
338 |
+
class SEModule(nn.Layer):
|
339 |
+
def __init__(self, channel, reduction=4):
|
340 |
+
super().__init__()
|
341 |
+
self.avg_pool = AdaptiveAvgPool2D(1)
|
342 |
+
self.conv1 = Conv2D(
|
343 |
+
in_channels=channel,
|
344 |
+
out_channels=channel // reduction,
|
345 |
+
kernel_size=1,
|
346 |
+
stride=1,
|
347 |
+
padding=0)
|
348 |
+
self.relu = nn.ReLU()
|
349 |
+
self.conv2 = Conv2D(
|
350 |
+
in_channels=channel // reduction,
|
351 |
+
out_channels=channel,
|
352 |
+
kernel_size=1,
|
353 |
+
stride=1,
|
354 |
+
padding=0)
|
355 |
+
self.hardsigmoid = Hardsigmoid(slope=0.2, offset=0.5)
|
356 |
+
|
357 |
+
def forward(self, x):
|
358 |
+
identity = x
|
359 |
+
x = self.avg_pool(x)
|
360 |
+
x = self.conv1(x)
|
361 |
+
x = self.relu(x)
|
362 |
+
x = self.conv2(x)
|
363 |
+
x = self.hardsigmoid(x)
|
364 |
+
return paddle.multiply(x=identity, y=x)
|
365 |
+
|
366 |
+
|
367 |
+
@manager.BACKBONES.add_component
|
368 |
+
def MobileNetV3_small_x0_35(**kwargs):
|
369 |
+
model = MobileNetV3(
|
370 |
+
config=NET_CONFIG["small"],
|
371 |
+
scale=0.35,
|
372 |
+
stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
|
373 |
+
out_index=OUT_INDEX["small"],
|
374 |
+
**kwargs)
|
375 |
+
return model
|
376 |
+
|
377 |
+
|
378 |
+
@manager.BACKBONES.add_component
|
379 |
+
def MobileNetV3_small_x0_5(**kwargs):
|
380 |
+
model = MobileNetV3(
|
381 |
+
config=NET_CONFIG["small"],
|
382 |
+
scale=0.5,
|
383 |
+
stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
|
384 |
+
out_index=OUT_INDEX["small"],
|
385 |
+
**kwargs)
|
386 |
+
return model
|
387 |
+
|
388 |
+
|
389 |
+
@manager.BACKBONES.add_component
|
390 |
+
def MobileNetV3_small_x0_75(**kwargs):
|
391 |
+
model = MobileNetV3(
|
392 |
+
config=NET_CONFIG["small"],
|
393 |
+
scale=0.75,
|
394 |
+
stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
|
395 |
+
out_index=OUT_INDEX["small"],
|
396 |
+
**kwargs)
|
397 |
+
return model
|
398 |
+
|
399 |
+
|
400 |
+
@manager.BACKBONES.add_component
|
401 |
+
def MobileNetV3_small_x1_0(**kwargs):
|
402 |
+
model = MobileNetV3(
|
403 |
+
config=NET_CONFIG["small"],
|
404 |
+
scale=1.0,
|
405 |
+
stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
|
406 |
+
out_index=OUT_INDEX["small"],
|
407 |
+
**kwargs)
|
408 |
+
return model
|
409 |
+
|
410 |
+
|
411 |
+
@manager.BACKBONES.add_component
|
412 |
+
def MobileNetV3_small_x1_25(**kwargs):
|
413 |
+
model = MobileNetV3(
|
414 |
+
config=NET_CONFIG["small"],
|
415 |
+
scale=1.25,
|
416 |
+
stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
|
417 |
+
out_index=OUT_INDEX["small"],
|
418 |
+
**kwargs)
|
419 |
+
return model
|
420 |
+
|
421 |
+
|
422 |
+
@manager.BACKBONES.add_component
|
423 |
+
def MobileNetV3_large_x0_35(**kwargs):
|
424 |
+
model = MobileNetV3(
|
425 |
+
config=NET_CONFIG["large"],
|
426 |
+
scale=0.35,
|
427 |
+
stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
|
428 |
+
out_index=OUT_INDEX["large"],
|
429 |
+
**kwargs)
|
430 |
+
return model
|
431 |
+
|
432 |
+
|
433 |
+
@manager.BACKBONES.add_component
|
434 |
+
def MobileNetV3_large_x0_5(**kwargs):
|
435 |
+
model = MobileNetV3(
|
436 |
+
config=NET_CONFIG["large"],
|
437 |
+
scale=0.5,
|
438 |
+
stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"],
|
439 |
+
out_index=OUT_INDEX["large"],
|
440 |
+
**kwargs)
|
441 |
+
return model
|
442 |
+
|
443 |
+
|
444 |
+
@manager.BACKBONES.add_component
|
445 |
+
def MobileNetV3_large_x0_75(**kwargs):
|
446 |
+
model = MobileNetV3(
|
447 |
+
config=NET_CONFIG["large"],
|
448 |
+
scale=0.75,
|
449 |
+
stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"],
|
450 |
+
out_index=OUT_INDEX["large"],
|
451 |
+
**kwargs)
|
452 |
+
return model
|
453 |
+
|
454 |
+
|
455 |
+
@manager.BACKBONES.add_component
|
456 |
+
def MobileNetV3_large_x1_0(**kwargs):
|
457 |
+
model = MobileNetV3(
|
458 |
+
config=NET_CONFIG["large"],
|
459 |
+
scale=1.0,
|
460 |
+
stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"],
|
461 |
+
out_index=OUT_INDEX["large"],
|
462 |
+
**kwargs)
|
463 |
+
return model
|
464 |
+
|
465 |
+
|
466 |
+
@manager.BACKBONES.add_component
|
467 |
+
def MobileNetV3_large_x1_25(**kwargs):
|
468 |
+
model = MobileNetV3(
|
469 |
+
config=NET_CONFIG["large"],
|
470 |
+
scale=1.25,
|
471 |
+
stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"],
|
472 |
+
out_index=OUT_INDEX["large"],
|
473 |
+
**kwargs)
|
474 |
+
return model
|
475 |
+
|
476 |
+
|
477 |
+
@manager.BACKBONES.add_component
|
478 |
+
def MobileNetV3_large_x1_0_os8(**kwargs):
|
479 |
+
model = MobileNetV3(
|
480 |
+
config=NET_CONFIG["large_os8"],
|
481 |
+
scale=1.0,
|
482 |
+
stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_large"],
|
483 |
+
out_index=OUT_INDEX["large"],
|
484 |
+
**kwargs)
|
485 |
+
return model
|
486 |
+
|
487 |
+
|
488 |
+
@manager.BACKBONES.add_component
|
489 |
+
def MobileNetV3_small_x1_0_os8(**kwargs):
|
490 |
+
model = MobileNetV3(
|
491 |
+
config=NET_CONFIG["small_os8"],
|
492 |
+
scale=1.0,
|
493 |
+
stages_pattern=MODEL_STAGES_PATTERN["MobileNetV3_small"],
|
494 |
+
out_index=OUT_INDEX["small"],
|
495 |
+
**kwargs)
|
496 |
+
return model
|