|
|
|
import argparse |
|
import os.path as osp |
|
|
|
import nibabel as nib |
|
import numpy as np |
|
from mmengine.utils import mkdir_or_exist |
|
from PIL import Image |
|
|
|
|
|
def read_files_from_txt(txt_path): |
|
with open(txt_path) as f: |
|
files = f.readlines() |
|
files = [file.strip() for file in files] |
|
return files |
|
|
|
|
|
def read_nii_file(nii_path): |
|
img = nib.load(nii_path).get_fdata() |
|
return img |
|
|
|
|
|
def split_3d_image(img): |
|
c, _, _ = img.shape |
|
res = [] |
|
for i in range(c): |
|
res.append(img[i, :, :]) |
|
return res |
|
|
|
|
|
def label_mapping(label): |
|
"""Label mapping from TransUNet paper setting. It only has 9 classes, which |
|
are 'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney', |
|
'liver', 'pancreas', 'spleen', 'stomach', respectively. Other foreground |
|
classes in original dataset are all set to background. |
|
|
|
More details could be found here: https://arxiv.org/abs/2102.04306 |
|
""" |
|
maped_label = np.zeros_like(label) |
|
maped_label[label == 8] = 1 |
|
maped_label[label == 4] = 2 |
|
maped_label[label == 3] = 3 |
|
maped_label[label == 2] = 4 |
|
maped_label[label == 6] = 5 |
|
maped_label[label == 11] = 6 |
|
maped_label[label == 1] = 7 |
|
maped_label[label == 7] = 8 |
|
return maped_label |
|
|
|
|
|
def pares_args(): |
|
parser = argparse.ArgumentParser( |
|
description='Convert synapse dataset to mmsegmentation format') |
|
parser.add_argument( |
|
'--dataset-path', type=str, help='synapse dataset path.') |
|
parser.add_argument( |
|
'--save-path', |
|
default='data/synapse', |
|
type=str, |
|
help='save path of the dataset.') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
args = pares_args() |
|
dataset_path = args.dataset_path |
|
save_path = args.save_path |
|
|
|
if not osp.exists(dataset_path): |
|
raise ValueError('The dataset path does not exist. ' |
|
'Please enter a correct dataset path.') |
|
if not osp.exists(osp.join(dataset_path, 'img')) \ |
|
or not osp.exists(osp.join(dataset_path, 'label')): |
|
raise FileNotFoundError('The dataset structure is incorrect. ' |
|
'Please check your dataset.') |
|
|
|
train_id = read_files_from_txt(osp.join(dataset_path, 'train.txt')) |
|
train_id = [idx[3:7] for idx in train_id] |
|
|
|
test_id = read_files_from_txt(osp.join(dataset_path, 'val.txt')) |
|
test_id = [idx[3:7] for idx in test_id] |
|
|
|
mkdir_or_exist(osp.join(save_path, 'img_dir/train')) |
|
mkdir_or_exist(osp.join(save_path, 'img_dir/val')) |
|
mkdir_or_exist(osp.join(save_path, 'ann_dir/train')) |
|
mkdir_or_exist(osp.join(save_path, 'ann_dir/val')) |
|
|
|
|
|
|
|
for i, idx in enumerate(train_id): |
|
img_3d = read_nii_file( |
|
osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz')) |
|
label_3d = read_nii_file( |
|
osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz')) |
|
|
|
img_3d = np.clip(img_3d, -125, 275) |
|
img_3d = (img_3d + 125) / 400 |
|
img_3d *= 255 |
|
img_3d = np.transpose(img_3d, [2, 0, 1]) |
|
img_3d = np.flip(img_3d, 2) |
|
|
|
label_3d = np.transpose(label_3d, [2, 0, 1]) |
|
label_3d = np.flip(label_3d, 2) |
|
label_3d = label_mapping(label_3d) |
|
|
|
for c in range(img_3d.shape[0]): |
|
img = img_3d[c] |
|
label = label_3d[c] |
|
|
|
img = Image.fromarray(img).convert('RGB') |
|
label = Image.fromarray(label).convert('L') |
|
img.save( |
|
osp.join( |
|
save_path, 'img_dir/train', 'case' + idx.zfill(4) + |
|
'_slice' + str(c).zfill(3) + '.jpg')) |
|
label.save( |
|
osp.join( |
|
save_path, 'ann_dir/train', 'case' + idx.zfill(4) + |
|
'_slice' + str(c).zfill(3) + '.png')) |
|
|
|
for i, idx in enumerate(test_id): |
|
img_3d = read_nii_file( |
|
osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz')) |
|
label_3d = read_nii_file( |
|
osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz')) |
|
|
|
img_3d = np.clip(img_3d, -125, 275) |
|
img_3d = (img_3d + 125) / 400 |
|
img_3d *= 255 |
|
img_3d = np.transpose(img_3d, [2, 0, 1]) |
|
img_3d = np.flip(img_3d, 2) |
|
|
|
label_3d = np.transpose(label_3d, [2, 0, 1]) |
|
label_3d = np.flip(label_3d, 2) |
|
label_3d = label_mapping(label_3d) |
|
|
|
for c in range(img_3d.shape[0]): |
|
img = img_3d[c] |
|
label = label_3d[c] |
|
|
|
img = Image.fromarray(img).convert('RGB') |
|
label = Image.fromarray(label).convert('L') |
|
img.save( |
|
osp.join( |
|
save_path, 'img_dir/val', 'case' + idx.zfill(4) + |
|
'_slice' + str(c).zfill(3) + '.jpg')) |
|
label.save( |
|
osp.join( |
|
save_path, 'ann_dir/val', 'case' + idx.zfill(4) + |
|
'_slice' + str(c).zfill(3) + '.png')) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|