Add files using upload-large-folder tool
Browse files- ktda/models/segmentors/__init__.py +3 -0
- tools/analysis_tools/browse_dataset.py +77 -0
- tools/dataset_converters/drive.py +114 -0
- tools/dataset_converters/vaihingen.py +156 -0
- tools/dataset_tools/analysis_dataset.py +58 -0
- tools/dataset_tools/process_water.py +152 -0
- tools/deployment/pytorch2torchscript.py +185 -0
- tools/model_converters/beit2mmseg.py +56 -0
ktda/models/segmentors/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .distill_encoder_decoder import DistillEncoderDecoder
|
2 |
+
|
3 |
+
__all__ = ['DistillEncoderDecoder']
|
tools/analysis_tools/browse_dataset.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import os.path as osp
|
4 |
+
|
5 |
+
from mmengine.config import Config, DictAction
|
6 |
+
from mmengine.utils import ProgressBar
|
7 |
+
|
8 |
+
from mmseg.registry import DATASETS, VISUALIZERS
|
9 |
+
from mmseg.utils import register_all_modules
|
10 |
+
|
11 |
+
|
12 |
+
def parse_args():
|
13 |
+
parser = argparse.ArgumentParser(description='Browse a dataset')
|
14 |
+
parser.add_argument('config', help='train config file path')
|
15 |
+
parser.add_argument(
|
16 |
+
'--output-dir',
|
17 |
+
default=None,
|
18 |
+
type=str,
|
19 |
+
help='If there is no display interface, you can save it')
|
20 |
+
parser.add_argument('--not-show', default=False, action='store_true')
|
21 |
+
parser.add_argument(
|
22 |
+
'--show-interval',
|
23 |
+
type=float,
|
24 |
+
default=2,
|
25 |
+
help='the interval of show (s)')
|
26 |
+
parser.add_argument(
|
27 |
+
'--cfg-options',
|
28 |
+
nargs='+',
|
29 |
+
action=DictAction,
|
30 |
+
help='override some settings in the used config, the key-value pair '
|
31 |
+
'in xxx=yyy format will be merged into config file. If the value to '
|
32 |
+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
33 |
+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
34 |
+
'Note that the quotation marks are necessary and that no white space '
|
35 |
+
'is allowed.')
|
36 |
+
args = parser.parse_args()
|
37 |
+
return args
|
38 |
+
|
39 |
+
|
40 |
+
def main():
|
41 |
+
args = parse_args()
|
42 |
+
cfg = Config.fromfile(args.config)
|
43 |
+
if args.cfg_options is not None:
|
44 |
+
cfg.merge_from_dict(args.cfg_options)
|
45 |
+
|
46 |
+
# register all modules in mmdet into the registries
|
47 |
+
register_all_modules()
|
48 |
+
|
49 |
+
dataset = DATASETS.build(cfg.train_dataloader.dataset)
|
50 |
+
visualizer = VISUALIZERS.build(cfg.visualizer)
|
51 |
+
visualizer.dataset_meta = dataset.metainfo
|
52 |
+
|
53 |
+
progress_bar = ProgressBar(len(dataset))
|
54 |
+
for item in dataset:
|
55 |
+
img = item['inputs'].permute(1, 2, 0).numpy()
|
56 |
+
img = img[..., [2, 1, 0]] # bgr to rgb
|
57 |
+
data_sample = item['data_samples'].numpy()
|
58 |
+
img_path = osp.basename(item['data_samples'].img_path)
|
59 |
+
|
60 |
+
out_file = osp.join(
|
61 |
+
args.output_dir,
|
62 |
+
osp.basename(img_path)) if args.output_dir is not None else None
|
63 |
+
|
64 |
+
visualizer.add_datasample(
|
65 |
+
name=osp.basename(img_path),
|
66 |
+
image=img,
|
67 |
+
data_sample=data_sample,
|
68 |
+
draw_gt=True,
|
69 |
+
draw_pred=False,
|
70 |
+
wait_time=args.show_interval,
|
71 |
+
out_file=out_file,
|
72 |
+
show=not args.not_show)
|
73 |
+
progress_bar.update()
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == '__main__':
|
77 |
+
main()
|
tools/dataset_converters/drive.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import tempfile
|
6 |
+
import zipfile
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import mmcv
|
10 |
+
from mmengine.utils import mkdir_or_exist
|
11 |
+
|
12 |
+
|
13 |
+
def parse_args():
|
14 |
+
parser = argparse.ArgumentParser(
|
15 |
+
description='Convert DRIVE dataset to mmsegmentation format')
|
16 |
+
parser.add_argument(
|
17 |
+
'training_path', help='the training part of DRIVE dataset')
|
18 |
+
parser.add_argument(
|
19 |
+
'testing_path', help='the testing part of DRIVE dataset')
|
20 |
+
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
21 |
+
parser.add_argument('-o', '--out_dir', help='output path')
|
22 |
+
args = parser.parse_args()
|
23 |
+
return args
|
24 |
+
|
25 |
+
|
26 |
+
def main():
|
27 |
+
args = parse_args()
|
28 |
+
training_path = args.training_path
|
29 |
+
testing_path = args.testing_path
|
30 |
+
if args.out_dir is None:
|
31 |
+
out_dir = osp.join('data', 'DRIVE')
|
32 |
+
else:
|
33 |
+
out_dir = args.out_dir
|
34 |
+
|
35 |
+
print('Making directories...')
|
36 |
+
mkdir_or_exist(out_dir)
|
37 |
+
mkdir_or_exist(osp.join(out_dir, 'images'))
|
38 |
+
mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
|
39 |
+
mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
|
40 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations'))
|
41 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
|
42 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
|
43 |
+
|
44 |
+
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
45 |
+
print('Extracting training.zip...')
|
46 |
+
zip_file = zipfile.ZipFile(training_path)
|
47 |
+
zip_file.extractall(tmp_dir)
|
48 |
+
|
49 |
+
print('Generating training dataset...')
|
50 |
+
now_dir = osp.join(tmp_dir, 'training', 'images')
|
51 |
+
for img_name in os.listdir(now_dir):
|
52 |
+
img = mmcv.imread(osp.join(now_dir, img_name))
|
53 |
+
mmcv.imwrite(
|
54 |
+
img,
|
55 |
+
osp.join(
|
56 |
+
out_dir, 'images', 'training',
|
57 |
+
osp.splitext(img_name)[0].replace('_training', '') +
|
58 |
+
'.png'))
|
59 |
+
|
60 |
+
now_dir = osp.join(tmp_dir, 'training', '1st_manual')
|
61 |
+
for img_name in os.listdir(now_dir):
|
62 |
+
cap = cv2.VideoCapture(osp.join(now_dir, img_name))
|
63 |
+
ret, img = cap.read()
|
64 |
+
mmcv.imwrite(
|
65 |
+
img[:, :, 0] // 128,
|
66 |
+
osp.join(out_dir, 'annotations', 'training',
|
67 |
+
osp.splitext(img_name)[0] + '.png'))
|
68 |
+
|
69 |
+
print('Extracting test.zip...')
|
70 |
+
zip_file = zipfile.ZipFile(testing_path)
|
71 |
+
zip_file.extractall(tmp_dir)
|
72 |
+
|
73 |
+
print('Generating validation dataset...')
|
74 |
+
now_dir = osp.join(tmp_dir, 'test', 'images')
|
75 |
+
for img_name in os.listdir(now_dir):
|
76 |
+
img = mmcv.imread(osp.join(now_dir, img_name))
|
77 |
+
mmcv.imwrite(
|
78 |
+
img,
|
79 |
+
osp.join(
|
80 |
+
out_dir, 'images', 'validation',
|
81 |
+
osp.splitext(img_name)[0].replace('_test', '') + '.png'))
|
82 |
+
|
83 |
+
now_dir = osp.join(tmp_dir, 'test', '1st_manual')
|
84 |
+
if osp.exists(now_dir):
|
85 |
+
for img_name in os.listdir(now_dir):
|
86 |
+
cap = cv2.VideoCapture(osp.join(now_dir, img_name))
|
87 |
+
ret, img = cap.read()
|
88 |
+
# The annotation img should be divided by 128, because some of
|
89 |
+
# the annotation imgs are not standard. We should set a
|
90 |
+
# threshold to convert the nonstandard annotation imgs. The
|
91 |
+
# value divided by 128 is equivalent to '1 if value >= 128
|
92 |
+
# else 0'
|
93 |
+
mmcv.imwrite(
|
94 |
+
img[:, :, 0] // 128,
|
95 |
+
osp.join(out_dir, 'annotations', 'validation',
|
96 |
+
osp.splitext(img_name)[0] + '.png'))
|
97 |
+
|
98 |
+
now_dir = osp.join(tmp_dir, 'test', '2nd_manual')
|
99 |
+
if osp.exists(now_dir):
|
100 |
+
for img_name in os.listdir(now_dir):
|
101 |
+
cap = cv2.VideoCapture(osp.join(now_dir, img_name))
|
102 |
+
ret, img = cap.read()
|
103 |
+
mmcv.imwrite(
|
104 |
+
img[:, :, 0] // 128,
|
105 |
+
osp.join(out_dir, 'annotations', 'validation',
|
106 |
+
osp.splitext(img_name)[0] + '.png'))
|
107 |
+
|
108 |
+
print('Removing the temporary files...')
|
109 |
+
|
110 |
+
print('Done!')
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == '__main__':
|
114 |
+
main()
|
tools/dataset_converters/vaihingen.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import glob
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import os.path as osp
|
7 |
+
import tempfile
|
8 |
+
import zipfile
|
9 |
+
|
10 |
+
import mmcv
|
11 |
+
import numpy as np
|
12 |
+
from mmengine.utils import ProgressBar, mkdir_or_exist
|
13 |
+
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
parser = argparse.ArgumentParser(
|
17 |
+
description='Convert vaihingen dataset to mmsegmentation format')
|
18 |
+
parser.add_argument('dataset_path', help='vaihingen folder path')
|
19 |
+
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
20 |
+
parser.add_argument('-o', '--out_dir', help='output path')
|
21 |
+
parser.add_argument(
|
22 |
+
'--clip_size',
|
23 |
+
type=int,
|
24 |
+
help='clipped size of image after preparation',
|
25 |
+
default=512)
|
26 |
+
parser.add_argument(
|
27 |
+
'--stride_size',
|
28 |
+
type=int,
|
29 |
+
help='stride of clipping original images',
|
30 |
+
default=256)
|
31 |
+
args = parser.parse_args()
|
32 |
+
return args
|
33 |
+
|
34 |
+
|
35 |
+
def clip_big_image(image_path, clip_save_dir, to_label=False):
|
36 |
+
# Original image of Vaihingen dataset is very large, thus pre-processing
|
37 |
+
# of them is adopted. Given fixed clip size and stride size to generate
|
38 |
+
# clipped image, the intersection of width and height is determined.
|
39 |
+
# For example, given one 5120 x 5120 original image, the clip size is
|
40 |
+
# 512 and stride size is 256, thus it would generate 20x20 = 400 images
|
41 |
+
# whose size are all 512x512.
|
42 |
+
image = mmcv.imread(image_path)
|
43 |
+
|
44 |
+
h, w, c = image.shape
|
45 |
+
cs = args.clip_size
|
46 |
+
ss = args.stride_size
|
47 |
+
|
48 |
+
num_rows = math.ceil((h - cs) / ss) if math.ceil(
|
49 |
+
(h - cs) / ss) * ss + cs >= h else math.ceil((h - cs) / ss) + 1
|
50 |
+
num_cols = math.ceil((w - cs) / ss) if math.ceil(
|
51 |
+
(w - cs) / ss) * ss + cs >= w else math.ceil((w - cs) / ss) + 1
|
52 |
+
|
53 |
+
x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
|
54 |
+
xmin = x * cs
|
55 |
+
ymin = y * cs
|
56 |
+
|
57 |
+
xmin = xmin.ravel()
|
58 |
+
ymin = ymin.ravel()
|
59 |
+
xmin_offset = np.where(xmin + cs > w, w - xmin - cs, np.zeros_like(xmin))
|
60 |
+
ymin_offset = np.where(ymin + cs > h, h - ymin - cs, np.zeros_like(ymin))
|
61 |
+
boxes = np.stack([
|
62 |
+
xmin + xmin_offset, ymin + ymin_offset,
|
63 |
+
np.minimum(xmin + cs, w),
|
64 |
+
np.minimum(ymin + cs, h)
|
65 |
+
],
|
66 |
+
axis=1)
|
67 |
+
|
68 |
+
if to_label:
|
69 |
+
color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0],
|
70 |
+
[255, 255, 0], [0, 255, 0], [0, 255, 255],
|
71 |
+
[0, 0, 255]])
|
72 |
+
flatten_v = np.matmul(
|
73 |
+
image.reshape(-1, c),
|
74 |
+
np.array([2, 3, 4]).reshape(3, 1))
|
75 |
+
out = np.zeros_like(flatten_v)
|
76 |
+
for idx, class_color in enumerate(color_map):
|
77 |
+
value_idx = np.matmul(class_color,
|
78 |
+
np.array([2, 3, 4]).reshape(3, 1))
|
79 |
+
out[flatten_v == value_idx] = idx
|
80 |
+
image = out.reshape(h, w)
|
81 |
+
|
82 |
+
for box in boxes:
|
83 |
+
start_x, start_y, end_x, end_y = box
|
84 |
+
clipped_image = image[start_y:end_y,
|
85 |
+
start_x:end_x] if to_label else image[
|
86 |
+
start_y:end_y, start_x:end_x, :]
|
87 |
+
area_idx = osp.basename(image_path).split('_')[3].strip('.tif')
|
88 |
+
mmcv.imwrite(
|
89 |
+
clipped_image.astype(np.uint8),
|
90 |
+
osp.join(clip_save_dir,
|
91 |
+
f'{area_idx}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
|
92 |
+
|
93 |
+
|
94 |
+
def main():
|
95 |
+
splits = {
|
96 |
+
'train': [
|
97 |
+
'area1', 'area11', 'area13', 'area15', 'area17', 'area21',
|
98 |
+
'area23', 'area26', 'area28', 'area3', 'area30', 'area32',
|
99 |
+
'area34', 'area37', 'area5', 'area7'
|
100 |
+
],
|
101 |
+
'val': [
|
102 |
+
'area6', 'area24', 'area35', 'area16', 'area14', 'area22',
|
103 |
+
'area10', 'area4', 'area2', 'area20', 'area8', 'area31', 'area33',
|
104 |
+
'area27', 'area38', 'area12', 'area29'
|
105 |
+
],
|
106 |
+
}
|
107 |
+
|
108 |
+
dataset_path = args.dataset_path
|
109 |
+
if args.out_dir is None:
|
110 |
+
out_dir = osp.join('data', 'vaihingen')
|
111 |
+
else:
|
112 |
+
out_dir = args.out_dir
|
113 |
+
|
114 |
+
print('Making directories...')
|
115 |
+
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
|
116 |
+
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
|
117 |
+
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
|
118 |
+
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
|
119 |
+
|
120 |
+
zipp_list = glob.glob(os.path.join(dataset_path, '*.zip'))
|
121 |
+
print('Find the data', zipp_list)
|
122 |
+
|
123 |
+
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
124 |
+
for zipp in zipp_list:
|
125 |
+
zip_file = zipfile.ZipFile(zipp)
|
126 |
+
zip_file.extractall(tmp_dir)
|
127 |
+
src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif'))
|
128 |
+
if 'ISPRS_semantic_labeling_Vaihingen' in zipp:
|
129 |
+
src_path_list = glob.glob(
|
130 |
+
os.path.join(os.path.join(tmp_dir, 'top'), '*.tif'))
|
131 |
+
if 'ISPRS_semantic_labeling_Vaihingen_ground_truth_eroded_COMPLETE' in zipp: # noqa
|
132 |
+
src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif'))
|
133 |
+
# delete unused area9 ground truth
|
134 |
+
for area_ann in src_path_list:
|
135 |
+
if 'area9' in area_ann:
|
136 |
+
src_path_list.remove(area_ann)
|
137 |
+
prog_bar = ProgressBar(len(src_path_list))
|
138 |
+
for i, src_path in enumerate(src_path_list):
|
139 |
+
area_idx = osp.basename(src_path).split('_')[3].strip('.tif')
|
140 |
+
data_type = 'train' if area_idx in splits['train'] else 'val'
|
141 |
+
if 'noBoundary' in src_path:
|
142 |
+
dst_dir = osp.join(out_dir, 'ann_dir', data_type)
|
143 |
+
clip_big_image(src_path, dst_dir, to_label=True)
|
144 |
+
else:
|
145 |
+
dst_dir = osp.join(out_dir, 'img_dir', data_type)
|
146 |
+
clip_big_image(src_path, dst_dir, to_label=False)
|
147 |
+
prog_bar.update()
|
148 |
+
|
149 |
+
print('Removing the temporary files...')
|
150 |
+
|
151 |
+
print('Done!')
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == '__main__':
|
155 |
+
args = parse_args()
|
156 |
+
main()
|
tools/dataset_tools/analysis_dataset.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from glob import glob
|
2 |
+
from typing import Tuple,List
|
3 |
+
import os
|
4 |
+
import argparse
|
5 |
+
import json
|
6 |
+
from matplotlib import pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
def get_args()->Tuple[str, str]:
|
11 |
+
"""
|
12 |
+
Return:
|
13 |
+
--dataset_dir: dataset dir.
|
14 |
+
--save_dir: save dir.
|
15 |
+
"""
|
16 |
+
parser = argparse.ArgumentParser()
|
17 |
+
parser.add_argument('--dataset_dir', type=str, default='data/grass')
|
18 |
+
parser.add_argument('--save_dir', type=str, default='dataset_num_analysis.png')
|
19 |
+
args = parser.parse_args()
|
20 |
+
return args.dataset_dir, args.save_dir
|
21 |
+
|
22 |
+
def get_mask_files(dataset_dir: str)->List[str]:
|
23 |
+
"""
|
24 |
+
get mask files from dataset dir.
|
25 |
+
Args:
|
26 |
+
dataset_dir: dataset dir.
|
27 |
+
Return:
|
28 |
+
mask_filenames: list of mask filenames.
|
29 |
+
"""
|
30 |
+
mask_filenames = glob(os.path.join(dataset_dir, "ann_dir", "*", "*.png"))
|
31 |
+
return mask_filenames
|
32 |
+
|
33 |
+
def main():
|
34 |
+
dataset_dir, save_dir = get_args()
|
35 |
+
mask_filenames = get_mask_files(dataset_dir)
|
36 |
+
statistic = {}
|
37 |
+
for mask_filename in mask_filenames:
|
38 |
+
mask = np.array(Image.open(mask_filename))
|
39 |
+
classes = np.unique(mask)
|
40 |
+
for class_ in classes:
|
41 |
+
class_ = int(class_)
|
42 |
+
if class_ not in statistic:
|
43 |
+
statistic[class_] = 0
|
44 |
+
statistic[(class_)] += int(np.sum(mask == class_))
|
45 |
+
|
46 |
+
classes = list(statistic.keys())
|
47 |
+
clasees_num = list(statistic.values())
|
48 |
+
|
49 |
+
plt.title("Dataset Analysis")
|
50 |
+
bars = plt.bar(classes, clasees_num)
|
51 |
+
for bar in bars:
|
52 |
+
height = bar.get_height()
|
53 |
+
plt.text(bar.get_x() + bar.get_width() / 2, height + 5, str(height), ha='center', va='bottom')
|
54 |
+
plt.savefig(save_dir,dpi=300)
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
main()
|
tools/dataset_tools/process_water.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shutil
|
2 |
+
from glob import glob
|
3 |
+
import os
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
from rich.progress import track
|
7 |
+
from PIL import Image
|
8 |
+
from typing import List
|
9 |
+
from vegseg.datasets import WaterDataset
|
10 |
+
from sklearn.model_selection import train_test_split
|
11 |
+
|
12 |
+
|
13 |
+
def get_args():
|
14 |
+
parse = argparse.ArgumentParser()
|
15 |
+
parse.add_argument("--raw_path", type=str)
|
16 |
+
parse.add_argument("--tmp_dir", type=str)
|
17 |
+
parse.add_argument("--save_path", type=str)
|
18 |
+
args = parse.parse_args()
|
19 |
+
return args.raw_path, args.tmp_dir, args.save_path
|
20 |
+
|
21 |
+
|
22 |
+
def get_palette() -> List[int]:
|
23 |
+
"""
|
24 |
+
get palette of dataset.
|
25 |
+
return:
|
26 |
+
palette: list of palette.
|
27 |
+
"""
|
28 |
+
palette = []
|
29 |
+
palette_list = WaterDataset.METAINFO["palette"]
|
30 |
+
for palette_item in palette_list:
|
31 |
+
palette.extend(palette_item)
|
32 |
+
return palette
|
33 |
+
|
34 |
+
|
35 |
+
def create_dataset(image_list, ann_list, image_dir, ann_dir, description="Working..."):
|
36 |
+
os.makedirs(image_dir, exist_ok=True)
|
37 |
+
os.makedirs(ann_dir, exist_ok=True)
|
38 |
+
for image_path, ann_path in track(
|
39 |
+
zip(image_list, ann_list), total=len(image_list), description=description
|
40 |
+
):
|
41 |
+
base_name = os.path.basename(image_path)
|
42 |
+
|
43 |
+
new_image_path = os.path.join(image_dir, base_name)
|
44 |
+
new_ann_path = os.path.join(ann_dir, base_name)
|
45 |
+
|
46 |
+
shutil.move(image_path, new_image_path)
|
47 |
+
shutil.move(ann_path, new_ann_path)
|
48 |
+
|
49 |
+
mask = Image.open(new_ann_path).convert("P")
|
50 |
+
palette = get_palette()
|
51 |
+
mask.putpalette(palette)
|
52 |
+
mask.save(new_ann_path)
|
53 |
+
|
54 |
+
|
55 |
+
def main():
|
56 |
+
classes_mapping = {
|
57 |
+
"CDUWD-1": 1,
|
58 |
+
"CDUWD-2": 2,
|
59 |
+
"CDUWD-3": 3,
|
60 |
+
"CDUWD-4": 4,
|
61 |
+
"CDUWD-5": 5,
|
62 |
+
"CDUWD-6": 0,
|
63 |
+
}
|
64 |
+
|
65 |
+
raw_path, tmp_dir, save_path = get_args()
|
66 |
+
|
67 |
+
all_images = glob(os.path.join(raw_path, "*", "images", "*.png"))
|
68 |
+
|
69 |
+
all_labels = [image_path.replace("images", "labels") for image_path in all_images]
|
70 |
+
|
71 |
+
target_image_dir = os.path.join(tmp_dir, "images")
|
72 |
+
target_label_dir = os.path.join(tmp_dir, "labels")
|
73 |
+
|
74 |
+
os.makedirs(target_image_dir, exist_ok=True)
|
75 |
+
os.makedirs(target_label_dir, exist_ok=True)
|
76 |
+
|
77 |
+
for image_path, label_path in track(
|
78 |
+
zip(all_images, all_labels), total=len(all_images), description="fuse dataset"
|
79 |
+
):
|
80 |
+
exists_images = glob(os.path.join(target_image_dir, "*.png"))
|
81 |
+
|
82 |
+
base_name = os.path.basename(image_path)
|
83 |
+
if image_path not in exists_images:
|
84 |
+
mask = np.array(Image.open(label_path))
|
85 |
+
|
86 |
+
assert list(np.unique(mask)) in [
|
87 |
+
[0],
|
88 |
+
[1],
|
89 |
+
[0, 1],
|
90 |
+
[1, 0],
|
91 |
+
], f"The mask image is not binary (it should only contain 0s and 1s),actually is {set(np.unique(mask))}"
|
92 |
+
|
93 |
+
classes_str = image_path.split(os.path.sep)[-3]
|
94 |
+
classes = classes_mapping[classes_str]
|
95 |
+
mask = np.where(mask == 1, classes, mask)
|
96 |
+
|
97 |
+
# print(classes_str)
|
98 |
+
|
99 |
+
mask = Image.fromarray(mask)
|
100 |
+
mask.save(os.path.join(target_label_dir, base_name))
|
101 |
+
shutil.copy(image_path, os.path.join(target_image_dir, base_name))
|
102 |
+
else:
|
103 |
+
|
104 |
+
exists_label_path = os.path.join(target_label_dir, base_name)
|
105 |
+
exists_mask = np.array(Image.open(exists_label_path))
|
106 |
+
|
107 |
+
mask = np.array(Image.open(label_path))
|
108 |
+
assert list(np.unique(mask)) in [
|
109 |
+
[0],
|
110 |
+
[1],
|
111 |
+
[0, 1],
|
112 |
+
[1, 0],
|
113 |
+
], f"The mask image is not binary (it should only contain 0s and 1s),actually is {set(np.unique(mask))}"
|
114 |
+
classes_str = image_path.split(os.path.sep)[-3]
|
115 |
+
classes = classes_mapping[classes_str]
|
116 |
+
|
117 |
+
exists_mask = np.where(mask == 1, classes, exists_mask)
|
118 |
+
|
119 |
+
exists_mask = Image.fromarray(exists_mask)
|
120 |
+
exists_mask.save(exists_label_path)
|
121 |
+
|
122 |
+
exists_images = glob(os.path.join(target_image_dir, "*.png"))
|
123 |
+
|
124 |
+
exists_labels = [
|
125 |
+
image_path.replace("images", "labels") for image_path in exists_images
|
126 |
+
]
|
127 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
128 |
+
exists_images, exists_labels, test_size=0.2, random_state=42, shuffle=True
|
129 |
+
)
|
130 |
+
|
131 |
+
create_dataset(
|
132 |
+
X_train,
|
133 |
+
y_train,
|
134 |
+
os.path.join(save_path, "img_dir", "train"),
|
135 |
+
os.path.join(save_path, "ann_dir", "train"),
|
136 |
+
description="train dataset",
|
137 |
+
)
|
138 |
+
create_dataset(
|
139 |
+
X_test,
|
140 |
+
y_test,
|
141 |
+
os.path.join(save_path, "img_dir", "val"),
|
142 |
+
os.path.join(save_path, "ann_dir", "val"),
|
143 |
+
description="val dataset",
|
144 |
+
)
|
145 |
+
|
146 |
+
os.rmdir(target_image_dir)
|
147 |
+
os.rmdir(target_label_dir)
|
148 |
+
|
149 |
+
|
150 |
+
if __name__ == "__main__":
|
151 |
+
# example python tools/dataset_tools/process_water.py --raw_path data/raw_water_dataset/1024 --tmp_dir data/raw_water_dataset/1024/all_dataset --save_path data/water_1024_1024
|
152 |
+
main()
|
tools/deployment/pytorch2torchscript.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch._C
|
7 |
+
import torch.serialization
|
8 |
+
from mmengine import Config
|
9 |
+
from mmengine.runner import load_checkpoint
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
from mmseg.models import build_segmentor
|
13 |
+
|
14 |
+
torch.manual_seed(3)
|
15 |
+
|
16 |
+
|
17 |
+
def digit_version(version_str):
|
18 |
+
digit_version = []
|
19 |
+
for x in version_str.split('.'):
|
20 |
+
if x.isdigit():
|
21 |
+
digit_version.append(int(x))
|
22 |
+
elif x.find('rc') != -1:
|
23 |
+
patch_version = x.split('rc')
|
24 |
+
digit_version.append(int(patch_version[0]) - 1)
|
25 |
+
digit_version.append(int(patch_version[1]))
|
26 |
+
return digit_version
|
27 |
+
|
28 |
+
|
29 |
+
def check_torch_version():
|
30 |
+
torch_minimum_version = '1.8.0'
|
31 |
+
torch_version = digit_version(torch.__version__)
|
32 |
+
|
33 |
+
assert (torch_version >= digit_version(torch_minimum_version)), \
|
34 |
+
f'Torch=={torch.__version__} is not support for converting to ' \
|
35 |
+
f'torchscript. Please install pytorch>={torch_minimum_version}.'
|
36 |
+
|
37 |
+
|
38 |
+
def _convert_batchnorm(module):
|
39 |
+
module_output = module
|
40 |
+
if isinstance(module, torch.nn.SyncBatchNorm):
|
41 |
+
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
|
42 |
+
module.momentum, module.affine,
|
43 |
+
module.track_running_stats)
|
44 |
+
if module.affine:
|
45 |
+
module_output.weight.data = module.weight.data.clone().detach()
|
46 |
+
module_output.bias.data = module.bias.data.clone().detach()
|
47 |
+
# keep requires_grad unchanged
|
48 |
+
module_output.weight.requires_grad = module.weight.requires_grad
|
49 |
+
module_output.bias.requires_grad = module.bias.requires_grad
|
50 |
+
module_output.running_mean = module.running_mean
|
51 |
+
module_output.running_var = module.running_var
|
52 |
+
module_output.num_batches_tracked = module.num_batches_tracked
|
53 |
+
for name, child in module.named_children():
|
54 |
+
module_output.add_module(name, _convert_batchnorm(child))
|
55 |
+
del module
|
56 |
+
return module_output
|
57 |
+
|
58 |
+
|
59 |
+
def _demo_mm_inputs(input_shape, num_classes):
|
60 |
+
"""Create a superset of inputs needed to run test or train batches.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
input_shape (tuple):
|
64 |
+
input batch dimensions
|
65 |
+
num_classes (int):
|
66 |
+
number of semantic classes
|
67 |
+
"""
|
68 |
+
(N, C, H, W) = input_shape
|
69 |
+
rng = np.random.RandomState(0)
|
70 |
+
imgs = rng.rand(*input_shape)
|
71 |
+
segs = rng.randint(
|
72 |
+
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
|
73 |
+
img_metas = [{
|
74 |
+
'img_shape': (H, W, C),
|
75 |
+
'ori_shape': (H, W, C),
|
76 |
+
'pad_shape': (H, W, C),
|
77 |
+
'filename': '<demo>.png',
|
78 |
+
'scale_factor': 1.0,
|
79 |
+
'flip': False,
|
80 |
+
} for _ in range(N)]
|
81 |
+
mm_inputs = {
|
82 |
+
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
|
83 |
+
'img_metas': img_metas,
|
84 |
+
'gt_semantic_seg': torch.LongTensor(segs)
|
85 |
+
}
|
86 |
+
return mm_inputs
|
87 |
+
|
88 |
+
|
89 |
+
def pytorch2libtorch(model,
|
90 |
+
input_shape,
|
91 |
+
show=False,
|
92 |
+
output_file='tmp.pt',
|
93 |
+
verify=False):
|
94 |
+
"""Export Pytorch model to TorchScript model and verify the outputs are
|
95 |
+
same between Pytorch and TorchScript.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
model (nn.Module): Pytorch model we want to export.
|
99 |
+
input_shape (tuple): Use this input shape to construct
|
100 |
+
the corresponding dummy input and execute the model.
|
101 |
+
show (bool): Whether print the computation graph. Default: False.
|
102 |
+
output_file (string): The path to where we store the
|
103 |
+
output TorchScript model. Default: `tmp.pt`.
|
104 |
+
verify (bool): Whether compare the outputs between
|
105 |
+
Pytorch and TorchScript. Default: False.
|
106 |
+
"""
|
107 |
+
if isinstance(model.decode_head, nn.ModuleList):
|
108 |
+
num_classes = model.decode_head[-1].num_classes
|
109 |
+
else:
|
110 |
+
num_classes = model.decode_head.num_classes
|
111 |
+
|
112 |
+
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
|
113 |
+
|
114 |
+
imgs = mm_inputs.pop('imgs')
|
115 |
+
|
116 |
+
# replace the original forword with forward_dummy
|
117 |
+
model.forward = model.forward_dummy
|
118 |
+
model.eval()
|
119 |
+
traced_model = torch.jit.trace(
|
120 |
+
model,
|
121 |
+
example_inputs=imgs,
|
122 |
+
check_trace=verify,
|
123 |
+
)
|
124 |
+
|
125 |
+
if show:
|
126 |
+
print(traced_model.graph)
|
127 |
+
|
128 |
+
traced_model.save(output_file)
|
129 |
+
print(f'Successfully exported TorchScript model: {output_file}')
|
130 |
+
|
131 |
+
|
132 |
+
def parse_args():
|
133 |
+
parser = argparse.ArgumentParser(
|
134 |
+
description='Convert MMSeg to TorchScript')
|
135 |
+
parser.add_argument('config', help='test config file path')
|
136 |
+
parser.add_argument('--checkpoint', help='checkpoint file', default=None)
|
137 |
+
parser.add_argument(
|
138 |
+
'--show', action='store_true', help='show TorchScript graph')
|
139 |
+
parser.add_argument(
|
140 |
+
'--verify', action='store_true', help='verify the TorchScript model')
|
141 |
+
parser.add_argument('--output-file', type=str, default='tmp.pt')
|
142 |
+
parser.add_argument(
|
143 |
+
'--shape',
|
144 |
+
type=int,
|
145 |
+
nargs='+',
|
146 |
+
default=[512, 512],
|
147 |
+
help='input image size (height, width)')
|
148 |
+
args = parser.parse_args()
|
149 |
+
return args
|
150 |
+
|
151 |
+
|
152 |
+
if __name__ == '__main__':
|
153 |
+
args = parse_args()
|
154 |
+
check_torch_version()
|
155 |
+
|
156 |
+
if len(args.shape) == 1:
|
157 |
+
input_shape = (1, 3, args.shape[0], args.shape[0])
|
158 |
+
elif len(args.shape) == 2:
|
159 |
+
input_shape = (
|
160 |
+
1,
|
161 |
+
3,
|
162 |
+
) + tuple(args.shape)
|
163 |
+
else:
|
164 |
+
raise ValueError('invalid input shape')
|
165 |
+
|
166 |
+
cfg = Config.fromfile(args.config)
|
167 |
+
cfg.model.pretrained = None
|
168 |
+
|
169 |
+
# build the model and load checkpoint
|
170 |
+
cfg.model.train_cfg = None
|
171 |
+
segmentor = build_segmentor(
|
172 |
+
cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
|
173 |
+
# convert SyncBN to BN
|
174 |
+
segmentor = _convert_batchnorm(segmentor)
|
175 |
+
|
176 |
+
if args.checkpoint:
|
177 |
+
load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
|
178 |
+
|
179 |
+
# convert the PyTorch model to LibTorch model
|
180 |
+
pytorch2libtorch(
|
181 |
+
segmentor,
|
182 |
+
input_shape,
|
183 |
+
show=args.show,
|
184 |
+
output_file=args.output_file,
|
185 |
+
verify=args.verify)
|
tools/model_converters/beit2mmseg.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import os.path as osp
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
import mmengine
|
7 |
+
import torch
|
8 |
+
from mmengine.runner import CheckpointLoader
|
9 |
+
|
10 |
+
|
11 |
+
def convert_beit(ckpt):
|
12 |
+
new_ckpt = OrderedDict()
|
13 |
+
|
14 |
+
for k, v in ckpt.items():
|
15 |
+
if k.startswith('patch_embed'):
|
16 |
+
new_key = k.replace('patch_embed.proj', 'patch_embed.projection')
|
17 |
+
new_ckpt[new_key] = v
|
18 |
+
if k.startswith('blocks'):
|
19 |
+
new_key = k.replace('blocks', 'layers')
|
20 |
+
if 'norm' in new_key:
|
21 |
+
new_key = new_key.replace('norm', 'ln')
|
22 |
+
elif 'mlp.fc1' in new_key:
|
23 |
+
new_key = new_key.replace('mlp.fc1', 'ffn.layers.0.0')
|
24 |
+
elif 'mlp.fc2' in new_key:
|
25 |
+
new_key = new_key.replace('mlp.fc2', 'ffn.layers.1')
|
26 |
+
new_ckpt[new_key] = v
|
27 |
+
else:
|
28 |
+
new_key = k
|
29 |
+
new_ckpt[new_key] = v
|
30 |
+
|
31 |
+
return new_ckpt
|
32 |
+
|
33 |
+
|
34 |
+
def main():
|
35 |
+
parser = argparse.ArgumentParser(
|
36 |
+
description='Convert keys in official pretrained beit models to'
|
37 |
+
'MMSegmentation style.')
|
38 |
+
parser.add_argument('src', help='src model path or url')
|
39 |
+
# The dst path must be a full path of the new checkpoint.
|
40 |
+
parser.add_argument('dst', help='save path')
|
41 |
+
args = parser.parse_args()
|
42 |
+
|
43 |
+
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
44 |
+
if 'state_dict' in checkpoint:
|
45 |
+
state_dict = checkpoint['state_dict']
|
46 |
+
elif 'model' in checkpoint:
|
47 |
+
state_dict = checkpoint['model']
|
48 |
+
else:
|
49 |
+
state_dict = checkpoint
|
50 |
+
weight = convert_beit(state_dict)
|
51 |
+
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
52 |
+
torch.save(weight, args.dst)
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == '__main__':
|
56 |
+
main()
|