lingchmao commited on
Commit
75b23cf
1 Parent(s): fc3517c

Delete utils/data_transforms.py

Browse files
Files changed (1) hide show
  1. utils/data_transforms.py +0 -267
utils/data_transforms.py DELETED
@@ -1,267 +0,0 @@
1
- import monai
2
- import cv2
3
- from monai.transforms import MapTransform
4
- import math
5
- import numpy as np
6
- import torch
7
- import morphsnakes as ms
8
- import monai
9
- import nrrd
10
- import torchvision.transforms as transforms
11
- from monai.transforms import (
12
- Activations, AsDiscreteD, AsDiscrete, Compose, CastToTypeD, RandSpatialCropd,
13
- ToTensorD, CropForegroundD, Resized, GaussianSmoothD,
14
- LoadImageD, TransposeD, OrientationD, ScaleIntensityRangeD,
15
- RandAffineD, ResizeWithPadOrCropd, ToTensor,
16
- FillHoles, KeepLargestConnectedComponent, HistogramNormalizeD, NormalizeIntensityD
17
- )
18
-
19
-
20
-
21
- def define_transforms_loadonly():
22
- transformations = Compose([
23
- LoadImageD(keys=["mask"], reader="NrrdReader", ensure_channel_first=True),
24
- ConvertMaskValues(keys=["mask"], keep_classes=["liver", "tumor"]),
25
- ToTensor()
26
- ])
27
- return transformations
28
-
29
-
30
- def define_post_processing(config):
31
- # Post-processing transforms
32
- post_processing = [
33
- # Apply softmax activation to convert logits to probabilities
34
- Activations(sigmoid=True),
35
- # Convert predicted probabilities to discrete values (0 or 1)
36
- AsDiscrete(argmax=True, to_onehot=None if len(config['KEEP_CLASSES']) <= 2 else len(config['KEEP_CLASSES'])),
37
- # Remove small connected components for 1=liver and 2=tumor
38
- KeepLargestConnectedComponent(applied_labels=[1]),
39
- # Fill holes in the binary mask for 1=liver and 2=tumor
40
- FillHoles(applied_labels=[1]),
41
- ToTensor()
42
- ]
43
-
44
- return Compose(post_processing)
45
-
46
- def define_transforms(config):
47
-
48
- transformations_test = [
49
- LoadImageD(keys=["image", "mask"], reader="NrrdReader", ensure_channel_first=True),
50
- # Orient up and down
51
- OrientationD(keys=["image", "mask"], axcodes="PLI"),
52
- ToTensorD(keys=["image", "mask"])
53
- # histogram equilization or normalization
54
- # HistogramNormalizeD(keys=["image"], num_bins=256, min=0, max=1),
55
- # Intensity normalization
56
- # NormalizeIntensityD(keys=["image"]),
57
- #CastToTypeD(keys=["image"], dtype=torch.float32),
58
- #CastToTypeD(keys=["mask"], dtype=torch.int32),
59
- ]
60
-
61
- if config['MASKNONLIVER']:
62
- transformations_test.extend(
63
- [
64
- MaskOutNonliver(mask_key="mask"),
65
- CropForegroundD(keys=["image", "mask"], source_key="image", allow_smaller=True),
66
- ]
67
- )
68
-
69
- transformations_test.append(
70
- # Windowing based on liver parameters
71
- ScaleIntensityRangeD(keys=["image"],
72
- a_min=config['HU_RANGE'][0],
73
- a_max=config['HU_RANGE'][1],
74
- b_min=0.0, b_max=1.0, clip=True
75
- )
76
- )
77
-
78
- if config['PREPROCESSING'] == "clihe":
79
- transformations_test.append(CLIHE(keys=["image"]))
80
-
81
- elif config['PREPROCESSING'] == "gaussian":
82
- transformations_test.append(GaussianSmoothD(keys=["image"], sigma=0.5))
83
-
84
- # convert labels to 0,1,2 instead of 0,1,2,3,4
85
- transformations_test.append(ConvertMaskValues(keys=["mask"], keep_classes=config['KEEP_CLASSES']))
86
-
87
- if len(config['KEEP_CLASSES']) > 2: # NEEDED FOR MULTICLASS https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/swin_unetr_brats21_segmentation_3d.ipynb
88
- transformations_test.append(AsDiscreteD(keys=["mask"], to_onehot=len(config['KEEP_CLASSES']))) # (N, C, H, W) 2d; (1, C, H, W, Z)
89
-
90
- if "3D" not in config['MODEL_NAME']:
91
- transformations_test.append(TransposeD(keys=["image", "mask"], indices=(3,0,1,2)))
92
-
93
- # training transforms include data augmentation
94
- transformations_train = transformations_test.copy()
95
- if config['MASKNONLIVER']: transformations_test = transformations_test[:4] + transformations_test[5:] # do not crop to liver foregroudn
96
-
97
- if config['DATA_AUGMENTATION']:
98
- if "3D" in config["MODEL_NAME"]:
99
- transformations_train.append(
100
- RandAffineD(keys=["image", "mask"], prob=0.2, padding_mode="border",
101
- mode="bilinear", spatial_size=config['ROI_SIZE'],
102
- rotate_range=(0.15,0.15,0.15), #translate_range=(30,30,30),
103
- scale_range=(0.1,0.1,0.1)))
104
- else:
105
- transformations_train.append(
106
- RandAffineD(keys=["image", "mask"], prob=0.2, padding_mode="border",
107
- mode="bilinear", #spatial_size=(512, 512),
108
- rotate_range=(0.15,0.15), #translate_range=(30,30),
109
- scale_range=(0.1,0.1)))
110
-
111
- transformations_train.extend(
112
- [
113
- RandSpatialCropd(keys=["image", "mask"], roi_size=config['ROI_SIZE'], random_size=False),
114
- ResizeWithPadOrCropd(keys=["image", "mask"], spatial_size=config['ROI_SIZE'], method="end", mode='constant', value=0)
115
- ]
116
- )
117
-
118
- postprocessing_transforms = define_post_processing(config)
119
- preprocessing_transforms_test = Compose(transformations_test)
120
- preprocessing_transforms_train = Compose(transformations_train)
121
- preprocessing_transforms_train.set_random_state(seed=1)
122
- preprocessing_transforms_test.set_random_state(seed=1)
123
-
124
- return preprocessing_transforms_train, preprocessing_transforms_test, postprocessing_transforms
125
-
126
-
127
-
128
- class CLIHE(MapTransform):
129
- def __init__(self, keys, allow_missing_keys=False):
130
- super().__init__(allow_missing_keys)
131
- self.keys = keys
132
-
133
- def __call__(self, data):
134
- for key in self.keys:
135
- if len(data['image'].shape) > 3: # 3D image
136
- data[key] = self.apply_clahe_3d(data[key]) # [B, 1, H, W, Z]
137
- else:
138
- data[key] = self.apply_clahe_2d(data[key]) # [B, 1, H, W, Z]
139
- return data
140
-
141
- def apply_clahe_3d(self, image):
142
- image = np.asarray(image)
143
- clahe_slices = []
144
- for slice_idx in range(image.shape[-1]):
145
- # Extract the current slice
146
- slice_2d = image[0, :, :, slice_idx]
147
-
148
- # Apply CLAHE to the current slice
149
- # slice_2d = cv2.medianBlur(slice_2d, 5)
150
- # slice_2d = cv2.anisotropicDiffusion(slice_2d, alpha=0.1, K=1, iterations=50)
151
- # slice_2d = anisotropic_diffusion(slice_2d)
152
- # slice_2d = cv2.Sobel(slice_2d, cv2.CV_64F, dx=1, dy=1, ksize=5)
153
- clahe = cv2.createCLAHE(clipLimit=1, tileGridSize=(16,16))
154
- slice_2d = clahe.apply(slice_2d.astype(np.uint8))
155
- #cv2.threshold(clahe_slice, 155, 255, cv2.THRESH_BINARY)
156
- kernel = np.ones((2,2), np.float32)/4
157
- slice_2d = cv2.filter2D(slice_2d, -1, kernel)
158
- #t = anisodiff2D(delta_t=0.2,kappa=50)
159
- #slice_2d = t.fit(slice_2d)
160
-
161
- # Append the CLAHE enhanced slice to the list
162
- clahe_slices.append(slice_2d)
163
-
164
- # Stack the CLAHE enhanced slices along the slice axis to form the 3D image
165
- clahe_image = np.stack(clahe_slices, axis=-1)
166
-
167
- return torch.from_numpy(clahe_image[None,:])
168
-
169
- def apply_clahe_2d(self, image):
170
- image = np.asarray(image)
171
-
172
- clahe = cv2.createCLAHE(clipLimit=5)
173
- clahe_slice = clahe.apply(image[0].astype(np.uint8))
174
-
175
- return torch.from_numpy(clahe_slice)
176
-
177
-
178
-
179
- class GaussianFilter(MapTransform):
180
- def __init__(self, keys, allow_missing_keys=False):
181
- super().__init__(allow_missing_keys)
182
- self.keys = keys
183
-
184
- def __call__(self, data):
185
- for key in self.keys:
186
- if len(data['image'].shape) > 3: # 3D image
187
- data[key] = self.apply_clahe_3d(data[key]) # [B, 1, H, W, Z]
188
- else:
189
- data[key] = self.apply_clahe_2d(data[key]) # [B, 1, H, W, Z]
190
- return data
191
-
192
- def apply_clahe_3d(self, image):
193
- image = np.asarray(image)
194
- clahe_slices = []
195
- for slice_idx in range(image.shape[-1]):
196
- # Extract the current slice
197
- slice_2d = image[0, :, :, slice_idx]
198
-
199
- # Apply CLAHE to the current slice
200
- kernel = np.ones((3,3), np.float32)/9
201
- slice_2d = cv2.filter2D(slice_2d, -1, kernel)
202
-
203
- # Append the CLAHE enhanced slice to the list
204
- clahe_slices.append(slice_2d)
205
-
206
- # Stack the CLAHE enhanced slices along the slice axis to form the 3D image
207
- clahe_image = np.stack(clahe_slices, axis=-1)
208
-
209
- return torch.from_numpy(clahe_image[None,:])
210
-
211
- def apply_clahe_2d(self, image):
212
- image = np.asarray(image)
213
-
214
- kernel = np.ones((3,3), np.float32)/9
215
- slice_2d = cv2.filter2D(image, -1, kernel)
216
-
217
- return torch.from_numpy(slice_2d)
218
-
219
-
220
- class Morphsnakes(MapTransform):
221
- # https://github.com/pmneila/morphsnakes/blob/master/morphsnakes.py
222
- def __init__(self, allow_missing_keys=False):
223
- super().__init__(allow_missing_keys)
224
-
225
- def __call__(self, data):
226
- if np.sum(data['mask'][-1]) > 0:
227
- res = ms.morphological_chan_vese(data['image'][0], iterations=2, init_level_set=data['mask'][-1])
228
- data['mask'] = res
229
- return data
230
-
231
-
232
- class MaskOutNonliver(MapTransform):
233
- def __init__(self, allow_missing_keys=False, mask_key="mask"):
234
- super().__init__(allow_missing_keys)
235
- self.mask_key = mask_key
236
-
237
- def __call__(self, data):
238
- # mask out non-liver regions of an image
239
- # non-liver regions are liver, tumor, or portal vein
240
- if data[self.mask_key].shape != data['image'].shape:
241
- return data
242
- data['image'][data[self.mask_key] >= 4] = -1000
243
- data['image'][data[self.mask_key] <= 0] = -1000
244
- return data
245
-
246
-
247
- class ConvertMaskValues(MapTransform):
248
- def __init__(self, keys, allow_missing_keys=False, keep_classes=["normal", "liver", "tumor"]):
249
- super().__init__(keys, allow_missing_keys)
250
- self.keep_classes = keep_classes
251
-
252
- def __call__(self, data):
253
- # original labels: 0 for normal region, 1 for liver, 2 for tumor mass, 3 for portal vein, and 4 for abdominal aorta.
254
- # converted labels: 0 for normal region and abdominal aorta, 1 for liver and portal vein, 2 for tumor mass
255
-
256
- for key in self.keys:
257
- data[key][data[key] > 4] = 4 # one patient had class label = 5, converted to 4
258
- if key in data:
259
- if "liver" not in self.keep_classes:
260
- data[key][data[key] == 1] = 0
261
- if "tumor" not in self.keep_classes:
262
- data[key][data[key] == 2] = 1
263
- if "portal vein" not in self.keep_classes:
264
- data[key][data[key] == 3] = 1
265
- if "abdominal aorta" not in self.keep_classes:
266
- data[key][data[key] >= 4] = 0
267
- return data