File size: 11,906 Bytes
3953219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
# create transforms for training, validation and test dataset

## TODO: Make Transforms more dynamic by directly building from config args
## Maybe like this
## TFM_NAME=config.transforms.keys()[0]
## tfm_fun=getattr(monai.transforms, TFM_NAME)
## tmfs+=[tfms_fun(keys=image+cols, **config.transforms[TFM_NAME], prob=prob, mode=mode)


## ---------- imports ----------
import os
# only import of base transforms, others are imported as needed
from monai.utils.enums import CommonKeys
from monai.transforms import (
    Activationsd,
    AsDiscreted,
    Compose,
    ConcatItemsd,
    KeepLargestConnectedComponentd,
    LoadImaged,
    EnsureChannelFirstd,
    EnsureTyped,
    SaveImaged,
    ScaleIntensityd,
    NormalizeIntensityd
)
# images should be interploated with `bilinear` but masks with `nearest`

## ---------- base transforms ----------
# applied everytime
def get_base_transforms(
    config: dict,
    minv: int=0, 
    maxv: int=1
)->list:
    
    tfms=[]
    tfms+=[LoadImaged(keys=config.data.image_cols+config.data.label_cols)]
    tfms+=[EnsureChannelFirstd(keys=config.data.image_cols+config.data.label_cols)]
    if config.transforms.spacing:
        from monai.transforms import Spacingd
        tfms+=[
            Spacingd(
                keys=config.data.image_cols+config.data.label_cols,
                pixdim=config.transforms.spacing,
                mode=config.transforms.mode
            )
        ]
    if config.transforms.orientation:
        from monai.transforms import Orientationd
        tfms+=[
            Orientationd(
                keys=config.data.image_cols+config.data.label_cols,
                axcodes=config.transforms.orientation
            )
        ]
    tfms+=[
        ScaleIntensityd(
            keys=config.data.image_cols,
            minv=minv,
            maxv=maxv
        )
    ]
    tfms+=[NormalizeIntensityd(keys=config.data.image_cols)]
    return tfms

## ---------- train transforms ----------

