File size: 7,813 Bytes
c642393
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
'''
python3 split_data.py RegisteredImageFolderPath RegisteredLabelFolderPath

Given the parameter as the path to the registered images, 
function creates two folders in the base directory (same level as this script), randomly putting in
70 percent of images into the train and 30 percent to the test
'''
import os
import glob
import random
import shutil

from typing import Tuple
import numpy as np
from collections import OrderedDict
import json
import argparse


"""
creates a folder at a specified folder path if it does not exists
folder_path : relative path of the folder (from cur_dir) which needs to be created
over_write :(default: False) if True overwrite the existing folder 
 """
def parse_command_line():
    print('---'*10)
    print('Parsing Command Line Arguments')
    parser = argparse.ArgumentParser(
        description='pipeline for dataset split')
    parser.add_argument('-bp', metavar='base path', type=str,
                        help="Absolute path of the base directory")
    parser.add_argument('-ip', metavar='image path', type=str,
                        help="Relative path of the image directory")
    parser.add_argument('-sp', metavar='segmentation path', type=str,
                        help="Relative path of the image directory")
    parser.add_argument('-sl', metavar='segmentation information list', type=str, nargs='+',
                        help='a list of label name and corresponding value')
    parser.add_argument('-ti', metavar='task id', type=int,
                        help='task id number')
    parser.add_argument('-tn', metavar='task name', type=str,
                        help='task name')
    parser.add_argument('-kf', metavar='k-fold validation', type=int, default=5,
                        help='k-fold validation')
    argv = parser.parse_args()
    return argv


def make_if_dont_exist(folder_path, overwrite=False):

    if os.path.exists(folder_path):
        if not overwrite:
            print(f'{folder_path} exists.')
        else:
            print(f"{folder_path} overwritten")
            shutil.rmtree(folder_path)
            os.makedirs(folder_path)
    else:
        os.makedirs(folder_path)
        print(f"{folder_path} created!")


def rename(location, oldname, newname):

    os.rename(os.path.join(location, oldname), os.path.join(location, newname))


