XavierJiezou commited on
Commit
31c0288
·
verified ·
1 Parent(s): 918db92

Add files using upload-large-folder tool

Browse files
ktda/models/segmentors/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .distill_encoder_decoder import DistillEncoderDecoder
2
+
3
+ __all__ = ['DistillEncoderDecoder']
tools/analysis_tools/browse_dataset.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os.path as osp
4
+
5
+ from mmengine.config import Config, DictAction
6
+ from mmengine.utils import ProgressBar
7
+
8
+ from mmseg.registry import DATASETS, VISUALIZERS
9
+ from mmseg.utils import register_all_modules
10
+
11
+
12
+ def parse_args():
13
+ parser = argparse.ArgumentParser(description='Browse a dataset')
14
+ parser.add_argument('config', help='train config file path')
15
+ parser.add_argument(
16
+ '--output-dir',
17
+ default=None,
18
+ type=str,
19
+ help='If there is no display interface, you can save it')
20
+ parser.add_argument('--not-show', default=False, action='store_true')
21
+ parser.add_argument(
22
+ '--show-interval',
23
+ type=float,
24
+ default=2,
25
+ help='the interval of show (s)')
26
+ parser.add_argument(
27
+ '--cfg-options',
28
+ nargs='+',
29
+ action=DictAction,
30
+ help='override some settings in the used config, the key-value pair '
31
+ 'in xxx=yyy format will be merged into config file. If the value to '
32
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
33
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
34
+ 'Note that the quotation marks are necessary and that no white space '
35
+ 'is allowed.')
36
+ args = parser.parse_args()
37
+ return args
38
+
39
+
40
+ def main():
41
+ args = parse_args()
42
+ cfg = Config.fromfile(args.config)
43
+ if args.cfg_options is not None:
44
+ cfg.merge_from_dict(args.cfg_options)
45
+
46
+ # register all modules in mmdet into the registries
47
+ register_all_modules()
48
+
49
+ dataset = DATASETS.build(cfg.train_dataloader.dataset)
50
+ visualizer = VISUALIZERS.build(cfg.visualizer)
51
+ visualizer.dataset_meta = dataset.metainfo
52
+
53
+ progress_bar = ProgressBar(len(dataset))
54
+ for item in dataset:
55
+ img = item['inputs'].permute(1, 2, 0).numpy()
56
+ img = img[..., [2, 1, 0]] # bgr to rgb
57
+ data_sample = item['data_samples'].numpy()
58
+ img_path = osp.basename(item['data_samples'].img_path)
59
+
60
+ out_file = osp.join(
61
+ args.output_dir,
62
+ osp.basename(img_path)) if args.output_dir is not None else None
63
+
64
+ visualizer.add_datasample(
65
+ name=osp.basename(img_path),
66
+ image=img,
67
+ data_sample=data_sample,
68
+ draw_gt=True,
69
+ draw_pred=False,
70
+ wait_time=args.show_interval,
71
+ out_file=out_file,
72
+ show=not args.not_show)
73
+ progress_bar.update()
74
+
75
+
76
+ if __name__ == '__main__':
77
+ main()
tools/dataset_converters/drive.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os
4
+ import os.path as osp
5
+ import tempfile
6
+ import zipfile
7
+
8
+ import cv2
9
+ import mmcv
10
+ from mmengine.utils import mkdir_or_exist
11
+
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser(
15
+ description='Convert DRIVE dataset to mmsegmentation format')
16
+ parser.add_argument(
17
+ 'training_path', help='the training part of DRIVE dataset')
18
+ parser.add_argument(
19
+ 'testing_path', help='the testing part of DRIVE dataset')
20
+ parser.add_argument('--tmp_dir', help='path of the temporary directory')
21
+ parser.add_argument('-o', '--out_dir', help='output path')
22
+ args = parser.parse_args()
23
+ return args
24
+
25
+
26
+ def main():
27
+ args = parse_args()
28
+ training_path = args.training_path
29
+ testing_path = args.testing_path
30
+ if args.out_dir is None:
31
+ out_dir = osp.join('data', 'DRIVE')
32
+ else:
33
+ out_dir = args.out_dir
34
+
35
+ print('Making directories...')
36
+ mkdir_or_exist(out_dir)
37
+ mkdir_or_exist(osp.join(out_dir, 'images'))
38
+ mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
39
+ mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
40
+ mkdir_or_exist(osp.join(out_dir, 'annotations'))
41
+ mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
42
+ mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
43
+
44
+ with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
45
+ print('Extracting training.zip...')
46
+ zip_file = zipfile.ZipFile(training_path)
47
+ zip_file.extractall(tmp_dir)
48
+
49
+ print('Generating training dataset...')
50
+ now_dir = osp.join(tmp_dir, 'training', 'images')
51
+ for img_name in os.listdir(now_dir):
52
+ img = mmcv.imread(osp.join(now_dir, img_name))
53
+ mmcv.imwrite(
54
+ img,
55
+ osp.join(
56
+ out_dir, 'images', 'training',
57
+ osp.splitext(img_name)[0].replace('_training', '') +
58
+ '.png'))
59
+
60
+ now_dir = osp.join(tmp_dir, 'training', '1st_manual')
61
+ for img_name in os.listdir(now_dir):
62
+ cap = cv2.VideoCapture(osp.join(now_dir, img_name))
63
+ ret, img = cap.read()
64
+ mmcv.imwrite(
65
+ img[:, :, 0] // 128,
66
+ osp.join(out_dir, 'annotations', 'training',
67
+ osp.splitext(img_name)[0] + '.png'))
68
+
69
+ print('Extracting test.zip...')
70
+ zip_file = zipfile.ZipFile(testing_path)
71
+ zip_file.extractall(tmp_dir)
72
+
73
+ print('Generating validation dataset...')
74
+ now_dir = osp.join(tmp_dir, 'test', 'images')
75
+ for img_name in os.listdir(now_dir):
76
+ img = mmcv.imread(osp.join(now_dir, img_name))
77
+ mmcv.imwrite(
78
+ img,
79
+ osp.join(
80
+ out_dir, 'images', 'validation',
81
+ osp.splitext(img_name)[0].replace('_test', '') + '.png'))
82
+
83
+ now_dir = osp.join(tmp_dir, 'test', '1st_manual')
84
+ if osp.exists(now_dir):
85
+ for img_name in os.listdir(now_dir):
86
+ cap = cv2.VideoCapture(osp.join(now_dir, img_name))
87
+ ret, img = cap.read()
88
+ # The annotation img should be divided by 128, because some of
89
+ # the annotation imgs are not standard. We should set a
90
+ # threshold to convert the nonstandard annotation imgs. The
91
+ # value divided by 128 is equivalent to '1 if value >= 128
92
+ # else 0'
93
+ mmcv.imwrite(
94
+ img[:, :, 0] // 128,
95
+ osp.join(out_dir, 'annotations', 'validation',
96
+ osp.splitext(img_name)[0] + '.png'))
97
+
98
+ now_dir = osp.join(tmp_dir, 'test', '2nd_manual')
99
+ if osp.exists(now_dir):
100
+ for img_name in os.listdir(now_dir):
101
+ cap = cv2.VideoCapture(osp.join(now_dir, img_name))
102
+ ret, img = cap.read()
103
+ mmcv.imwrite(
104
+ img[:, :, 0] // 128,
105
+ osp.join(out_dir, 'annotations', 'validation',
106
+ osp.splitext(img_name)[0] + '.png'))
107
+
108
+ print('Removing the temporary files...')
109
+
110
+ print('Done!')
111
+
112
+
113
+ if __name__ == '__main__':
114
+ main()
tools/dataset_converters/vaihingen.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import glob
4
+ import math
5
+ import os
6
+ import os.path as osp
7
+ import tempfile
8
+ import zipfile
9
+
10
+ import mmcv
11
+ import numpy as np
12
+ from mmengine.utils import ProgressBar, mkdir_or_exist
13
+
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser(
17
+ description='Convert vaihingen dataset to mmsegmentation format')
18
+ parser.add_argument('dataset_path', help='vaihingen folder path')
19
+ parser.add_argument('--tmp_dir', help='path of the temporary directory')
20
+ parser.add_argument('-o', '--out_dir', help='output path')
21
+ parser.add_argument(
22
+ '--clip_size',
23
+ type=int,
24
+ help='clipped size of image after preparation',
25
+ default=512)
26
+ parser.add_argument(
27
+ '--stride_size',
28
+ type=int,
29
+ help='stride of clipping original images',
30
+ default=256)
31
+ args = parser.parse_args()
32
+ return args
33
+
34
+
35
+ def clip_big_image(image_path, clip_save_dir, to_label=False):
36
+ # Original image of Vaihingen dataset is very large, thus pre-processing
37
+ # of them is adopted. Given fixed clip size and stride size to generate
38
+ # clipped image, the intersection of width and height is determined.
39
+ # For example, given one 5120 x 5120 original image, the clip size is
40
+ # 512 and stride size is 256, thus it would generate 20x20 = 400 images
41
+ # whose size are all 512x512.
42
+ image = mmcv.imread(image_path)
43
+
44
+ h, w, c = image.shape
45
+ cs = args.clip_size
46
+ ss = args.stride_size
47
+
48
+ num_rows = math.ceil((h - cs) / ss) if math.ceil(
49
+ (h - cs) / ss) * ss + cs >= h else math.ceil((h - cs) / ss) + 1
50
+ num_cols = math.ceil((w - cs) / ss) if math.ceil(
51
+ (w - cs) / ss) * ss + cs >= w else math.ceil((w - cs) / ss) + 1
52
+
53
+ x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
54
+ xmin = x * cs
55
+ ymin = y * cs
56
+
57
+ xmin = xmin.ravel()
58
+ ymin = ymin.ravel()
59
+ xmin_offset = np.where(xmin + cs > w, w - xmin - cs, np.zeros_like(xmin))
60
+ ymin_offset = np.where(ymin + cs > h, h - ymin - cs, np.zeros_like(ymin))
61
+ boxes = np.stack([
62
+ xmin + xmin_offset, ymin + ymin_offset,
63
+ np.minimum(xmin + cs, w),
64
+ np.minimum(ymin + cs, h)
65
+ ],
66
+ axis=1)
67
+
68
+ if to_label:
69
+ color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0],
70
+ [255, 255, 0], [0, 255, 0], [0, 255, 255],
71
+ [0, 0, 255]])
72
+ flatten_v = np.matmul(
73
+ image.reshape(-1, c),
74
+ np.array([2, 3, 4]).reshape(3, 1))
75
+ out = np.zeros_like(flatten_v)
76
+ for idx, class_color in enumerate(color_map):
77
+ value_idx = np.matmul(class_color,
78
+ np.array([2, 3, 4]).reshape(3, 1))
79
+ out[flatten_v == value_idx] = idx
80
+ image = out.reshape(h, w)
81
+
82
+ for box in boxes:
83
+ start_x, start_y, end_x, end_y = box
84
+ clipped_image = image[start_y:end_y,
85
+ start_x:end_x] if to_label else image[
86
+ start_y:end_y, start_x:end_x, :]
87
+ area_idx = osp.basename(image_path).split('_')[3].strip('.tif')
88
+ mmcv.imwrite(
89
+ clipped_image.astype(np.uint8),
90
+ osp.join(clip_save_dir,
91
+ f'{area_idx}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
92
+
93
+
94
+ def main():
95
+ splits = {
96
+ 'train': [
97
+ 'area1', 'area11', 'area13', 'area15', 'area17', 'area21',
98
+ 'area23', 'area26', 'area28', 'area3', 'area30', 'area32',
99
+ 'area34', 'area37', 'area5', 'area7'
100
+ ],
101
+ 'val': [
102
+ 'area6', 'area24', 'area35', 'area16', 'area14', 'area22',
103
+ 'area10', 'area4', 'area2', 'area20', 'area8', 'area31', 'area33',
104
+ 'area27', 'area38', 'area12', 'area29'
105
+ ],
106
+ }
107
+
108
+ dataset_path = args.dataset_path
109
+ if args.out_dir is None:
110
+ out_dir = osp.join('data', 'vaihingen')
111
+ else:
112
+ out_dir = args.out_dir
113
+
114
+ print('Making directories...')
115
+ mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
116
+ mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
117
+ mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
118
+ mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
119
+
120
+ zipp_list = glob.glob(os.path.join(dataset_path, '*.zip'))
121
+ print('Find the data', zipp_list)
122
+
123
+ with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
124
+ for zipp in zipp_list:
125
+ zip_file = zipfile.ZipFile(zipp)
126
+ zip_file.extractall(tmp_dir)
127
+ src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif'))
128
+ if 'ISPRS_semantic_labeling_Vaihingen' in zipp:
129
+ src_path_list = glob.glob(
130
+ os.path.join(os.path.join(tmp_dir, 'top'), '*.tif'))
131
+ if 'ISPRS_semantic_labeling_Vaihingen_ground_truth_eroded_COMPLETE' in zipp: # noqa
132
+ src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif'))
133
+ # delete unused area9 ground truth
134
+ for area_ann in src_path_list:
135
+ if 'area9' in area_ann:
136
+ src_path_list.remove(area_ann)
137
+ prog_bar = ProgressBar(len(src_path_list))
138
+ for i, src_path in enumerate(src_path_list):
139
+ area_idx = osp.basename(src_path).split('_')[3].strip('.tif')
140
+ data_type = 'train' if area_idx in splits['train'] else 'val'
141
+ if 'noBoundary' in src_path:
142
+ dst_dir = osp.join(out_dir, 'ann_dir', data_type)
143
+ clip_big_image(src_path, dst_dir, to_label=True)
144
+ else:
145
+ dst_dir = osp.join(out_dir, 'img_dir', data_type)
146
+ clip_big_image(src_path, dst_dir, to_label=False)
147
+ prog_bar.update()
148
+
149
+ print('Removing the temporary files...')
150
+
151
+ print('Done!')
152
+
153
+
154
+ if __name__ == '__main__':
155
+ args = parse_args()
156
+ main()
tools/dataset_tools/analysis_dataset.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+ from typing import Tuple,List
3
+ import os
4
+ import argparse
5
+ import json
6
+ from matplotlib import pyplot as plt
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ def get_args()->Tuple[str, str]:
11
+ """
12
+ Return:
13
+ --dataset_dir: dataset dir.
14
+ --save_dir: save dir.
15
+ """
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument('--dataset_dir', type=str, default='data/grass')
18
+ parser.add_argument('--save_dir', type=str, default='dataset_num_analysis.png')
19
+ args = parser.parse_args()
20
+ return args.dataset_dir, args.save_dir
21
+
22
+ def get_mask_files(dataset_dir: str)->List[str]:
23
+ """
24
+ get mask files from dataset dir.
25
+ Args:
26
+ dataset_dir: dataset dir.
27
+ Return:
28
+ mask_filenames: list of mask filenames.
29
+ """
30
+ mask_filenames = glob(os.path.join(dataset_dir, "ann_dir", "*", "*.png"))
31
+ return mask_filenames
32
+
33
+ def main():
34
+ dataset_dir, save_dir = get_args()
35
+ mask_filenames = get_mask_files(dataset_dir)
36
+ statistic = {}
37
+ for mask_filename in mask_filenames:
38
+ mask = np.array(Image.open(mask_filename))
39
+ classes = np.unique(mask)
40
+ for class_ in classes:
41
+ class_ = int(class_)
42
+ if class_ not in statistic:
43
+ statistic[class_] = 0
44
+ statistic[(class_)] += int(np.sum(mask == class_))
45
+
46
+ classes = list(statistic.keys())
47
+ clasees_num = list(statistic.values())
48
+
49
+ plt.title("Dataset Analysis")
50
+ bars = plt.bar(classes, clasees_num)
51
+ for bar in bars:
52
+ height = bar.get_height()
53
+ plt.text(bar.get_x() + bar.get_width() / 2, height + 5, str(height), ha='center', va='bottom')
54
+ plt.savefig(save_dir,dpi=300)
55
+
56
+
57
+ if __name__ == "__main__":
58
+ main()
tools/dataset_tools/process_water.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ from glob import glob
3
+ import os
4
+ import argparse
5
+ import numpy as np
6
+ from rich.progress import track
7
+ from PIL import Image
8
+ from typing import List
9
+ from vegseg.datasets import WaterDataset
10
+ from sklearn.model_selection import train_test_split
11
+
12
+
13
+ def get_args():
14
+ parse = argparse.ArgumentParser()
15
+ parse.add_argument("--raw_path", type=str)
16
+ parse.add_argument("--tmp_dir", type=str)
17
+ parse.add_argument("--save_path", type=str)
18
+ args = parse.parse_args()
19
+ return args.raw_path, args.tmp_dir, args.save_path
20
+
21
+
22
+ def get_palette() -> List[int]:
23
+ """
24
+ get palette of dataset.
25
+ return:
26
+ palette: list of palette.
27
+ """
28
+ palette = []
29
+ palette_list = WaterDataset.METAINFO["palette"]
30
+ for palette_item in palette_list:
31
+ palette.extend(palette_item)
32
+ return palette
33
+
34
+
35
+ def create_dataset(image_list, ann_list, image_dir, ann_dir, description="Working..."):
36
+ os.makedirs(image_dir, exist_ok=True)
37
+ os.makedirs(ann_dir, exist_ok=True)
38
+ for image_path, ann_path in track(
39
+ zip(image_list, ann_list), total=len(image_list), description=description
40
+ ):
41
+ base_name = os.path.basename(image_path)
42
+
43
+ new_image_path = os.path.join(image_dir, base_name)
44
+ new_ann_path = os.path.join(ann_dir, base_name)
45
+
46
+ shutil.move(image_path, new_image_path)
47
+ shutil.move(ann_path, new_ann_path)
48
+
49
+ mask = Image.open(new_ann_path).convert("P")
50
+ palette = get_palette()
51
+ mask.putpalette(palette)
52
+ mask.save(new_ann_path)
53
+
54
+
55
+ def main():
56
+ classes_mapping = {
57
+ "CDUWD-1": 1,
58
+ "CDUWD-2": 2,
59
+ "CDUWD-3": 3,
60
+ "CDUWD-4": 4,
61
+ "CDUWD-5": 5,
62
+ "CDUWD-6": 0,
63
+ }
64
+
65
+ raw_path, tmp_dir, save_path = get_args()
66
+
67
+ all_images = glob(os.path.join(raw_path, "*", "images", "*.png"))
68
+
69
+ all_labels = [image_path.replace("images", "labels") for image_path in all_images]
70
+
71
+ target_image_dir = os.path.join(tmp_dir, "images")
72
+ target_label_dir = os.path.join(tmp_dir, "labels")
73
+
74
+ os.makedirs(target_image_dir, exist_ok=True)
75
+ os.makedirs(target_label_dir, exist_ok=True)
76
+
77
+ for image_path, label_path in track(
78
+ zip(all_images, all_labels), total=len(all_images), description="fuse dataset"
79
+ ):
80
+ exists_images = glob(os.path.join(target_image_dir, "*.png"))
81
+
82
+ base_name = os.path.basename(image_path)
83
+ if image_path not in exists_images:
84
+ mask = np.array(Image.open(label_path))
85
+
86
+ assert list(np.unique(mask)) in [
87
+ [0],
88
+ [1],
89
+ [0, 1],
90
+ [1, 0],
91
+ ], f"The mask image is not binary (it should only contain 0s and 1s),actually is {set(np.unique(mask))}"
92
+
93
+ classes_str = image_path.split(os.path.sep)[-3]
94
+ classes = classes_mapping[classes_str]
95
+ mask = np.where(mask == 1, classes, mask)
96
+
97
+ # print(classes_str)
98
+
99
+ mask = Image.fromarray(mask)
100
+ mask.save(os.path.join(target_label_dir, base_name))
101
+ shutil.copy(image_path, os.path.join(target_image_dir, base_name))
102
+ else:
103
+
104
+ exists_label_path = os.path.join(target_label_dir, base_name)
105
+ exists_mask = np.array(Image.open(exists_label_path))
106
+
107
+ mask = np.array(Image.open(label_path))
108
+ assert list(np.unique(mask)) in [
109
+ [0],
110
+ [1],
111
+ [0, 1],
112
+ [1, 0],
113
+ ], f"The mask image is not binary (it should only contain 0s and 1s),actually is {set(np.unique(mask))}"
114
+ classes_str = image_path.split(os.path.sep)[-3]
115
+ classes = classes_mapping[classes_str]
116
+
117
+ exists_mask = np.where(mask == 1, classes, exists_mask)
118
+
119
+ exists_mask = Image.fromarray(exists_mask)
120
+ exists_mask.save(exists_label_path)
121
+
122
+ exists_images = glob(os.path.join(target_image_dir, "*.png"))
123
+
124
+ exists_labels = [
125
+ image_path.replace("images", "labels") for image_path in exists_images
126
+ ]
127
+ X_train, X_test, y_train, y_test = train_test_split(
128
+ exists_images, exists_labels, test_size=0.2, random_state=42, shuffle=True
129
+ )
130
+
131
+ create_dataset(
132
+ X_train,
133
+ y_train,
134
+ os.path.join(save_path, "img_dir", "train"),
135
+ os.path.join(save_path, "ann_dir", "train"),
136
+ description="train dataset",
137
+ )
138
+ create_dataset(
139
+ X_test,
140
+ y_test,
141
+ os.path.join(save_path, "img_dir", "val"),
142
+ os.path.join(save_path, "ann_dir", "val"),
143
+ description="val dataset",
144
+ )
145
+
146
+ os.rmdir(target_image_dir)
147
+ os.rmdir(target_label_dir)
148
+
149
+
150
+ if __name__ == "__main__":
151
+ # example python tools/dataset_tools/process_water.py --raw_path data/raw_water_dataset/1024 --tmp_dir data/raw_water_dataset/1024/all_dataset --save_path data/water_1024_1024
152
+ main()
tools/deployment/pytorch2torchscript.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch._C
7
+ import torch.serialization
8
+ from mmengine import Config
9
+ from mmengine.runner import load_checkpoint
10
+ from torch import nn
11
+
12
+ from mmseg.models import build_segmentor
13
+
14
+ torch.manual_seed(3)
15
+
16
+
17
+ def digit_version(version_str):
18
+ digit_version = []
19
+ for x in version_str.split('.'):
20
+ if x.isdigit():
21
+ digit_version.append(int(x))
22
+ elif x.find('rc') != -1:
23
+ patch_version = x.split('rc')
24
+ digit_version.append(int(patch_version[0]) - 1)
25
+ digit_version.append(int(patch_version[1]))
26
+ return digit_version
27
+
28
+
29
+ def check_torch_version():
30
+ torch_minimum_version = '1.8.0'
31
+ torch_version = digit_version(torch.__version__)
32
+
33
+ assert (torch_version >= digit_version(torch_minimum_version)), \
34
+ f'Torch=={torch.__version__} is not support for converting to ' \
35
+ f'torchscript. Please install pytorch>={torch_minimum_version}.'
36
+
37
+
38
+ def _convert_batchnorm(module):
39
+ module_output = module
40
+ if isinstance(module, torch.nn.SyncBatchNorm):
41
+ module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
42
+ module.momentum, module.affine,
43
+ module.track_running_stats)
44
+ if module.affine:
45
+ module_output.weight.data = module.weight.data.clone().detach()
46
+ module_output.bias.data = module.bias.data.clone().detach()
47
+ # keep requires_grad unchanged
48
+ module_output.weight.requires_grad = module.weight.requires_grad
49
+ module_output.bias.requires_grad = module.bias.requires_grad
50
+ module_output.running_mean = module.running_mean
51
+ module_output.running_var = module.running_var
52
+ module_output.num_batches_tracked = module.num_batches_tracked
53
+ for name, child in module.named_children():
54
+ module_output.add_module(name, _convert_batchnorm(child))
55
+ del module
56
+ return module_output
57
+
58
+
59
+ def _demo_mm_inputs(input_shape, num_classes):
60
+ """Create a superset of inputs needed to run test or train batches.
61
+
62
+ Args:
63
+ input_shape (tuple):
64
+ input batch dimensions
65
+ num_classes (int):
66
+ number of semantic classes
67
+ """
68
+ (N, C, H, W) = input_shape
69
+ rng = np.random.RandomState(0)
70
+ imgs = rng.rand(*input_shape)
71
+ segs = rng.randint(
72
+ low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
73
+ img_metas = [{
74
+ 'img_shape': (H, W, C),
75
+ 'ori_shape': (H, W, C),
76
+ 'pad_shape': (H, W, C),
77
+ 'filename': '<demo>.png',
78
+ 'scale_factor': 1.0,
79
+ 'flip': False,
80
+ } for _ in range(N)]
81
+ mm_inputs = {
82
+ 'imgs': torch.FloatTensor(imgs).requires_grad_(True),
83
+ 'img_metas': img_metas,
84
+ 'gt_semantic_seg': torch.LongTensor(segs)
85
+ }
86
+ return mm_inputs
87
+
88
+
89
+ def pytorch2libtorch(model,
90
+ input_shape,
91
+ show=False,
92
+ output_file='tmp.pt',
93
+ verify=False):
94
+ """Export Pytorch model to TorchScript model and verify the outputs are
95
+ same between Pytorch and TorchScript.
96
+
97
+ Args:
98
+ model (nn.Module): Pytorch model we want to export.
99
+ input_shape (tuple): Use this input shape to construct
100
+ the corresponding dummy input and execute the model.
101
+ show (bool): Whether print the computation graph. Default: False.
102
+ output_file (string): The path to where we store the
103
+ output TorchScript model. Default: `tmp.pt`.
104
+ verify (bool): Whether compare the outputs between
105
+ Pytorch and TorchScript. Default: False.
106
+ """
107
+ if isinstance(model.decode_head, nn.ModuleList):
108
+ num_classes = model.decode_head[-1].num_classes
109
+ else:
110
+ num_classes = model.decode_head.num_classes
111
+
112
+ mm_inputs = _demo_mm_inputs(input_shape, num_classes)
113
+
114
+ imgs = mm_inputs.pop('imgs')
115
+
116
+ # replace the original forword with forward_dummy
117
+ model.forward = model.forward_dummy
118
+ model.eval()
119
+ traced_model = torch.jit.trace(
120
+ model,
121
+ example_inputs=imgs,
122
+ check_trace=verify,
123
+ )
124
+
125
+ if show:
126
+ print(traced_model.graph)
127
+
128
+ traced_model.save(output_file)
129
+ print(f'Successfully exported TorchScript model: {output_file}')
130
+
131
+
132
+ def parse_args():
133
+ parser = argparse.ArgumentParser(
134
+ description='Convert MMSeg to TorchScript')
135
+ parser.add_argument('config', help='test config file path')
136
+ parser.add_argument('--checkpoint', help='checkpoint file', default=None)
137
+ parser.add_argument(
138
+ '--show', action='store_true', help='show TorchScript graph')
139
+ parser.add_argument(
140
+ '--verify', action='store_true', help='verify the TorchScript model')
141
+ parser.add_argument('--output-file', type=str, default='tmp.pt')
142
+ parser.add_argument(
143
+ '--shape',
144
+ type=int,
145
+ nargs='+',
146
+ default=[512, 512],
147
+ help='input image size (height, width)')
148
+ args = parser.parse_args()
149
+ return args
150
+
151
+
152
+ if __name__ == '__main__':
153
+ args = parse_args()
154
+ check_torch_version()
155
+
156
+ if len(args.shape) == 1:
157
+ input_shape = (1, 3, args.shape[0], args.shape[0])
158
+ elif len(args.shape) == 2:
159
+ input_shape = (
160
+ 1,
161
+ 3,
162
+ ) + tuple(args.shape)
163
+ else:
164
+ raise ValueError('invalid input shape')
165
+
166
+ cfg = Config.fromfile(args.config)
167
+ cfg.model.pretrained = None
168
+
169
+ # build the model and load checkpoint
170
+ cfg.model.train_cfg = None
171
+ segmentor = build_segmentor(
172
+ cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
173
+ # convert SyncBN to BN
174
+ segmentor = _convert_batchnorm(segmentor)
175
+
176
+ if args.checkpoint:
177
+ load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
178
+
179
+ # convert the PyTorch model to LibTorch model
180
+ pytorch2libtorch(
181
+ segmentor,
182
+ input_shape,
183
+ show=args.show,
184
+ output_file=args.output_file,
185
+ verify=args.verify)
tools/model_converters/beit2mmseg.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os.path as osp
4
+ from collections import OrderedDict
5
+
6
+ import mmengine
7
+ import torch
8
+ from mmengine.runner import CheckpointLoader
9
+
10
+
11
+ def convert_beit(ckpt):
12
+ new_ckpt = OrderedDict()
13
+
14
+ for k, v in ckpt.items():
15
+ if k.startswith('patch_embed'):
16
+ new_key = k.replace('patch_embed.proj', 'patch_embed.projection')
17
+ new_ckpt[new_key] = v
18
+ if k.startswith('blocks'):
19
+ new_key = k.replace('blocks', 'layers')
20
+ if 'norm' in new_key:
21
+ new_key = new_key.replace('norm', 'ln')
22
+ elif 'mlp.fc1' in new_key:
23
+ new_key = new_key.replace('mlp.fc1', 'ffn.layers.0.0')
24
+ elif 'mlp.fc2' in new_key:
25
+ new_key = new_key.replace('mlp.fc2', 'ffn.layers.1')
26
+ new_ckpt[new_key] = v
27
+ else:
28
+ new_key = k
29
+ new_ckpt[new_key] = v
30
+
31
+ return new_ckpt
32
+
33
+
34
+ def main():
35
+ parser = argparse.ArgumentParser(
36
+ description='Convert keys in official pretrained beit models to'
37
+ 'MMSegmentation style.')
38
+ parser.add_argument('src', help='src model path or url')
39
+ # The dst path must be a full path of the new checkpoint.
40
+ parser.add_argument('dst', help='save path')
41
+ args = parser.parse_args()
42
+
43
+ checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
44
+ if 'state_dict' in checkpoint:
45
+ state_dict = checkpoint['state_dict']
46
+ elif 'model' in checkpoint:
47
+ state_dict = checkpoint['model']
48
+ else:
49
+ state_dict = checkpoint
50
+ weight = convert_beit(state_dict)
51
+ mmengine.mkdir_or_exist(osp.dirname(args.dst))
52
+ torch.save(weight, args.dst)
53
+
54
+
55
+ if __name__ == '__main__':
56
+ main()