def get_train_transforms(config: dict):
    tfms=get_base_transforms(config=config)

    # ---------- specific transforms for mri ----------
    if 'rand_bias_field' in config.transforms.keys():
        from monai.transforms import RandBiasFieldd
        args=config.transforms.rand_bias_field
        tfms+=[
            RandBiasFieldd(
                keys=config.data.image_cols,
                degree=args['degree'],
                coeff_range=args['coeff_range'],
                prob=config.transforms.prob
            )
        ]

    if 'rand_gaussian_smooth' in config.transforms.keys():
        from monai.transforms import RandGaussianSmoothd
        args=config.transforms.rand_gaussian_smooth
        tfms+=[
            RandGaussianSmoothd(
                keys=config.data.image_cols,
                sigma_x=args['sigma_x'],
                sigma_y=args['sigma_y'],
                sigma_z=args['sigma_z'],
                prob=config.transforms.prob
            )
        ]

    if 'rand_gibbs_nose' in config.transforms.keys():
        from monai.transforms import RandGibbsNoised
        args=config.transforms.rand_gibbs_nose
        tfms+=[
            RandGibbsNoised(
                keys=config.data.image_cols,
                alpha=args['alpha'],
                prob=config.transforms.prob
            )
        ]

    # ---------- affine transforms ----------

    if 'rand_affine' in config.transforms.keys():
        from monai.transforms import RandAffined
        args=config.transforms.rand_affine
        tfms+=[
            RandAffined(
                keys=config.data.image_cols+config.data.label_cols,
                rotate_range=args['rotate_range'],
                shear_range=args['shear_range'],
                translate_range=args['translate_range'],
                mode=config.transforms.mode,
                prob=config.transforms.prob
            )
        ]

    if 'rand_rotate90' in config.transforms.keys():
        from monai.transforms import RandRotate90d
        args=config.transforms.rand_rotate90
        tfms+=[
            RandRotate90d(
                keys=config.data.image_cols+config.data.label_cols,
                spatial_axes=args['spatial_axes'],
                prob=config.transforms.prob
            )
        ]

    if 'rand_rotate' in config.transforms.keys():
        from monai.transforms import RandRotated
        args=config.transforms.rand_rotate
        tfms+=[
            RandRotated(
                keys=config.data.image_cols+config.data.label_cols,
                range_x=args['range_x'],
                range_y=args['range_y'],
                range_z=args['range_z'],
                mode=config.transforms.mode,
                prob=config.transforms.prob
            )
        ]

    if 'rand_elastic' in config.transforms.keys():
        if config['ndim'] == 3:
            from monai.transforms import Rand3DElasticd as RandElasticd
        elif config['ndim'] == 2:
            from monai.transforms import Rand2DElasticd as RandElasticd
        args=config.transforms.rand_elastic
        tfms+=[
            RandElasticd(
                keys=config.data.image_cols+config.data.label_cols,
                sigma_range=args['sigma_range'],
                magnitude_range=args['magnitude_range'],
                rotate_range=args['rotate_range'],
                shear_range=args['shear_range'],
                translate_range=args['translate_range'],
                mode=config.transforms.mode,
                prob=config.transforms.prob
            )
        ]

    if 'rand_zoom' in config.transforms.keys():
        from monai.transforms import RandZoomd
        args=config.transforms.rand_zoom
        tfms+=[
            RandZoomd(
                keys=config.data.image_cols+config.data.label_cols,
                min_zoom=args['min'],
                max_zoom=args['max'],
                mode=['area' if x == 'bilinear' else x for x in config.transforms.mode],
                prob=config.transforms.prob
            )
        ]

    # ---------- random cropping, very effective for large images ----------
    # RandCropByPosNegLabeld is not advisable for data with missing lables
    # e.g., segmentation of carcinomas which are not present on all images
    # thus fallback to RandSpatialCropSamplesd. Completly replacing Cropping
    # by just resizing could be discussed, but I believe it is not beneficial
    # For the first version, this is an ungly hack. For the second version, 
    # a better verion for transforms should be written. 

    if 'rand_crop_pos_neg_label' in config.transforms.keys():
        from monai.transforms import RandCropByPosNegLabeld
        args=config.transforms.rand_crop_pos_neg_label
        tfms+=[
            RandCropByPosNegLabeld(
                keys=config.data.image_cols+config.data.label_cols,
                label_key=config.data.label_cols[0],
                spatial_size=args['spatial_size'],
                pos=args['pos'],
                neg=args['neg'],
                num_samples=args['num_samples'],
                image_key=config.data.image_cols[0],
                image_threshold=0,
            )
        ]
        
    elif 'rand_spatial_crop_samples' in config.transforms.keys():
        from monai.transforms import RandSpatialCropSamplesd
        args=config.transforms.rand_spatial_crop_samples
        tfms+=[
            RandSpatialCropSamplesd(
                keys=config.data.image_cols+config.data.label_cols,
                roi_size=args['roi_size'],
                random_size=False,
                num_samples=args['num_samples'],
            )
        ]
        
    else: 
        raise ValueError('Either `rand_crop_pos_neg_label` or `rand_spatial_crop_samples` '\
                         'need to be specified')
        
    # ---------- intensity transforms ----------

    if 'gaussian_noise' in config.transforms.keys():
        from monai.transforms import RandGaussianNoised
        args=config.transforms.gaussian_noise
        tfms+=[
            RandGaussianNoised(
                keys=config.data.image_cols,
                mean=args['mean'],
                std=args['std'],
                prob=config.transforms.prob
            )
        ]

    if 'shift_intensity' in config.transforms.keys():
        from monai.transforms import RandShiftIntensityd
        args=config.transforms.shift_intensity
        tfms+=[
            RandShiftIntensityd(
                keys=config.data.image_cols,
                offsets=args['offsets'],
                prob=config.transforms.prob
            )
        ]

    if 'gaussian_sharpen' in config.transforms.keys():
        from monai.transforms import RandGaussianSharpend
        args=config.transforms.gaussian_sharpen
        tfms+=[
            RandGaussianSharpend(
                keys=config.data.image_cols,
                sigma1_x=args['sigma1_x'],
                sigma1_y=args['sigma1_y'],
                sigma1_z=args['sigma1_z'],
                sigma2_x=args['sigma2_x'],
                sigma2_y=args['sigma2_y'],
                sigma2_z=args['sigma2_z'],
                alpha=args['alpha'],
                prob=config.transforms.prob
            )
        ]

    if 'adjust_contrast' in config.transforms.keys():
        from monai.transforms import RandAdjustContrastd
        args=config.transforms.adjust_contrast
        tfms+=[
            RandAdjustContrastd(
                keys=config.data.image_cols,
                gamma=args['gamma'],
                prob=config.transforms.prob
            )
        ]
        
    # Concat mutlisequence data to single Tensors on the ChannelDim
    # Rename images to `CommonKeys.IMAGE` and labels to `CommonKeys.LABELS`
    # for more compatibility with monai.engines
    
    tfms+=[
        ConcatItemsd(
            keys=config.data.image_cols, 
            name=CommonKeys.IMAGE, 
            dim=0
        )
    ]

    tfms+=[
        ConcatItemsd(
            keys=config.data.label_cols, 
            name=CommonKeys.LABEL, 
            dim=0
        )
    ]

    return Compose(tfms)

## ---------- valid transforms ----------

def get_val_transforms(config: dict):
    tfms=get_base_transforms(config=config)
    tfms+=[EnsureTyped(keys=config.data.image_cols+config.data.label_cols)]
    tfms+=[
        ConcatItemsd(
            keys=config.data.image_cols, 
            name=CommonKeys.IMAGE, 
            dim=0
        )
    ]

    tfms+=[
        ConcatItemsd(
            keys=config.data.label_cols, 
            name=CommonKeys.LABEL, 
            dim=0
        )
    ]
    
    return Compose(tfms)

## ---------- test transforms ----------
# same as valid transforms

def get_test_transforms(config: dict):
    tfms=get_base_transforms(config=config)
    tfms+=[EnsureTyped(keys=config.data.image_cols+config.data.label_cols)]
    tfms+=[
        ConcatItemsd(
            keys=config.data.image_cols, 
            name=CommonKeys.IMAGE, 
            dim=0
        )
    ]

    tfms+=[
        ConcatItemsd(
            keys=config.data.label_cols, 
            name=CommonKeys.LABEL, 
            dim=0
        )
    ]
    
    return Compose(tfms)


def get_val_post_transforms(config: dict): 
    tfms=[EnsureTyped(keys=[CommonKeys.PRED, CommonKeys.LABEL]),
            AsDiscreted(
                keys=CommonKeys.PRED, 
                argmax=True, 
                to_onehot=config.model.out_channels, 
                num_classes=config.model.out_channels
            ),
            AsDiscreted(
                keys=CommonKeys.LABEL, 
                to_onehot=config.model.out_channels, 
                num_classes=config.model.out_channels
            ),
            KeepLargestConnectedComponentd(
                keys=CommonKeys.PRED, 
                applied_labels=list(range(1, config.model.out_channels))
           ),
            ]
    return Compose(tfms)