def main():
    args = parse_command_line()
    base = args.bp
    reg_data_path = args.ip
    lab_data_path = args.sp
    task_id = args.ti
    Name = args.tn
    k_fold = args.kf
    seg_list = args.sl
    base_dir = "/home/ameen"
    #os.chdir(base_dir)
    nnunet_dir = "nnUNet/nnunet/nnUNet_raw_data_base/nnUNet_raw_data"
    main_dir = os.path.join(base_dir, 'nnUNet/nnunet')
    make_if_dont_exist(os.path.join(main_dir, 'nnUNet_preprocessed'))
    make_if_dont_exist(os.path.join(main_dir, 'nnUNet_trained_models'))

    os.environ['nnUNet_raw_data_base'] = os.path.join(
        main_dir, 'nnUNet_raw_data_base')
    os.environ['nnUNet_preprocessed'] = os.path.join(
        main_dir, 'nnUNet_preprocessed')
    os.environ['RESULTS_FOLDER'] = os.path.join(
        main_dir, 'nnUNet_trained_models')

    random.seed(19)
    cur_path = os.getcwd()  # current working

    image_list = glob.glob(os.path.join(base, reg_data_path) + "/*.nii.gz")
    label_list = glob.glob(os.path.join(base, lab_data_path) + "/*.nii.gz")
    num_images = len(image_list)
    # compute number of data for each fold
    num_each_fold = divmod(num_images, k_fold)[0]
    fold_num = np.repeat(num_each_fold, k_fold)
    num_remain = divmod(num_images, k_fold)[1]
    count = 0
    while num_remain > 0:
        fold_num[count] += 1
        count = (count+1) % 5 
        num_remain -= 1
    
    random.shuffle(image_list)
    piece_data = {}
    start_point = 0
    # select scans for each fold
    for m in range(k_fold):
        piece_data[f'fold_{m}'] = image_list[start_point:start_point+fold_num[m]]
        start_point += fold_num[m]
    
    for j in range(k_fold):
        task_name = f"Task0{task_id}_{Name}_fold{j}"  # MODIFY
        task_id += 1
        task_folder_name = os.path.join(base_dir, nnunet_dir, task_name)
        train_image_dir = os.path.join(task_folder_name, 'imagesTr')
        train_label_dir = os.path.join(task_folder_name, 'labelsTr')
        test_dir = os.path.join(task_folder_name, 'imagesTs')

        make_if_dont_exist(task_folder_name)
        make_if_dont_exist(train_image_dir)
        make_if_dont_exist(train_label_dir)
        make_if_dont_exist(test_dir)
        # Dataset Split (70 / 30):
        num_test = fold_num[j]
        num_train = np.sum(fold_num) - num_test
        print("Number of training subjects: ", num_train,
            "\nNumber of testing subjects:", num_test, "\nTotal:", num_images)
        p = 0
        train_images = []
        # concat all 4 folds for training
        while p < len(piece_data):
            if p !=j:
                train_images.extend(piece_data[f'fold_{p}'])
            p+=1 
        # select one fold for testing
        test_images = piece_data[f'fold_{j}']
        # prepare for nnUNet training scans and labels
        for i in range(len(train_images)):
            filename1 = os.path.basename(train_images[i]).split(".")[0]
            number = ''.join(filter(lambda x: x.isdigit(), filename1)) 
            # put this image to the training folder
            shutil.copy(train_images[i], train_image_dir)
            filename = os.path.basename(train_images[i])
            rename(train_image_dir, filename, Name + "_" + number + "_0000.nii.gz")

            for label_dir in label_list:
                if label_dir.endswith(os.path.basename(train_images[i])):
                    shutil.copy(label_dir, train_label_dir)
                    rename(train_label_dir, filename, Name + "_" + number + '.nii.gz')
                    break
        # prepare for nnUNet testing scans
        for i in range(len(test_images)):
            # put this image to the test folder
            shutil.copy(test_images[i], test_dir)
            filename = os.path.basename(test_images[i])
            filename1 = os.path.basename(test_images[i]).split(".")[0]
            number = ''.join(filter(lambda x: x.isdigit(), filename1)) 
            rename(test_dir, filename, Name +  "_" + number + "_0000.nii.gz")

        # create json file
        json_dict = OrderedDict()
        json_dict['name'] = task_name
        json_dict['description'] = Name
        json_dict['tensorImageSize'] = "4D"
        json_dict['reference'] = "MODIFY"
        json_dict['licence'] = "MODIFY"
        json_dict['release'] = "0.0"
        json_dict['modality'] = {
            "0": "CT"
        }
        json_dict['labels'] = {
            "0": "background",
        }
        for i in range(0, len(seg_list), 2):
            assert(seg_list[i].isdigit() == True)
            assert(seg_list[i + 1].isdigit() == False)
            json_dict['labels'].update({
                seg_list[i]: seg_list[i + 1]
            })
        train_ids = os.listdir(train_image_dir)
        test_ids = os.listdir(test_dir)
        json_dict['numTraining'] = len(train_ids)
        json_dict['numTest'] = len(test_ids)
        json_dict['training'] = [{'image': "./imagesTr/%s" % (i[:i.find(
            "_0000")]+'.nii.gz'), "label": "./labelsTr/%s" % (i[:i.find("_0000")]+'.nii.gz')} for i in train_ids]
        json_dict['test'] = ["./imagesTs/%s" %
                            (i[:i.find("_0000")]+'.nii.gz') for i in test_ids]

        with open(os.path.join(task_folder_name, "dataset.json"), 'w') as f:
            json.dump(json_dict, f, indent=4, sort_keys=True)

        if os.path.exists(os.path.join(task_folder_name, 'dataset.json')):
            print("new json file created!")


if __name__ == '__main__':
    main()