|
import argparse |
|
import os.path as osp |
|
from functools import partial |
|
|
|
import mmcv |
|
import numpy as np |
|
from PIL import Image |
|
from scipy.io import loadmat |
|
|
|
AUG_LEN = 10582 |
|
|
|
|
|
def convert_mat(mat_file, in_dir, out_dir): |
|
data = loadmat(osp.join(in_dir, mat_file)) |
|
mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8) |
|
seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png')) |
|
Image.fromarray(mask).save(seg_filename, 'PNG') |
|
|
|
|
|
def generate_aug_list(merged_list, excluded_list): |
|
return list(set(merged_list) - set(excluded_list)) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser( |
|
description='Convert PASCAL VOC annotations to mmsegmentation format') |
|
parser.add_argument('devkit_path', help='pascal voc devkit path') |
|
parser.add_argument('aug_path', help='pascal voc aug path') |
|
parser.add_argument('-o', '--out_dir', help='output path') |
|
parser.add_argument( |
|
'--nproc', default=1, type=int, help='number of process') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
devkit_path = args.devkit_path |
|
aug_path = args.aug_path |
|
nproc = args.nproc |
|
if args.out_dir is None: |
|
out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug') |
|
else: |
|
out_dir = args.out_dir |
|
mmcv.mkdir_or_exist(out_dir) |
|
in_dir = osp.join(aug_path, 'dataset', 'cls') |
|
|
|
mmcv.track_parallel_progress( |
|
partial(convert_mat, in_dir=in_dir, out_dir=out_dir), |
|
list(mmcv.scandir(in_dir, suffix='.mat')), |
|
nproc=nproc) |
|
|
|
full_aug_list = [] |
|
with open(osp.join(aug_path, 'dataset', 'train.txt')) as f: |
|
full_aug_list += [line.strip() for line in f] |
|
with open(osp.join(aug_path, 'dataset', 'val.txt')) as f: |
|
full_aug_list += [line.strip() for line in f] |
|
|
|
with open( |
|
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', |
|
'train.txt')) as f: |
|
ori_train_list = [line.strip() for line in f] |
|
with open( |
|
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', |
|
'val.txt')) as f: |
|
val_list = [line.strip() for line in f] |
|
|
|
aug_train_list = generate_aug_list(ori_train_list + full_aug_list, |
|
val_list) |
|
assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format( |
|
AUG_LEN) |
|
|
|
with open( |
|
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', |
|
'trainaug.txt'), 'w') as f: |
|
f.writelines(line + '\n' for line in aug_train_list) |
|
|
|
aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list) |
|
assert len(aug_list) == AUG_LEN - len( |
|
ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN - |
|
len(ori_train_list)) |
|
with open( |
|
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'), |
|
'w') as f: |
|
f.writelines(line + '\n' for line in aug_list) |
|
|
|
print('Done!') |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|