File size: 8,490 Bytes
3953219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# create dataloaders form csv file

## ---------- imports ----------
import os        
import torch
import shutil
import numpy as np
import pandas as pd
from typing import Union
from monai.utils import first
from functools import partial
from collections import namedtuple
from monai.data import DataLoader as MonaiDataLoader
      
from . import transforms
from .utils import num_workers


def import_dataset(config: dict): 
    if config.data.dataset_type == 'persistent':
        from monai.data import PersistentDataset
        if os.path.exists(config.data.cache_dir): 
            shutil.rmtree(config.data.cache_dir) # rm previous cache DS
        os.makedirs(config.data.cache_dir, exist_ok = True)
        Dataset = partial(PersistentDataset, cache_dir = config.data.cache_dir)
    elif config.data.dataset_type == 'cache':
        from monai.data import CacheDataset
        raise NotImplementedError('CacheDataset not yet implemented')
    else:
        from monai.data import Dataset
    return Dataset


class DataLoader(MonaiDataLoader): 
    "overwrite monai DataLoader for enhanced viewing capabilities"
    
    def show_batch(self, 
                   image_key: str='image', 
                   label_key: str='label', 
                   image_transform=lambda x: x.squeeze().transpose(0,2).flip(-2), 
                   label_transform=lambda x: x.squeeze().transpose(0,2).flip(-2)): 
        """Args:
            image_key: dict key name for image to view
            label_key: dict kex name for corresponding label. Can be a tensor or str
            image_transform: transform input before it is passed to the viewer to ensure
                ndim of the image is equal to 3 and image is oriented correctly
            label_transform: transform labels before passed to the viewer, to ensure 
                segmentations masks have same shape and orientations as images. Should be 
                identity function of labels are str. 
        """
        from .viewer import ListViewer
        
        batch = first(self)
        image = torch.unbind(batch[image_key], 0)
        label = torch.unbind(batch[label_key], 0)
        
        ListViewer([image_transform(im) for im in image],
                   [label_transform(im) for im in label]).show()

# TODO
## Work with 3 dataloaders
        
def segmentation_dataloaders(config: dict, 
                             train: bool = None,
                             valid: bool = None,
                             test: bool = None,
                            ):
    """Create segmentation dataloaders
    Args:
        config: config file
        train: whether to return a train DataLoader
        valid: whether to return a valid DataLoader
        test: whether to return a test DateLoader
    Args from config: 
        data_dir: base directory for the data
        csv_name: path to csv file containing filenames and paths
        image_cols: columns in csv containing path to images
        label_cols: columns in csv containing path to label files
        dataset_type: PersistentDataset, CacheDataset and Dataset are supported
        cache_dir: cache directory to be used by PersistentDataset
        batch_size: batch size for training. Valid and test are always 1
        debug: run with reduced number of images
    Returns:
        list of:
            train_loader: DataLoader (optional, if train==True)
            valid_loader: DataLoader (optional, if valid==True)
            test_loader: DataLoader (optional, if test==True)
    """
    
    ## parse needed rguments from config
    if train is None: train = config.data.train
    if valid is None: valid = config.data.valid
    if test is None: test = config.data.test
    
    data_dir = config.data.data_dir
    train_csv = config.data.train_csv
    valid_csv = config.data.valid_csv
    test_csv = config.data.test_csv
    image_cols = config.data.image_cols
    label_cols = config.data.label_cols
    dataset_type = config.data.dataset_type
    cache_dir = config.data.cache_dir
    batch_size = config.data.batch_size
    debug = config.debug
            
    ## ---------- data dicts ----------

    # first a global data dict, containing only the filepath from image_cols and label_cols is created. For this,
    # the dataframe is reduced to only the relevant columns. Then the rows are iterated, converting each row into an
    # individual dict, as expected by monai

    if not isinstance(image_cols, (tuple, list)): image_cols = [image_cols]
    if not isinstance(label_cols, (tuple, list)): label_cols = [label_cols]

    train_df = pd.read_csv(train_csv)
    valid_df = pd.read_csv(valid_csv)
    test_df = pd.read_csv(test_csv)
    if debug: 
        train_df = train_df.sample(25)
        valid_df = valid_df.sample(5)
    
    train_df['split']='train'
    valid_df['split']='valid'
    test_df['split']='test'
    whole_df = []
    if train: whole_df += [train_df]
    if valid: whole_df += [valid_df]
    if test: whole_df += [test_df]
    df = pd.concat(whole_df)
    cols = image_cols + label_cols
    for col in cols:
        # create absolute file name from relative fn in df and data_dir
        df[col] = [os.path.join(data_dir, fn) for fn in df[col]]
        if not os.path.exists(list(df[col])[0]):
            raise FileNotFoundError(list(df[col])[0])


    data_dict = [dict(row[1]) for row in df[cols].iterrows()]
    # data_dict is not the correct name, list_of_data_dicts would be more accurate, but also longer.
    # The data_dict looks like this:
    # [
    #  {'image_col_1': 'data_dir/path/to/image1',
    #   'image_col_2': 'data_dir/path/to/image2'
    #   'label_col_1': 'data_dir/path/to/label1},
    #  {'image_col_1': 'data_dir/path/to/image1',
    #   'image_col_2': 'data_dir/path/to/image2'
    #   'label_col_1': 'data_dir/path/to/label1},
    #    ...]
    # Filename should now be absolute or relative to working directory

    # now we create separate data dicts for train, valid and test data respectively
    assert train or test or valid, 'No dataset type is specified (train/valid or test)'

    if test:
        test_files = list(map(data_dict.__getitem__, *np.where(df.split == 'test')))

    if valid:
        val_files = list(map(data_dict.__getitem__, *np.where(df.split == 'valid')))

    if train:
        train_files = list(map(data_dict.__getitem__, *np.where(df.split == 'train')))

    # transforms are specified in transforms.py and are just loaded here
    if train: train_transforms = transforms.get_train_transforms(config)
    if valid: val_transforms = transforms.get_val_transforms(config)
    if test: test_transforms = transforms.get_test_transforms(config)
    
    
    ## ---------- construct dataloaders ----------
    Dataset=import_dataset(config)
    data_loaders = []
    if train:
        train_ds = Dataset(
            data=train_files,
            transform=train_transforms
        )
        train_loader = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers(),
            shuffle=True
        )
        data_loaders.append(train_loader)

    if valid:
        val_ds = Dataset(
            data=val_files,
            transform=val_transforms
        )
        val_loader = DataLoader(
            val_ds,
            batch_size=1,
            num_workers=num_workers(),
            shuffle=False
        )
        data_loaders.append(val_loader)

    if test:
        test_ds = Dataset(
            data=test_files,
            transform=test_transforms
        )
        test_loader = DataLoader(
            test_ds,
            batch_size=1,
            num_workers=num_workers(),
            shuffle=False
        )
        data_loaders.append(test_loader)

    # if only one dataloader is constructed, return only this dataloader else return a named tuple with dataloaders,
    # so it is clear which DataLoader is train/valid or test

    if len(data_loaders) == 1:
        return data_loaders[0]
    else:
        DataLoaders = namedtuple(
            'DataLoaders',
            # create str with specification of loader type if train and test are true but
            # valid is false string will be 'train test'
            ' '.join(
                [
                    'train' if train else '',
                    'valid' if valid else '',
                    'test' if test else ''
                ]
            ).strip()
        )
        return  DataLoaders(*data_loaders)