CVRPDataset commited on
Commit
1060621
·
verified ·
1 Parent(s): 1e04206

Upload 7 files

Browse files
run/json2png.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import os
4
+ import os.path as osp
5
+
6
+ import numpy as np
7
+ import PIL.Image
8
+ from labelme import utils
9
+
10
+
11
+ def json2mask(root, classes)
12
+ before_path = os.path.join(root, 'json') # labelme files
13
+ jpgs_path = os.path.join(root, 'jpgs') # images
14
+ pngs_path = os.path.join(root, 'masks') # annotations
15
+ if not os.path.exists(jpgs_path):
16
+ os.makedirs(jpgs_path)
17
+ if not os.path.exists(pngs_path):
18
+ os.makedirs(pngs_path)
19
+
20
+ path = before_path
21
+ for file in os.listdir(path):
22
+ if file.endswith('json'):
23
+ data = json.load(open(os.path.join(path, file)))
24
+
25
+ if data['imageData']:
26
+ imageData = data['imageData']
27
+ else:
28
+ imagePath = os.path.join(os.path.dirname(path), data['imagePath'])
29
+ with open(imagePath, 'rb') as f:
30
+ imageData = f.read()
31
+ imageData = base64.b64encode(imageData).decode('utf-8')
32
+
33
+ img = utils.img_b64_to_arr(imageData)
34
+ label_name_to_value = {classes[0]: 0, classes[2]:0, classes[1]: 1 }
35
+
36
+ lbl = utils.shapes_to_label(img.shape, data['shapes'], label_name_to_value)
37
+
38
+ PIL.Image.fromarray(img).save(osp.join(jpgs_path, file.split(".")[0] + '.jpg'))
39
+
40
+ utils.lblsave(osp.join(pngs_path, file.split(".")[0] + '.png'), lbl)
41
+
42
+ if __name__ == '__main__':
43
+ root = 'J:/dataset_panicle/2023/only_plant/images'
44
+ classes = ["_background_", "panicle", "other"]
45
+ json2mask(root,classes)
run/labelme2mask.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import cv2
5
+ import shutil
6
+ from tqdm import tqdm
7
+
8
+
9
+ def labelme2mask_single_img(img_path, labelme_json_path, class_info):
10
+ '''
11
+ Convert a single image's LabelMe annotation to a mask.
12
+ '''
13
+ img_bgr = cv2.imread(img_path)
14
+ img_mask = np.zeros(img_bgr.shape[:2], dtype=np.uint8) # Create an empty mask image (0 - background)
15
+
16
+ with open(labelme_json_path, 'r', encoding='utf-8') as f:
17
+ labelme = json.load(f)
18
+
19
+ for one_class in class_info: # Iterate over each class in class_info
20
+ for each in labelme['shapes']: # Iterate over all shapes in the annotation
21
+ if each['label'] == one_class['label']:
22
+ if one_class['type'] == 'polygon': # Handle polygon annotation
23
+ points = [np.array(each['points'], dtype=np.int32).reshape((-1, 1, 2))] # Ensure correct shape
24
+ img_mask = cv2.fillPoly(img_mask, points, color=one_class['color'])
25
+ elif one_class['type'] == 'line' or one_class['type'] == 'linestrip': # Handle line annotation
26
+ points = [np.array(each['points'], dtype=np.int32).reshape((-1, 1, 2))]
27
+ img_mask = cv2.polylines(img_mask, points, isClosed=False, color=one_class['color'],
28
+ thickness=one_class.get('thickness', 1))
29
+ elif one_class['type'] == 'circle': # Handle circle annotation
30
+ points = np.array(each['points'], dtype=np.int32)
31
+ center_x, center_y = points[0][0], points[0][1]
32
+ edge_x, edge_y = points[1][0], points[1][1]
33
+ radius = int(np.linalg.norm([center_x - edge_x, center_y - edge_y]))
34
+ img_mask = cv2.circle(img_mask, (center_x, center_y), radius, one_class['color'], -1)
35
+ else:
36
+ print('Unknown annotation type:', one_class['type'])
37
+
38
+ return img_mask
39
+
40
+
41
+ def convert_labelme_to_mask(Dataset_Path):
42
+ '''
43
+ Convert all LabelMe annotations in the dataset to mask images.
44
+ '''
45
+ # Dataset directories
46
+ img_dir = os.path.join(Dataset_Path, 'images')
47
+ ann_dir = os.path.join(Dataset_Path, 'labelme_jsons')
48
+
49
+ # Class information for mask generation
50
+ class_info = [
51
+ {'label': 'panicle', 'type': 'polygon', 'color': 1}
52
+ ]
53
+
54
+ # Create target directories
55
+ images_target_dir = os.path.join(Dataset_Path, 'img_dir')
56
+ ann_target_dir = os.path.join(Dataset_Path, 'ann_dir')
57
+
58
+ # Create target directories if they do not exist
59
+ os.makedirs(images_target_dir, exist_ok=True)
60
+ os.makedirs(ann_target_dir, exist_ok=True)
61
+
62
+ # Process each image in the images directory
63
+ for img_name in tqdm(os.listdir(img_dir), desc="Converting images to masks"):
64
+ try:
65
+ img_path = os.path.join(img_dir, img_name)
66
+ labelme_json_path = os.path.join(ann_dir, f'{os.path.splitext(img_name)[0]}.json')
67
+
68
+ if os.path.exists(labelme_json_path):
69
+ # Convert LabelMe annotations to mask
70
+ img_mask = labelme2mask_single_img(img_path, labelme_json_path, class_info)
71
+
72
+ # Save the mask to the target directory
73
+ mask_path = os.path.join(ann_target_dir, f'{os.path.splitext(img_name)[0]}.png')
74
+ cv2.imwrite(mask_path, img_mask)
75
+
76
+ # Move the image to the target directory
77
+ shutil.move(img_path, os.path.join(images_target_dir, img_name))
78
+ else:
79
+ print(f"Annotation file missing for {img_name}")
80
+
81
+ except Exception as e:
82
+ print(f"Failed to convert {img_name}: {e}")
83
+
84
+ # Optionally remove the original directories if they are empty
85
+ shutil.rmtree(img_dir, ignore_errors=True)
86
+ shutil.rmtree(ann_dir, ignore_errors=True)
87
+
88
+ print("Conversion completed.")
89
+
90
+
91
+ if __name__ == '__main__':
92
+ Dataset_Path = 'CVRP' # Update this to the path of your dataset
93
+ convert_labelme_to_mask(Dataset_Path)
run/mask2json.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import json
4
+ import numpy as np
5
+ from pycococreatortools import pycococreatortools
6
+ from PIL import Image
7
+ import base64
8
+ import cv2
9
+ from tqdm import tqdm
10
+ def img_tobyte(img_pil):
11
+
12
+ ENCODING = 'utf-8'
13
+ img_byte = io.BytesIO()
14
+ img_pil.save(img_byte, format='PNG')
15
+ binary_str2 = img_byte.getvalue()
16
+ imageData = base64.b64encode(binary_str2)
17
+ base64_string = imageData.decode(ENCODING)
18
+ return base64_string
19
+
20
+ def mask2json(ROOT_DIR):
21
+ Image_DIR = os.path.join(ROOT_DIR, "pngs")
22
+ Label_DIR = os.path.join(ROOT_DIR, "pre")
23
+
24
+ Label_files = os.listdir(Label_DIR)
25
+ # classs_names
26
+ class_names = ['_background_', 'panicle']
27
+ for Label_filename in tqdm(Label_files):
28
+
29
+ Json_output = {
30
+ "version": "3.16.7",
31
+ "flags": {},
32
+ "fillColor": [255, 0, 0, 128],
33
+ "lineColor": [0, 255, 0, 128],
34
+ "imagePath": {},
35
+ "shapes": [],
36
+ "imageData": {}}
37
+ name = Label_filename.split('.', 3)[0]
38
+ name1 = name + '.png'
39
+ Json_output["imagePath"] = name1
40
+
41
+ image = Image.open(Image_DIR + '/' + name1)
42
+ imageData = img_tobyte(image)
43
+ Json_output["imageData"] = imageData
44
+ binary_mask = np.asarray(np.array(Image.open(Label_DIR + '/' + Label_filename))).astype(np.uint8)
45
+ mask_image = cv2.imread(Label_DIR + '/' + Label_filename, cv2.IMREAD_GRAYSCALE)
46
+ temp_mask = np.asarray((mask_image != 0), dtype=np.uint8)
47
+
48
+ segmentation = pycococreatortools.binary_mask_to_polygon(temp_mask, tolerance=2)
49
+ for item in segmentation:
50
+ if len(item) > 10:
51
+ list1 = []
52
+ for j in range(0, len(item), 2):
53
+ list1.append([item[j], item[j + 1]])
54
+ # There is only one non-background class, so just use class_names[1]
55
+ label = class_names[1]
56
+ seg_info = {'points': list1, "fill_color": None, "line_color": None, "label": label,
57
+ "shape_type": "polygon", "flags": {}}
58
+ Json_output["shapes"].append(seg_info)
59
+
60
+ Json_output["imageHeight"] = binary_mask.shape[0]
61
+ Json_output["imageWidth"] = binary_mask.shape[1]
62
+ json_path = os.path.join(ROOT_DIR,'json')
63
+ if not os.path.exists(json_path):
64
+ os.makedirs(json_path)
65
+ full_path = os.path.join(json_path, '{}.json'.format(name))
66
+ with open(full_path, 'w') as output_json_file:
67
+ json.dump(Json_output, output_json_file)
68
+
69
+
70
+ if __name__ == '__main__':
71
+
72
+ ROOT_DIR = ''
73
+ mask2json(ROOT_DIR)
74
+
run/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ opencv-python
2
+ pillow
3
+ matplotlib
4
+ seaborn
5
+ tqdm
6
+ ftfy
7
+ regex
8
+ pytorch-lightning
9
+ mmdet>=3.1.0
run/run_configs.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from mmengine import Config
4
+
5
+ def create_deeplabv3plus_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval):
6
+ cfg = Config.fromfile(model_config_path)
7
+ dataset_cfg = Config.fromfile(dataset_config_path)
8
+ cfg.merge_from_dict(dataset_cfg)
9
+
10
+ # Set crop size
11
+ cfg.crop_size = (512, 512)
12
+ cfg.model.data_preprocessor.size = cfg.crop_size
13
+
14
+ # Configure normalization
15
+ cfg.norm_cfg = dict(type='BN', requires_grad=True)
16
+ cfg.model.backbone.norm_cfg = cfg.norm_cfg
17
+ cfg.model.decode_head.norm_cfg = cfg.norm_cfg
18
+ cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
19
+
20
+ cfg.model.decode_head.num_classes = num_class
21
+ cfg.model.auxiliary_head.num_classes = num_class
22
+
23
+ cfg.train_dataloader.batch_size = batch_size
24
+
25
+ # Set training configurations
26
+ cfg.train_cfg.max_iters = max_iters
27
+ cfg.train_cfg.val_interval = val_interval
28
+ cfg.default_hooks.logger.interval = 100
29
+ cfg.default_hooks.checkpoint.interval = 2500
30
+ cfg.default_hooks.checkpoint.max_keep_ckpts = 1
31
+ cfg.default_hooks.checkpoint.save_best = 'mIoU'
32
+
33
+ cfg['randomness'] = dict(seed=0)
34
+ # Set work directory
35
+ cfg.save_dir = save_dir
36
+ name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1]
37
+ cfg.work_dir = os.path.join(work_dir,name)
38
+ os.makedirs(cfg.work_dir, exist_ok=True)
39
+ save_config_file = os.path.join(save_dir, f"{name}.py")
40
+ cfg.dump(save_config_file)
41
+ print(f"Configuration saved to: {save_config_file}")
42
+
43
+ def create_knet_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval):
44
+
45
+ cfg = Config.fromfile(model_config_path)
46
+ dataset_cfg = Config.fromfile(dataset_config_path)
47
+
48
+ cfg.merge_from_dict(dataset_cfg)
49
+
50
+ cfg.norm_cfg = dict(type='BN', requires_grad=True)
51
+ cfg.model.data_preprocessor.size = cfg.crop_size
52
+
53
+ cfg.model.decode_head.kernel_generate_head.num_classes = num_class
54
+ cfg.model.auxiliary_head.num_classes = num_class
55
+
56
+ cfg.train_dataloader.batch_size = batch_size
57
+ cfg.work_dir = work_dir
58
+
59
+ cfg.train_cfg.max_iters = max_iters
60
+ cfg.train_cfg.val_interval = val_interval
61
+ cfg.default_hooks.logger.interval = 100
62
+ cfg.default_hooks.checkpoint.interval = 2500
63
+ cfg.default_hooks.checkpoint.max_keep_ckpts = 1
64
+ cfg.default_hooks.checkpoint.save_best = 'mIoU'
65
+
66
+ cfg['randomness'] = dict(seed=0)
67
+
68
+ cfg.save_dir = save_dir
69
+ name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1]
70
+ cfg.work_dir = os.path.join(work_dir, name)
71
+ os.makedirs(cfg.work_dir, exist_ok=True)
72
+ save_config_file = os.path.join(save_dir, f"{name}.py")
73
+ cfg.dump(save_config_file)
74
+ print(f"Configuration saved to: {save_config_file}")
75
+
76
+ def create_mask2former_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval):
77
+ cfg = Config.fromfile(model_config_path)
78
+ dataset_cfg = Config.fromfile(dataset_config_path)
79
+ cfg.merge_from_dict(dataset_cfg)
80
+
81
+ # Set crop size
82
+ cfg.crop_size = (512, 512)
83
+ cfg.model.data_preprocessor.size = cfg.crop_size
84
+
85
+ # Configure normalization
86
+ cfg.norm_cfg = dict(type='BN', requires_grad=True)
87
+
88
+ cfg.model.decode_head.num_classes = num_class
89
+ cfg.model.decode_head.loss_cls.class_weight = [1.0] * num_class + [0.1]
90
+
91
+ cfg.train_dataloader.batch_size = batch_size
92
+
93
+ # Set training configurations
94
+ cfg.train_cfg.max_iters = max_iters
95
+ cfg.train_cfg.val_interval = val_interval
96
+ cfg.default_hooks.logger.interval = 100
97
+ cfg.default_hooks.checkpoint.interval = 2500
98
+ cfg.default_hooks.checkpoint.max_keep_ckpts = 1
99
+ cfg.default_hooks.checkpoint.save_best = 'mIoU'
100
+
101
+ cfg['randomness'] = dict(seed=0)
102
+ # Set work directory
103
+ cfg.save_dir = save_dir
104
+ name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1]
105
+ cfg.work_dir = os.path.join(work_dir,name)
106
+ os.makedirs(cfg.work_dir, exist_ok=True)
107
+ save_config_file = os.path.join(save_dir, f"{name}.py")
108
+ cfg.dump(save_config_file)
109
+ print(f"Configuration saved to: {save_config_file}")
110
+
111
+ def create_segformer_config(model_config_path, dataset_config_path, num_class, work_dir, save_dir, batch_size, max_iters, val_interval):
112
+ cfg = Config.fromfile(model_config_path)
113
+ dataset_cfg = Config.fromfile(dataset_config_path)
114
+ cfg.merge_from_dict(dataset_cfg)
115
+
116
+ # Configure normalization
117
+ cfg.norm_cfg = dict(type='BN', requires_grad=True)
118
+ cfg.model.data_preprocessor.size = cfg.crop_size
119
+ cfg.model.decode_head.norm_cfg = cfg.norm_cfg
120
+
121
+ cfg.model.decode_head.num_classes = num_class
122
+
123
+ cfg.train_dataloader.batch_size = batch_size
124
+
125
+ # Set training configurations
126
+ cfg.train_cfg.max_iters = max_iters
127
+ cfg.train_cfg.val_interval = val_interval
128
+ cfg.default_hooks.logger.interval = 100
129
+ cfg.default_hooks.checkpoint.interval = 2500
130
+ cfg.default_hooks.checkpoint.max_keep_ckpts = 1
131
+ cfg.default_hooks.checkpoint.save_best = 'mIoU'
132
+
133
+ cfg['randomness'] = dict(seed=0)
134
+ # Set work directory
135
+ cfg.save_dir = save_dir
136
+ name = os.path.basename(dataset_config_path).split('_')[0] + "_" + os.path.dirname(model_config_path).split(os.sep)[1]
137
+ cfg.work_dir = os.path.join(work_dir,name)
138
+ os.makedirs(cfg.work_dir, exist_ok=True)
139
+ save_config_file = os.path.join(save_dir, f"{name}.py")
140
+ cfg.dump(save_config_file)
141
+ print(f"Configuration saved to: {save_config_file}")
142
+
143
+ def main():
144
+ parser = argparse.ArgumentParser(description='Train configuration setup for different models.')
145
+
146
+ parser.add_argument('--model_name', type=str, required=True, choices=['deeplabv3plus', 'knet', 'mask2former', 'segformer'],
147
+ help='Model name to generate the config for.')
148
+ parser.add_argument('-m', '--model_config', type=str, required=True, help="Path to the model config file")
149
+ parser.add_argument('-d', '--dataset_config', type=str, required=True, help='Path to the dataset config file.')
150
+ parser.add_argument('-c', '--num_class', type=int, required=True, help="Number of classes in the dataset")
151
+ parser.add_argument('-w','--work_dir', type=str, required=True, help='Directory to save the train result.')
152
+ parser.add_argument('-s', '--save_dir', type=str, required=True, help="Directory to save the generated config file")
153
+ parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training.')
154
+ parser.add_argument('--max_iters', type=int, default=20000, help='Number of training iterations.')
155
+ parser.add_argument('--val_interval', type=int, default=500, help='Interval for validation during training.')
156
+
157
+
158
+ args = parser.parse_args()
159
+
160
+ if args.model_name == 'deeplabv3plus':
161
+ create_deeplabv3plus_config(
162
+ model_config_path=args.model_config,
163
+ dataset_config_path=args.dataset_config,
164
+ num_class=args.num_class,
165
+ work_dir=args.work_dir,
166
+ save_dir =args.save_dir,
167
+ batch_size=args.batch_size,
168
+ max_iters=args.max_iters,
169
+ val_interval=args.val_interval
170
+ )
171
+ if args.model_name == 'knet':
172
+ create_knet_config(
173
+ model_config_path=args.model_config,
174
+ dataset_config_path=args.dataset_config,
175
+ num_class=args.num_class,
176
+ work_dir=args.work_dir,
177
+ save_dir =args.save_dir,
178
+ batch_size=args.batch_size,
179
+ max_iters=args.max_iters,
180
+ val_interval=args.val_interval
181
+ )
182
+ if args.model_name == 'mask2former':
183
+ create_mask2former_config(
184
+ model_config_path=args.model_config,
185
+ dataset_config_path=args.dataset_config,
186
+ num_class=args.num_class,
187
+ work_dir=args.work_dir,
188
+ save_dir =args.save_dir,
189
+ batch_size=args.batch_size,
190
+ max_iters=args.max_iters,
191
+ val_interval=args.val_interval
192
+ )
193
+ elif args.model_name == 'segformer':
194
+ create_segformer_config(
195
+ model_config_path=args.model_config,
196
+ dataset_config_path=args.dataset_config,
197
+ num_class=args.num_class,
198
+ work_dir=args.work_dir,
199
+ save_dir =args.save_dir,
200
+ batch_size=args.batch_size,
201
+ max_iters=args.max_iters,
202
+ val_interval=args.val_interval
203
+ )
204
+
205
+ if __name__ == '__main__':
206
+ main()
run/split_dataset.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import random
4
+ from tqdm import tqdm
5
+
6
+ def main():
7
+ # Set the dataset path
8
+ Dataset_Path = 'CVRPDataset'
9
+ img_dir = os.path.join(Dataset_Path, 'img_dir')
10
+ ann_dir = os.path.join(Dataset_Path, 'ann_dir')
11
+
12
+ # Set the ratio of training set to test set
13
+ test_frac = 0.1
14
+ random.seed(123)
15
+
16
+ # Create directories if not exist
17
+ os.makedirs(os.path.join(Dataset_Path, 'train'), exist_ok=True)
18
+ os.makedirs(os.path.join(Dataset_Path, 'val'), exist_ok=True)
19
+
20
+ # Get image file paths
21
+ img_paths = os.listdir(img_dir)
22
+ random.shuffle(img_paths)
23
+
24
+ val_number = int(len(img_paths) * test_frac)
25
+ train_files = img_paths[val_number:]
26
+ val_files = img_paths[:val_number]
27
+
28
+ print(f"Total images: {len(img_paths)}")
29
+ print(f"Training set images: {len(train_files)}")
30
+ print(f"Test set images: {len(val_files)}")
31
+
32
+ # Move the training set images to the train directory
33
+ for each in tqdm(train_files, desc="Move the training set images"):
34
+ src_path = os.path.join(img_dir, each)
35
+ dst_path = os.path.join(Dataset_Path, 'train', each)
36
+ shutil.move(src_path, dst_path)
37
+
38
+ # Move the test set images to the test directory
39
+ for each in tqdm(val_files, desc="Move the test set images"):
40
+ src_path = os.path.join(img_dir, each)
41
+ dst_path = os.path.join(Dataset_Path, 'val', each)
42
+ shutil.move(src_path, dst_path)
43
+
44
+ # Move the train and val directories into img_dir
45
+ shutil.move(os.path.join(Dataset_Path, 'train'), os.path.join(img_dir, 'train'))
46
+ shutil.move(os.path.join(Dataset_Path, 'val'), os.path.join(img_dir, 'val'))
47
+
48
+ # Process annotation files
49
+ # Ensure the annotation directories exist
50
+ os.makedirs(os.path.join(Dataset_Path, 'train'), exist_ok=True)
51
+ os.makedirs(os.path.join(Dataset_Path, 'val'), exist_ok=True)
52
+
53
+ # Move the training set annotation files to the train directory
54
+ for each in tqdm(train_files, desc="Move the training set annotations"):
55
+ src_path = os.path.join(ann_dir, each.split('.')[0] + '.png')
56
+ dst_path = os.path.join(Dataset_Path, 'train', each.split('.')[0] + '.png')
57
+ shutil.move(src_path, dst_path)
58
+
59
+ # Move the test set annotation files to the test directory
60
+ for each in tqdm(val_files, desc="Move the test set annotations"):
61
+ src_path = os.path.join(ann_dir, each.split('.')[0] + '.png')
62
+ dst_path = os.path.join(Dataset_Path, 'val', each.split('.')[0] + '.png')
63
+ shutil.move(src_path, dst_path)
64
+
65
+ # Move the train and val annotation directories into ann_dir
66
+ shutil.move(os.path.join(Dataset_Path, 'train'), os.path.join(ann_dir, 'train'))
67
+ shutil.move(os.path.join(Dataset_Path, 'val'), os.path.join(ann_dir, 'val'))
68
+
69
+ if __name__ == '__main__':
70
+ main()
run/test.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ from tqdm import tqdm
5
+ import argparse
6
+ from mmseg.apis import init_model, inference_model
7
+
8
+
9
+ def process_single_img(img_path, model, outpath, palette_dict):
10
+
11
+ img_bgr = cv2.imread(img_path)
12
+
13
+ result = inference_model(model, img_bgr)
14
+ pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
15
+
16
+ # Map the predicted integer ID to the color of the corresponding category
17
+ pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
18
+ for idx in palette_dict.keys():
19
+ pred_mask_bgr[np.where(pred_mask==idx)] = palette_dict[idx]
20
+ pred_mask_bgr = pred_mask_bgr.astype('uint8')
21
+
22
+ save_path = os.path.join(outpath, os.path.basename(img_path))
23
+ cv2.imwrite(save_path, pred_mask_bgr)
24
+
25
+
26
+
27
+ def main(args):
28
+ # Initialize model
29
+ model = init_model(args.config_file, args.checkpoint_file, device=args.device)
30
+
31
+ # Define class palette
32
+ palette = [
33
+ ['background', [0, 0, 0]],
34
+ ['red', [0, 0, 255]]
35
+ ]
36
+ palette_dict = {idx: each[1] for idx, each in enumerate(palette)}
37
+
38
+ # Create output directory if not exists
39
+ if not os.path.exists(args.outpath):
40
+ os.mkdir(args.outpath)
41
+
42
+ # Process each image in the given directory
43
+ for img_name in tqdm(os.listdir(args.data_folder)):
44
+ img_path = os.path.join(args.data_folder, img_name)
45
+ process_single_img(img_path, model, args.outpath, palette_dict)
46
+
47
+
48
+ if __name__ == '__main__':
49
+ parser = argparse.ArgumentParser(description="Process images for semantic segmentation inference.")
50
+ parser.add_argument('-d','--data_folder', type=str, required=True, help="Path to the folder containing input images.")
51
+ parser.add_argument('-m','--config_file', type=str, required=True, help="Path to the model config file.")
52
+ parser.add_argument('-pth','--checkpoint_file', type=str, required=True, help="Path to the model checkpoint file.")
53
+ parser.add_argument('-o','--outpath', type=str, help="Path to save the output images.")
54
+ parser.add_argument('--device', type=str, default='cuda:0', help="Device to run the model (e.g., 'cuda:0', 'cpu').")
55
+
56
+ args = parser.parse_args()
57
+ main(args)
58
+