lingchmao commited on
Commit
6ffe23f
1 Parent(s): 52a9229

Upload 12 files

Browse files
model development/run_best_model_notebook.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
utils/data_preparation.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from sklearn.model_selection import train_test_split
3
+ import monai
4
+ from monai.data import Dataset, DataLoader
5
+ from data_transforms import define_transforms, define_transforms_loadonly
6
+ import torch
7
+ import numpy as np
8
+ from visualization import visualize_patient
9
+ from monai.data import list_data_collate
10
+ import pandas as pd
11
+
12
+
13
+ def prepare_clinical_data(data_file, predictors):
14
+
15
+ # read data file
16
+ info = pd.read_excel(data_file, sheet_name=0)
17
+
18
+ # convert to numerical
19
+ info['CPS'] = info['CPS'].map({'A': 1, 'B': 2, 'C': 3})
20
+ info['T_involvment'] = info['T_involvment'].map({'< or = 50%': 1, '>50%': 2})
21
+ info['CLIP_Score'] = info['CLIP_Score'].map({'Stage_0': 0, 'Stage_1': 1, 'Stage_2': 2, 'Stage_3': 3, 'Stage_4': 4, 'Stage_5': 5, 'Stage_6': 6})
22
+ info['Okuda'] = info['Okuda'].map({'Stage I': 1, 'Stage II': 2, 'Stage III': 3})
23
+ info['TNM'] = info['TNM'].map({'Stage-I': 1, 'Stage-II': 2, 'Stage-IIIA': 3, 'Stage-IIIB': 4, 'Stage-IIIC': 5, 'Stage-IVA': 6, 'Stage-IVB': 7})
24
+ info['BCLC'] = info['BCLC'].map({'0': 0, 'Stage-A': 1, 'Stage-B': 2, 'Stage-C': 3, 'Stage-D': 4})
25
+
26
+ # remove duplicates
27
+ info.groupby("TCIA_ID").first()
28
+
29
+ # select columns
30
+ info = info[['TCIA_ID'] + predictors].rename(columns={'TCIA_ID': "patient_id"})
31
+
32
+
33
+ return info
34
+
35
+
36
+
37
+ def preparare_train_test_txt(data_dir, test_patient_ratio=0.2, seed=1):
38
+ """
39
+ From a list of patients, split them into train and test and export list to .txt files
40
+ """
41
+
42
+ # split based on seed, write to txt files
43
+ patients = os.listdir(data_dir)
44
+ patients.remove("HCC-TACE-Seg_clinical_data-V2.xlsx")
45
+ patients = list(set(patients))
46
+
47
+ # remove one patient with wrong labels
48
+ try:
49
+ patients.remove("HCC_017")
50
+ print("The patient HCC_017 is removed due to label issues including necrosis.")
51
+ except Exception as e:
52
+ pass
53
+
54
+ print("Total patients:", len(patients))
55
+ patients_train, patients_test = train_test_split(patients, test_size=test_patient_ratio, random_state=seed)
56
+ print(" There are", len(patients_train), "patients in training")
57
+ print(" There are", len(patients_test), "patients in test")
58
+
59
+ # export a copy
60
+ if not os.path.exists('train-test-split-seed' + str(seed)):
61
+ os.makedirs('train-test-split-seed' + str(seed))
62
+ with open(r'train-test-split-seed' + str(seed) + '/train.txt', 'w') as f:
63
+ f.write(','.join(patient for patient in patients_train))
64
+ with open(r'train-test-split-seed' + str(seed) + '/test.txt', 'w') as f:
65
+ f.write(','.join(patient for patient in patients_test))
66
+
67
+ print("Files saved to", 'train-test-split-seed' + str(seed) + '/train.txt and train-test-split-seed' + str(seed) + '/test.txt')
68
+ return
69
+
70
+
71
+
72
+
73
+ def extract_file_path(patient_id, data_folder):
74
+ """
75
+ Given one patient's ID, obtain the file path of the image and mask data.
76
+ If patient has multiple images, they are labeled as pre1, pre2, etc.
77
+ """
78
+ path = os.path.join(data_folder, patient_id)
79
+ files = os.listdir(path)
80
+ patient_files = {}
81
+ count = 1
82
+ for file in files:
83
+ if "seg" in file or "Segmentation" in file:
84
+ patient_files["mask"] = os.path.join(path, file)
85
+ else:
86
+ patient_files["pre_" + str(count)] = os.path.join(path, file)
87
+ count += 1
88
+ return patient_files
89
+
90
+
91
+
92
+ def get_patient_dictionaries(txt_file, data_dir):
93
+ """
94
+ From .txt file that stores list of patients, look through data folders and extract a dictionary of patient data
95
+ """
96
+ assert os.path.isfile(txt_file), "The file " + txt_file + " was not found. Please check your file directory."
97
+
98
+ file = open(txt_file, "r")
99
+ patients = file.read().split(',')
100
+
101
+ data_dict = []
102
+
103
+ for patient_id in patients:
104
+
105
+ # get directories for mask and images
106
+ patient_files = extract_file_path(patient_id, data_dir)
107
+
108
+ # pair up each image with the mask
109
+ for key, value in patient_files.items():
110
+ if key != "mask":
111
+ data_dict.append(
112
+ {
113
+ "patient_id": patient_id,
114
+ "image": patient_files[key],
115
+ "mask": patient_files["mask"]
116
+ }
117
+ )
118
+
119
+ print(" There are", len(data_dict), "image-masks in this dataset.")
120
+ return data_dict
121
+
122
+
123
+
124
+
125
+ def build_dataset(config, get_clinical=False):
126
+
127
+ def custom_collate_fn(batch):
128
+ """
129
+ Custom collate function to stack samples along the first dimension.
130
+
131
+ Args:
132
+ batch (list): List of dictionaries with keys "image" and "mask",
133
+ where values are tensors of shape (N, 1, 512, 512).
134
+
135
+ Returns:
136
+ tuple: Tuple containing two tensors:
137
+ - Stacked images of shape (B, 1, 512, 512)
138
+ - Stacked masks of shape (B, 1, 512, 512)
139
+ where B is the total number of samples in the batch.
140
+ """
141
+ # torch.manual_seed(1)
142
+ num_samples_to_select = config['BATCH_SIZE']
143
+
144
+ # Extract images and masks from the batch
145
+ images, masks = [], []
146
+ for sample in batch:
147
+ num_samples = min(sample["image"].shape[0], sample["mask"].shape[0])
148
+ random_indices = torch.randperm(num_samples)[:num_samples_to_select]
149
+ if "3D" in config['MODEL_NAME']: # 3D image
150
+ images.append(sample["image"][:,:512,:512,:]) # ensure image and mask same size
151
+ masks.append(sample["mask"][:,:512,:512,:])
152
+ else:
153
+ images.append(sample["image"][random_indices,:,:512,:512]) # ensure image and mask same size
154
+ masks.append(sample["mask"][random_indices,:,:512,:512])
155
+ #images.append(sample["image"][:,:,:512,:512]) # ensure image and mask same size
156
+ #masks.append(sample["mask"][:,:,:512,:512])
157
+
158
+ # Stack images and masks along the first dimension
159
+ try:
160
+ if "3D" not in config['MODEL_NAME']: # 3D image
161
+ concatenated_images = torch.cat(images, dim=0)
162
+ concatenated_masks = torch.cat(masks, dim=0)
163
+ else:
164
+ concatenated_images = torch.stack(images, dim=0)
165
+ concatenated_masks = torch.stack(masks, dim=0)
166
+ except Exception as e:
167
+ print("WARNING: not all images/masks are 512 by 512. Please check. ", images[0].shape, images[1].shape, masks[0].shape, masks[1].shape)
168
+ return None, None
169
+
170
+ # Return stacked images and masks as tensors
171
+ return {"image": concatenated_images, "mask": concatenated_masks}
172
+
173
+ # get list of training and test patient files
174
+ train_data_dict = get_patient_dictionaries(config['TRAIN_PATIENTS_FILE'], config['DATA_DIR'])
175
+ test_data_dict = get_patient_dictionaries(config['TEST_PATIENTS_FILE'], config['DATA_DIR'])
176
+ if config['ONESAMPLETESTRUN']: train_data_dict = train_data_dict[:2]
177
+ ttrain_data_dict, valid_data_dict = train_test_split(train_data_dict, test_size=config['VALID_PATIENT_RATIO'], shuffle=False, random_state=1) # must be false to match with linical data
178
+ print(" Training patients:", len(ttrain_data_dict), " Validation patients:", len(valid_data_dict))
179
+ print(" Test patients:", len(test_data_dict))
180
+
181
+ # define data transformations
182
+ preprocessing_transforms_train, preprocessing_transforms_test, postprocessing_transforms = define_transforms(config)
183
+
184
+ # create data loaders
185
+ train_ds = Dataset(ttrain_data_dict, transform=preprocessing_transforms_train)
186
+ valid_ds = Dataset(valid_data_dict, transform=preprocessing_transforms_test)
187
+ test_ds = Dataset(test_data_dict, transform=preprocessing_transforms_test)
188
+
189
+ if "3D" in config['MODEL_NAME']:
190
+ train_loader = DataLoader(train_ds, batch_size=config['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config['NUM_WORKERS'])
191
+ valid_loader = DataLoader(valid_ds, batch_size=config['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config['NUM_WORKERS'])
192
+ test_loader = DataLoader(test_ds, batch_size=config['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config['NUM_WORKERS'])
193
+ else:
194
+ train_loader = DataLoader(train_ds, batch_size=1, shuffle=False, collate_fn=custom_collate_fn, num_workers=config['NUM_WORKERS']) #, pin_memory=torch.cuda.is_available())
195
+ valid_loader = DataLoader(valid_ds, batch_size=1, shuffle=False, collate_fn=custom_collate_fn, num_workers=config['NUM_WORKERS']) #, pin_memory=torch.cuda.is_available())
196
+ test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=custom_collate_fn, num_workers=config['NUM_WORKERS']) #, pin_memory=torch.cuda.is_available())
197
+
198
+ # get clinical data
199
+ df_clinical_train = pd.DataFrame()
200
+ if get_clinical:
201
+ # define transforms
202
+ simple_transforms = define_transforms_loadonly()
203
+ simple_train_ds = Dataset(train_data_dict, transform=simple_transforms)
204
+ simple_train_loader = DataLoader(simple_train_ds, batch_size=config['BATCH_SIZE'], collate_fn=list_data_collate, shuffle=False, num_workers=config['NUM_WORKERS']) #, pin_memory=torch.cuda.is_available())
205
+
206
+ # compute tumor ratio within liver
207
+ df_clinical_train['patient_id'] = [p["patient_id"] for p in train_data_dict]
208
+ ratios_train, ratios_test = [], []
209
+ for batch_data in simple_train_loader:
210
+ labels = batch_data["mask"]
211
+ ratio = torch.sum(labels == 2, dim=(1, 2, 3, 4)) / torch.sum(labels > 0, dim=(1, 2, 3, 4))
212
+ ratios_train.append(ratio.cpu().numpy()[0]) # [metatensor()]
213
+ df_clinical_train['tumor_ratio'] = ratios_train
214
+
215
+ # get clinical features
216
+ info = prepare_clinical_data(config['CLINICAL_DATA_FILE'], config['CLINICAL_PREDICTORS'])
217
+ df_clinical_train = pd.merge(df_clinical_train, info, on='patient_id', how="left")
218
+ df_clinical_train.fillna(df_clinical_train.median(), inplace=True)
219
+ df_clinical_train.set_index("patient_id", inplace=True)
220
+
221
+ # visualize the data loader for one image to ensure correct formatting
222
+ print("Example data transformations:")
223
+ while True:
224
+ sample = preprocessing_transforms_train(train_data_dict[0])
225
+ if isinstance(sample, list): # depending on preprocessing, one sample may be [sample] or sample
226
+ sample = sample[0]
227
+ if torch.sum(sample['mask'][-1]) == 0: continue
228
+ print(f" image shape: {sample['image'].shape}")
229
+ print(f" mask shape: {sample['mask'].shape}")
230
+ print(f" mask values: {np.unique(sample['mask'])}")
231
+ #print(f" image affine:\n{sample['image'].meta['affine']}")
232
+ print(f" image min max: {np.min(sample['image']), np.max(sample['image'])}")
233
+ visualize_patient(sample['image'], sample['mask'], n_slices=3, z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
234
+ break
235
+
236
+ temp = monai.utils.first(test_loader)
237
+ print("Test loader shapes:", temp['image'].shape, temp['mask'].shape)
238
+
239
+ return train_loader, valid_loader, test_loader, postprocessing_transforms, df_clinical_train
240
+
241
+
utils/data_transforms.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import monai
2
+ import cv2
3
+ from monai.transforms import MapTransform
4
+ import math
5
+ import numpy as np
6
+ import torch
7
+ import morphsnakes as ms
8
+ import monai
9
+ import nrrd
10
+ import torchvision.transforms as transforms
11
+ from monai.transforms import (
12
+ Activations, AsDiscreteD, AsDiscrete, Compose, CastToTypeD, RandSpatialCropd,
13
+ ToTensorD, CropForegroundD, Resized, GaussianSmoothD,
14
+ LoadImageD, TransposeD, OrientationD, ScaleIntensityRangeD,
15
+ RandAffineD, ResizeWithPadOrCropd, ToTensor,
16
+ FillHoles, KeepLargestConnectedComponent, HistogramNormalizeD, NormalizeIntensityD
17
+ )
18
+
19
+
20
+
21
+ def define_transforms_loadonly():
22
+ transformations = Compose([
23
+ LoadImageD(keys=["mask"], reader="NrrdReader", ensure_channel_first=True),
24
+ ConvertMaskValues(keys=["mask"], keep_classes=["liver", "tumor"]),
25
+ ToTensor()
26
+ ])
27
+ return transformations
28
+
29
+
30
+ def define_post_processing(config):
31
+ # Post-processing transforms
32
+ post_processing = [
33
+ # Apply softmax activation to convert logits to probabilities
34
+ Activations(sigmoid=True),
35
+ # Convert predicted probabilities to discrete values (0 or 1)
36
+ AsDiscrete(argmax=True, to_onehot=None if len(config['KEEP_CLASSES']) <= 2 else len(config['KEEP_CLASSES'])),
37
+ # Remove small connected components for 1=liver and 2=tumor
38
+ KeepLargestConnectedComponent(applied_labels=[1]),
39
+ # Fill holes in the binary mask for 1=liver and 2=tumor
40
+ FillHoles(applied_labels=[1]),
41
+ ToTensor()
42
+ ]
43
+
44
+ return Compose(post_processing)
45
+
46
+ def define_transforms(config):
47
+
48
+ transformations_test = [
49
+ LoadImageD(keys=["image", "mask"], reader="NrrdReader", ensure_channel_first=True),
50
+ # Orient up and down
51
+ OrientationD(keys=["image", "mask"], axcodes="PLI"),
52
+ ToTensorD(keys=["image", "mask"])
53
+ # histogram equilization or normalization
54
+ # HistogramNormalizeD(keys=["image"], num_bins=256, min=0, max=1),
55
+ # Intensity normalization
56
+ # NormalizeIntensityD(keys=["image"]),
57
+ #CastToTypeD(keys=["image"], dtype=torch.float32),
58
+ #CastToTypeD(keys=["mask"], dtype=torch.int32),
59
+ ]
60
+
61
+ if config['MASKNONLIVER']:
62
+ transformations_test.extend(
63
+ [
64
+ MaskOutNonliver(mask_key="mask"),
65
+ CropForegroundD(keys=["image", "mask"], source_key="image", allow_smaller=True),
66
+ ]
67
+ )
68
+
69
+ transformations_test.append(
70
+ # Windowing based on liver parameters
71
+ ScaleIntensityRangeD(keys=["image"],
72
+ a_min=config['HU_RANGE'][0],
73
+ a_max=config['HU_RANGE'][1],
74
+ b_min=0.0, b_max=1.0, clip=True
75
+ )
76
+ )
77
+
78
+ if config['PREPROCESSING'] == "clihe":
79
+ transformations_test.append(CLIHE(keys=["image"]))
80
+
81
+ elif config['PREPROCESSING'] == "gaussian":
82
+ transformations_test.append(GaussianSmoothD(keys=["image"], sigma=0.5))
83
+
84
+ # convert labels to 0,1,2 instead of 0,1,2,3,4
85
+ transformations_test.append(ConvertMaskValues(keys=["mask"], keep_classes=config['KEEP_CLASSES']))
86
+
87
+ if len(config['KEEP_CLASSES']) > 2: # NEEDED FOR MULTICLASS https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/swin_unetr_brats21_segmentation_3d.ipynb
88
+ transformations_test.append(AsDiscreteD(keys=["mask"], to_onehot=len(config['KEEP_CLASSES']))) # (N, C, H, W) 2d; (1, C, H, W, Z)
89
+
90
+ if "3D" not in config['MODEL_NAME']:
91
+ transformations_test.append(TransposeD(keys=["image", "mask"], indices=(3,0,1,2)))
92
+
93
+ # training transforms include data augmentation
94
+ transformations_train = transformations_test.copy()
95
+ if config['MASKNONLIVER']: transformations_test = transformations_test[:4] + transformations_test[5:] # do not crop to liver foregroudn
96
+
97
+ if config['DATA_AUGMENTATION']:
98
+ if "3D" in config["MODEL_NAME"]:
99
+ transformations_train.append(
100
+ RandAffineD(keys=["image", "mask"], prob=0.2, padding_mode="border",
101
+ mode="bilinear", spatial_size=config['ROI_SIZE'],
102
+ rotate_range=(0.15,0.15,0.15), #translate_range=(30,30,30),
103
+ scale_range=(0.1,0.1,0.1)))
104
+ else:
105
+ transformations_train.append(
106
+ RandAffineD(keys=["image", "mask"], prob=0.2, padding_mode="border",
107
+ mode="bilinear", #spatial_size=(512, 512),
108
+ rotate_range=(0.15,0.15), #translate_range=(30,30),
109
+ scale_range=(0.1,0.1)))
110
+
111
+ transformations_train.extend(
112
+ [
113
+ RandSpatialCropd(keys=["image", "mask"], roi_size=config['ROI_SIZE'], random_size=False),
114
+ ResizeWithPadOrCropd(keys=["image", "mask"], spatial_size=config['ROI_SIZE'], method="end", mode='constant', value=0)
115
+ ]
116
+ )
117
+
118
+ postprocessing_transforms = define_post_processing(config)
119
+ preprocessing_transforms_test = Compose(transformations_test)
120
+ preprocessing_transforms_train = Compose(transformations_train)
121
+ preprocessing_transforms_train.set_random_state(seed=1)
122
+ preprocessing_transforms_test.set_random_state(seed=1)
123
+
124
+ return preprocessing_transforms_train, preprocessing_transforms_test, postprocessing_transforms
125
+
126
+
127
+
128
+ class CLIHE(MapTransform):
129
+ def __init__(self, keys, allow_missing_keys=False):
130
+ super().__init__(allow_missing_keys)
131
+ self.keys = keys
132
+
133
+ def __call__(self, data):
134
+ for key in self.keys:
135
+ if len(data['image'].shape) > 3: # 3D image
136
+ data[key] = self.apply_clahe_3d(data[key]) # [B, 1, H, W, Z]
137
+ else:
138
+ data[key] = self.apply_clahe_2d(data[key]) # [B, 1, H, W, Z]
139
+ return data
140
+
141
+ def apply_clahe_3d(self, image):
142
+ image = np.asarray(image)
143
+ clahe_slices = []
144
+ for slice_idx in range(image.shape[-1]):
145
+ # Extract the current slice
146
+ slice_2d = image[0, :, :, slice_idx]
147
+
148
+ # Apply CLAHE to the current slice
149
+ # slice_2d = cv2.medianBlur(slice_2d, 5)
150
+ # slice_2d = cv2.anisotropicDiffusion(slice_2d, alpha=0.1, K=1, iterations=50)
151
+ # slice_2d = anisotropic_diffusion(slice_2d)
152
+ # slice_2d = cv2.Sobel(slice_2d, cv2.CV_64F, dx=1, dy=1, ksize=5)
153
+ clahe = cv2.createCLAHE(clipLimit=1, tileGridSize=(16,16))
154
+ slice_2d = clahe.apply(slice_2d.astype(np.uint8))
155
+ #cv2.threshold(clahe_slice, 155, 255, cv2.THRESH_BINARY)
156
+ kernel = np.ones((2,2), np.float32)/4
157
+ slice_2d = cv2.filter2D(slice_2d, -1, kernel)
158
+ #t = anisodiff2D(delta_t=0.2,kappa=50)
159
+ #slice_2d = t.fit(slice_2d)
160
+
161
+ # Append the CLAHE enhanced slice to the list
162
+ clahe_slices.append(slice_2d)
163
+
164
+ # Stack the CLAHE enhanced slices along the slice axis to form the 3D image
165
+ clahe_image = np.stack(clahe_slices, axis=-1)
166
+
167
+ return torch.from_numpy(clahe_image[None,:])
168
+
169
+ def apply_clahe_2d(self, image):
170
+ image = np.asarray(image)
171
+
172
+ clahe = cv2.createCLAHE(clipLimit=5)
173
+ clahe_slice = clahe.apply(image[0].astype(np.uint8))
174
+
175
+ return torch.from_numpy(clahe_slice)
176
+
177
+
178
+
179
+ class GaussianFilter(MapTransform):
180
+ def __init__(self, keys, allow_missing_keys=False):
181
+ super().__init__(allow_missing_keys)
182
+ self.keys = keys
183
+
184
+ def __call__(self, data):
185
+ for key in self.keys:
186
+ if len(data['image'].shape) > 3: # 3D image
187
+ data[key] = self.apply_clahe_3d(data[key]) # [B, 1, H, W, Z]
188
+ else:
189
+ data[key] = self.apply_clahe_2d(data[key]) # [B, 1, H, W, Z]
190
+ return data
191
+
192
+ def apply_clahe_3d(self, image):
193
+ image = np.asarray(image)
194
+ clahe_slices = []
195
+ for slice_idx in range(image.shape[-1]):
196
+ # Extract the current slice
197
+ slice_2d = image[0, :, :, slice_idx]
198
+
199
+ # Apply CLAHE to the current slice
200
+ kernel = np.ones((3,3), np.float32)/9
201
+ slice_2d = cv2.filter2D(slice_2d, -1, kernel)
202
+
203
+ # Append the CLAHE enhanced slice to the list
204
+ clahe_slices.append(slice_2d)
205
+
206
+ # Stack the CLAHE enhanced slices along the slice axis to form the 3D image
207
+ clahe_image = np.stack(clahe_slices, axis=-1)
208
+
209
+ return torch.from_numpy(clahe_image[None,:])
210
+
211
+ def apply_clahe_2d(self, image):
212
+ image = np.asarray(image)
213
+
214
+ kernel = np.ones((3,3), np.float32)/9
215
+ slice_2d = cv2.filter2D(image, -1, kernel)
216
+
217
+ return torch.from_numpy(slice_2d)
218
+
219
+
220
+ class Morphsnakes(MapTransform):
221
+ # https://github.com/pmneila/morphsnakes/blob/master/morphsnakes.py
222
+ def __init__(self, allow_missing_keys=False):
223
+ super().__init__(allow_missing_keys)
224
+
225
+ def __call__(self, data):
226
+ if np.sum(data['mask'][-1]) > 0:
227
+ res = ms.morphological_chan_vese(data['image'][0], iterations=2, init_level_set=data['mask'][-1])
228
+ data['mask'] = res
229
+ return data
230
+
231
+
232
+ class MaskOutNonliver(MapTransform):
233
+ def __init__(self, allow_missing_keys=False, mask_key="mask"):
234
+ super().__init__(allow_missing_keys)
235
+ self.mask_key = mask_key
236
+
237
+ def __call__(self, data):
238
+ # mask out non-liver regions of an image
239
+ # non-liver regions are liver, tumor, or portal vein
240
+ if data[self.mask_key].shape != data['image'].shape:
241
+ return data
242
+ data['image'][data[self.mask_key] >= 4] = -1000
243
+ data['image'][data[self.mask_key] <= 0] = -1000
244
+ return data
245
+
246
+
247
+ class ConvertMaskValues(MapTransform):
248
+ def __init__(self, keys, allow_missing_keys=False, keep_classes=["normal", "liver", "tumor"]):
249
+ super().__init__(keys, allow_missing_keys)
250
+ self.keep_classes = keep_classes
251
+
252
+ def __call__(self, data):
253
+ # original labels: 0 for normal region, 1 for liver, 2 for tumor mass, 3 for portal vein, and 4 for abdominal aorta.
254
+ # converted labels: 0 for normal region and abdominal aorta, 1 for liver and portal vein, 2 for tumor mass
255
+
256
+ for key in self.keys:
257
+ data[key][data[key] > 4] = 4 # one patient had class label = 5, converted to 4
258
+ if key in data:
259
+ if "liver" not in self.keep_classes:
260
+ data[key][data[key] == 1] = 0
261
+ if "tumor" not in self.keep_classes:
262
+ data[key][data[key] == 2] = 1
263
+ if "portal vein" not in self.keep_classes:
264
+ data[key][data[key] == 3] = 1
265
+ if "abdominal aorta" not in self.keep_classes:
266
+ data[key][data[key] >= 4] = 0
267
+ return data
utils/inference.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from monai.transforms import (
4
+ Activations, AsDiscreteD, AsDiscrete, Compose, ToTensorD,
5
+ GaussianSmoothD, LoadImageD, TransposeD, OrientationD, ScaleIntensityRangeD,
6
+ ToTensor, FillHoles, KeepLargestConnectedComponent, NormalizeIntensityD
7
+ )
8
+ from nrrd import read
9
+ from visualization import visualize_results
10
+ from data_preparation import get_patient_dictionaries
11
+ from monai.data import Dataset, DataLoader
12
+ import os
13
+ from data_transforms import ConvertMaskValues, MaskOutNonliver
14
+ from pipeline import build_model, evaluate
15
+
16
+ def run_sequential_inference(txt_file, config_liver, config_tumor, eval_metrics, output_dir, only_tumor=False, export=True):
17
+
18
+ def custom_collate_fn(batch):
19
+ num_samples_to_select = config_liver['BATCH_SIZE']
20
+
21
+ # Extract images and masks from the batch, ensure image and mask same size
22
+ images, masks, pred_liver = [], [], []
23
+ for sample in batch:
24
+ num_samples = min(sample["image"].shape[0], sample["mask"].shape[0])
25
+ random_indices = torch.randperm(num_samples)[:num_samples_to_select]
26
+ images.append(sample["image"][:,:512,:512,:])
27
+ masks.append(sample["mask"][:,:512,:512,:])
28
+
29
+ # Stack images and masks along the first dimension
30
+ try:
31
+ concatenated_images = torch.stack(images, dim=0)
32
+ concatenated_masks = torch.stack(masks, dim=0)
33
+ except Exception as e:
34
+ print("WARNING: not all images/masks are 512 by 512. Please check. ", images[0].shape, images[1].shape, masks[0].shape, masks[1].shape)
35
+ return None, None
36
+
37
+ # Return stacked images and masks as tensors
38
+ if "pred_liver" in sample.keys():
39
+ return {"image": concatenated_images, "mask": concatenated_masks, "pred_liver": sample["pred_liver"]}
40
+ else:
41
+ return {"image": concatenated_images, "mask": concatenated_masks}
42
+
43
+ ### Model preparation
44
+ print("")
45
+ print("Loading models....")
46
+ liver_model = build_model(config_liver)
47
+ tumor_model = build_model(config_tumor)
48
+
49
+ #### Data preparation
50
+ print("")
51
+ print("Loading test data....")
52
+ test_data_dict = get_patient_dictionaries(txt_file=txt_file, data_dir=config_liver['DATA_DIR'])
53
+ print(" Number of test patients:", len(test_data_dict))
54
+
55
+ # assign output file names and paths
56
+ export_file_metadata = []
57
+ if not os.path.exists(output_dir): os.makedirs(output_dir)
58
+ for patient_dict in test_data_dict:
59
+ patient_folder = os.path.join(output_dir, patient_dict['patient_id'])
60
+ if not os.path.exists(patient_folder): os.makedirs(patient_folder)
61
+ patient_dict['pred_liver'] = os.path.join(patient_folder, "liver_segmentation.nrrd")
62
+ patient_dict['pred_tumor'] = os.path.join(patient_folder, "tumor_segmentation.nrrd")
63
+ export_file_metadata.append(read(patient_dict['image'])[1])
64
+
65
+ #### Liver segmentation
66
+ # define liver data loading and preprocessing
67
+ if not only_tumor:
68
+ print("")
69
+ print("Producing liver segmentations....")
70
+ liver_preprocessing = Compose([
71
+ LoadImageD(keys=["image", "mask"], reader="NrrdReader", ensure_channel_first=True),
72
+ OrientationD(keys=["image", "mask"], axcodes="PLI"),
73
+ ScaleIntensityRangeD(keys=["image"],
74
+ a_min=config_liver['HU_RANGE'][0],
75
+ a_max=config_liver['HU_RANGE'][1],
76
+ b_min=0.0, b_max=1.0, clip=True
77
+ ),
78
+ ConvertMaskValues(keys=["mask"], keep_classes=["liver"]),
79
+ ToTensorD(keys=["image", "mask"])
80
+ ])
81
+
82
+ liver_postprocessing = Compose([
83
+ Activations(sigmoid=True),
84
+ AsDiscrete(argmax=True, to_onehot=None),
85
+ KeepLargestConnectedComponent(applied_labels=[1]),
86
+ FillHoles(applied_labels=[1]),
87
+ ToTensor()
88
+ ])
89
+ test_ds_liver = Dataset(test_data_dict, transform=liver_preprocessing)
90
+ test_ds_liver = DataLoader(test_ds_liver, batch_size=config_liver['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config_liver['NUM_WORKERS'])
91
+
92
+ # produce liver model results
93
+ test_metrics_liver, sample_output_liver = evaluate(liver_model, test_ds_liver, eval_metrics, config_liver, postprocessing_transforms=liver_postprocessing, export_filenames = [p['pred_liver'] for p in test_data_dict], export_file_metadata=export_file_metadata)
94
+
95
+ print("")
96
+ print("==============================")
97
+ print("Liver segmentation test performance ....")
98
+ for key, value in test_metrics_liver.items():
99
+ print(f' {key.replace("_avg", "_liver")}: {value:.3f}')
100
+ print("==============================")
101
+
102
+ ##### Tumor segmentation
103
+ print("")
104
+ print("Producing tumor segmentations....")
105
+
106
+ # define tumor loading and preprocessing
107
+ tumor_preprocessing = Compose([
108
+ LoadImageD(keys=["image", "mask", "pred_liver"], reader="NrrdReader", ensure_channel_first=True),
109
+ OrientationD(keys=["image", "mask"], axcodes="PLI"),
110
+ MaskOutNonliver(mask_key="pred_liver"), # note that liver's predicted segmentation is used to crop to the liver region
111
+ ScaleIntensityRangeD(keys=["image"],
112
+ a_min=config_tumor['HU_RANGE'][0],
113
+ a_max=config_tumor['HU_RANGE'][1],
114
+ b_min=0.0, b_max=1.0, clip=True
115
+ ),
116
+ ConvertMaskValues(keys=["mask"], keep_classes=["liver", "tumor"]), # format mask for measuring test performance
117
+ AsDiscreteD(keys=["mask"], to_onehot=3), # format mask for measuring test performance
118
+ ToTensorD(keys=["image", "mask", "pred_liver"])
119
+ ])
120
+
121
+ tumor_postprocessing = Compose([
122
+ Activations(sigmoid=True),
123
+ AsDiscrete(argmax=True, to_onehot=3),
124
+ ToTensor()
125
+ ])
126
+
127
+ test_ds_tumor = Dataset(test_data_dict, transform=tumor_preprocessing)
128
+ test_ds_tumor = DataLoader(test_ds_tumor, batch_size=config_tumor['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config_tumor['NUM_WORKERS'])
129
+
130
+ test_metrics_tumor, sample_output_tumor = evaluate(tumor_model, test_ds_tumor, eval_metrics, config_tumor, tumor_postprocessing, use_liver_seg = True, export_filenames = [p['pred_tumor'] for p in test_data_dict] if export else [], export_file_metadata=export_file_metadata)
131
+
132
+ print("")
133
+ print("==============================")
134
+ print("Tumor segmentation test performance ....")
135
+ for key, value in test_metrics_tumor.items():
136
+ if "class2" in key:
137
+ print(f' {key.replace("_class2", "_tumor")}: {value:.3f}')
138
+ print("==============================")
139
+ print("")
140
+
141
+ #### Visualization
142
+
143
+ # combine liver and tumor segmentations into one segmentation output
144
+ if not only_tumor: sample_output_tumor[2][0][1] = sample_output_liver[2][0][0]
145
+
146
+ # visualization
147
+ print("")
148
+ if not only_tumor:
149
+ visualize_results(sample_output_liver[0][0].cpu(), sample_output_tumor[1][0].cpu(), sample_output_tumor[2][0].cpu(), n_slices=5, title="")
150
+ else:
151
+ visualize_results(sample_output_tumor[0][0].cpu(), sample_output_tumor[1][0].cpu(), sample_output_tumor[2][0].cpu(), n_slices=5, title="")
152
+
153
+ return
154
+
155
+
utils/loss.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from collections.abc import Callable, Sequence
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.nn.modules.loss import _Loss
10
+
11
+ from monai.losses.dice import DiceLoss
12
+ from monai.losses.focal_loss import FocalLoss
13
+ from monai.networks import one_hot
14
+ from monai.utils import DiceCEReduction, LossReduction, Weight, deprecated_arg, look_up_option, pytorch_after
15
+
16
+
17
+
18
+ ##### Adapted from Monai DiceFocalLoss
19
+ class WeaklyDiceFocalLoss(_Loss):
20
+ """
21
+ Compute Dice loss, Focal Loss, and weakly supervised loss from clinical predictor, and return the weighted sum of these three losses.
22
+
23
+ ``gamma`` and ``lambda_focal`` are only used for the focal loss.
24
+ ``include_background``, ``weight`` and ``reduction`` are used for both losses
25
+ and other parameters are only used for dice loss.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ include_background: bool = True,
31
+ to_onehot_y: bool = False,
32
+ sigmoid: bool = False,
33
+ softmax: bool = False,
34
+ other_act: Callable | None = None,
35
+ squared_pred: bool = False,
36
+ jaccard: bool = False,
37
+ reduction: str = "mean",
38
+ smooth_nr: float = 1e-5,
39
+ smooth_dr: float = 1e-5,
40
+ batch: bool = False,
41
+ gamma: float = 2.0,
42
+ focal_weight: Sequence[float] | float | int | torch.Tensor | None = None,
43
+ weight: Sequence[float] | float | int | torch.Tensor | None = None,
44
+ lambda_dice: float = 1.0,
45
+ lambda_focal: float = 1.0,
46
+ lambda_weak: float = 1.0,
47
+ ) -> None:
48
+ """
49
+ Args:
50
+ include_background: if False channel index 0 (background category) is excluded from the calculation.
51
+ to_onehot_y: whether to convert the ``target`` into the one-hot format,
52
+ using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
53
+ sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
54
+ don't need to specify activation function for `FocalLoss`.
55
+ softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
56
+ don't need to specify activation function for `FocalLoss`.
57
+ other_act: callable function to execute other activation layers, Defaults to ``None``.
58
+ for example: `other_act = torch.tanh`. only used by the `DiceLoss`, not for `FocalLoss`.
59
+ squared_pred: use squared versions of targets and predictions in the denominator or not.
60
+ jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
61
+ reduction: {``"none"``, ``"mean"``, ``"sum"``}
62
+ Specifies the reduction to apply to the output. Defaults to ``"mean"``.
63
+
64
+ - ``"none"``: no reduction will be applied.
65
+ - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
66
+ - ``"sum"``: the output will be summed.
67
+
68
+ smooth_nr: a small constant added to the numerator to avoid zero.
69
+ smooth_dr: a small constant added to the denominator to avoid nan.
70
+ batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
71
+ Defaults to False, a Dice loss value is computed independently from each item in the batch
72
+ before any `reduction`.
73
+ gamma: value of the exponent gamma in the definition of the Focal loss.
74
+ weight: weights to apply to the voxels of each class. If None no weights are applied.
75
+ The input can be a single value (same weight for all classes), a sequence of values (the length
76
+ of the sequence should be the same as the number of classes).
77
+ lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0.
78
+ Defaults to 1.0.
79
+ lambda_focal: the trade-off weight value for focal loss. The value should be no less than 0.0.
80
+ Defaults to 1.0.
81
+ lambda_weak: the trade-off weight value for weakly supervised loss. The value should be no less than 0.0
82
+ Defaults to 0.2.
83
+
84
+ """
85
+ super().__init__()
86
+ weight = focal_weight if focal_weight is not None else weight
87
+ self.dice = DiceLoss(
88
+ include_background=include_background,
89
+ to_onehot_y=False,
90
+ sigmoid=sigmoid,
91
+ softmax=softmax,
92
+ other_act=other_act,
93
+ squared_pred=squared_pred,
94
+ jaccard=jaccard,
95
+ reduction=reduction,
96
+ smooth_nr=smooth_nr,
97
+ smooth_dr=smooth_dr,
98
+ batch=batch,
99
+ weight=weight,
100
+ )
101
+ self.focal = FocalLoss(
102
+ include_background=include_background, to_onehot_y=False, gamma=gamma, weight=weight, reduction=reduction
103
+ )
104
+ if lambda_dice < 0.0:
105
+ raise ValueError("lambda_dice should be no less than 0.0.")
106
+ if lambda_focal < 0.0:
107
+ raise ValueError("lambda_focal should be no less than 0.0.")
108
+ if lambda_weak < 0.0:
109
+ raise ValueError("lambda_weak should be no less than 0.0.")
110
+ self.lambda_dice = lambda_dice
111
+ self.lambda_focal = lambda_focal
112
+ self.to_onehot_y = to_onehot_y
113
+ self.lambda_weak = lambda_weak
114
+
115
+
116
+ def compute_weakly_supervised_loss(self, input: torch.Tensor, weaktarget: torch.Tensor) -> torch.Tensor:
117
+ # compute ratio of tumor/liver in the predicted mask
118
+ tumor_pixels = torch.sum(input[:, -1, ...], dim=(1, 2, 3))
119
+ liver_pixels = torch.sum(input[:, -2, ...], dim=(1, 2, 3)) + tumor_pixels
120
+ predicted_ratio = tumor_pixels / liver_pixels
121
+ loss = torch.mean((predicted_ratio - weaktarget) ** 2)
122
+ return loss
123
+
124
+
125
+
126
+ def forward(self, input: torch.Tensor, target: torch.Tensor, weaktarget: torch.Tensor) -> torch.Tensor:
127
+ """
128
+ Args:
129
+ input: the shape should be BNH[WD]. The input should be the original logits
130
+ due to the restriction of ``monai.losses.FocalLoss``.
131
+ target: the shape should be BNH[WD] or B1H[WD].
132
+
133
+ Raises:
134
+ ValueError: When number of dimensions for input and target are different.
135
+ ValueError: When number of channels for target is neither 1 nor the same as input.
136
+
137
+ """
138
+ if len(input.shape) != len(target.shape):
139
+ raise ValueError(
140
+ "the number of dimensions for input and target should be the same, "
141
+ f"got shape {input.shape} and {target.shape}."
142
+ )
143
+ if self.to_onehot_y:
144
+ n_pred_ch = input.shape[1]
145
+ if n_pred_ch == 1:
146
+ warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
147
+ else:
148
+ target = one_hot(target, num_classes=n_pred_ch)
149
+ dice_loss = self.dice(input, target)
150
+ focal_loss = self.focal(input, target)
151
+ weak_loss = self.compute_weakly_supervised_loss(input, weaktarget)
152
+ total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss + self.lambda_weak * weak_loss
153
+ return total_loss
utils/models.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ # 2D: net = UNet2D(1,2,pab_channels=64,use_batchnorm=True)
8
+ # 3D: net = UNet3D(1,2,pab_channels=32,use_batchnorm=True)
9
+
10
+ class _NonLocalBlockND(nn.Module):
11
+ def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
12
+ super(_NonLocalBlockND, self).__init__()
13
+
14
+ assert dimension in [1, 2, 3]
15
+
16
+ self.dimension = dimension
17
+ self.sub_sample = sub_sample
18
+
19
+ self.in_channels = in_channels
20
+ self.inter_channels = inter_channels
21
+
22
+ if self.inter_channels is None:
23
+ self.inter_channels = in_channels // 2
24
+ if self.inter_channels == 0:
25
+ self.inter_channels = 1
26
+
27
+ if dimension == 3:
28
+ conv_nd = nn.Conv3d
29
+ max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
30
+ bn = nn.BatchNorm3d
31
+ elif dimension == 2:
32
+ conv_nd = nn.Conv2d
33
+ max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
34
+ bn = nn.BatchNorm2d
35
+ else:
36
+ conv_nd = nn.Conv1d
37
+ max_pool_layer = nn.MaxPool1d(kernel_size=(2))
38
+ bn = nn.BatchNorm1d
39
+
40
+ self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
41
+ kernel_size=1, stride=1, padding=0)
42
+
43
+ if bn_layer:
44
+ self.W = nn.Sequential(
45
+ conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
46
+ kernel_size=1, stride=1, padding=0),
47
+ bn(self.in_channels)
48
+ )
49
+ nn.init.constant_(self.W[1].weight, 0)
50
+ nn.init.constant_(self.W[1].bias, 0)
51
+ else:
52
+ self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
53
+ kernel_size=1, stride=1, padding=0)
54
+ nn.init.constant_(self.W.weight, 0)
55
+ nn.init.constant_(self.W.bias, 0)
56
+
57
+ self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
58
+ kernel_size=1, stride=1, padding=0)
59
+
60
+ self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
61
+ kernel_size=1, stride=1, padding=0)
62
+
63
+ if sub_sample:
64
+ self.g = nn.Sequential(self.g, max_pool_layer)
65
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
66
+
67
+ def forward(self, x):
68
+ '''
69
+ :param x: (b, c, t, h, w)
70
+ :return:
71
+ '''
72
+
73
+ batch_size = x.size(0)
74
+
75
+ g_x = self.g(x).view(batch_size, self.inter_channels, -1)
76
+ g_x = g_x.permute(0, 2, 1)
77
+
78
+ theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
79
+ theta_x = theta_x.permute(0, 2, 1)
80
+ phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
81
+ f = torch.matmul(theta_x, phi_x)
82
+ N = f.size(-1)
83
+ f_div_C = f / N
84
+
85
+ y = torch.matmul(f_div_C, g_x)
86
+ y = y.permute(0, 2, 1).contiguous()
87
+ y = y.view(batch_size, self.inter_channels, *x.size()[2:])
88
+ W_y = self.W(y)
89
+ z = W_y + x
90
+
91
+ return z
92
+
93
+
94
+ class NONLocalBlock1D(_NonLocalBlockND):
95
+ def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
96
+ super(NONLocalBlock1D, self).__init__(in_channels,
97
+ inter_channels=inter_channels,
98
+ dimension=1, sub_sample=sub_sample,
99
+ bn_layer=bn_layer)
100
+
101
+
102
+ class NONLocalBlock2D(_NonLocalBlockND):
103
+ def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
104
+ super(NONLocalBlock2D, self).__init__(in_channels,
105
+ inter_channels=inter_channels,
106
+ dimension=2, sub_sample=sub_sample,
107
+ bn_layer=bn_layer)
108
+
109
+
110
+ class NONLocalBlock3D(_NonLocalBlockND):
111
+ def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
112
+ super(NONLocalBlock3D, self).__init__(in_channels,
113
+ inter_channels=inter_channels,
114
+ dimension=3, sub_sample=sub_sample,
115
+ bn_layer=bn_layer)
116
+
117
+
118
+
119
+ class Conv2dReLU(nn.Sequential):
120
+ def __init__(
121
+ self,
122
+ in_channels,
123
+ out_channels,
124
+ kernel_size,
125
+ padding=0,
126
+ stride=1,
127
+ use_batchnorm=True,
128
+ ):
129
+
130
+ if use_batchnorm == "inplace" and InPlaceABN is None:
131
+ raise RuntimeError(
132
+ "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
133
+ + "To install see: https://github.com/mapillary/inplace_abn"
134
+ )
135
+
136
+ conv = nn.Conv2d(
137
+ in_channels,
138
+ out_channels,
139
+ kernel_size,
140
+ stride=stride,
141
+ padding=padding,
142
+ bias=not (use_batchnorm),
143
+ )
144
+ relu = nn.ReLU(inplace=True)
145
+
146
+ if use_batchnorm == "inplace":
147
+ bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
148
+ relu = nn.Identity()
149
+
150
+ elif use_batchnorm and use_batchnorm != "inplace":
151
+ bn = nn.BatchNorm2d(out_channels)
152
+
153
+ else:
154
+ bn = nn.Identity()
155
+
156
+ super(Conv2dReLU, self).__init__(conv, bn, relu)
157
+
158
+ class Conv3dReLU(nn.Sequential):
159
+ def __init__(
160
+ self,
161
+ in_channels,
162
+ out_channels,
163
+ kernel_size,
164
+ padding=0,
165
+ stride=1,
166
+ use_batchnorm=True,
167
+ ):
168
+
169
+ if use_batchnorm == "inplace" and InPlaceABN is None:
170
+ raise RuntimeError(
171
+ "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
172
+ + "To install see: https://github.com/mapillary/inplace_abn"
173
+ )
174
+
175
+ conv = nn.Conv3d(
176
+ in_channels,
177
+ out_channels,
178
+ kernel_size,
179
+ stride=stride,
180
+ padding=padding,
181
+ bias=not (use_batchnorm),
182
+ )
183
+ relu = nn.ReLU(inplace=True)
184
+
185
+ if use_batchnorm == "inplace":
186
+ bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
187
+ relu = nn.Identity()
188
+
189
+ elif use_batchnorm and use_batchnorm != "inplace":
190
+ bn = nn.BatchNorm3d(out_channels)
191
+
192
+ else:
193
+ bn = nn.Identity()
194
+
195
+ super(Conv3dReLU, self).__init__(conv, bn, relu)
196
+ class PAB2D(nn.Module):
197
+ def __init__(self, in_channels, out_channels, pab_channels=64):
198
+ super(PAB2D, self).__init__()
199
+ # Series of 1x1 conv to generate attention feature maps
200
+ self.pab_channels = pab_channels
201
+ self.in_channels = in_channels
202
+ self.top_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1)
203
+ self.center_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1)
204
+ self.bottom_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
205
+ self.map_softmax = nn.Softmax(dim=1)
206
+ self.out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
207
+
208
+ def forward(self, x):
209
+ bsize = x.size()[0]
210
+ h = x.size()[2]
211
+ w = x.size()[3]
212
+ x_top = self.top_conv(x)
213
+ x_center = self.center_conv(x)
214
+ x_bottom = self.bottom_conv(x)
215
+
216
+ x_top = x_top.flatten(2)
217
+ x_center = x_center.flatten(2).transpose(1, 2)
218
+ x_bottom = x_bottom.flatten(2).transpose(1, 2)
219
+
220
+ sp_map = torch.matmul(x_center, x_top)
221
+ sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h*w, h*w)
222
+ sp_map = torch.matmul(sp_map, x_bottom)
223
+ sp_map = sp_map.reshape(bsize, self.in_channels, h, w)
224
+ x = x + sp_map
225
+ x = self.out_conv(x)
226
+ # print('x_top',x_top.shape,'x_center',x_center.shape,'x_bottom',x_bottom.shape,'x',x.shape,'sp_map',sp_map.shape)
227
+ return x
228
+
229
+ class MFAB2D(nn.Module):
230
+ def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16):
231
+ # MFAB is just a modified version of SE-blocks, one for skip, one for input
232
+ super(MFAB2D, self).__init__()
233
+ self.hl_conv = nn.Sequential(
234
+ Conv2dReLU(
235
+ in_channels,
236
+ in_channels,
237
+ kernel_size=3,
238
+ padding=1,
239
+ use_batchnorm=use_batchnorm,
240
+ ),
241
+ Conv2dReLU(
242
+ in_channels,
243
+ skip_channels,
244
+ kernel_size=1,
245
+ use_batchnorm=use_batchnorm,
246
+ )
247
+ )
248
+ self.SE_ll = nn.Sequential(
249
+ nn.AdaptiveAvgPool2d(1),
250
+ nn.Conv2d(skip_channels, skip_channels // reduction, 1),
251
+ nn.ReLU(inplace=True),
252
+ nn.Conv2d(skip_channels // reduction, skip_channels, 1),
253
+ nn.Sigmoid(),
254
+ )
255
+ self.SE_hl = nn.Sequential(
256
+ nn.AdaptiveAvgPool2d(1),
257
+ nn.Conv2d(skip_channels, skip_channels // reduction, 1),
258
+ nn.ReLU(inplace=True),
259
+ nn.Conv2d(skip_channels // reduction, skip_channels, 1),
260
+ nn.Sigmoid(),
261
+ )
262
+ self.conv1 = Conv2dReLU(
263
+ skip_channels + skip_channels, # we transform C-prime form high level to C from skip connection
264
+ out_channels,
265
+ kernel_size=3,
266
+ padding=1,
267
+ use_batchnorm=use_batchnorm,
268
+ )
269
+ self.conv2 = Conv2dReLU(
270
+ out_channels,
271
+ out_channels,
272
+ kernel_size=3,
273
+ padding=1,
274
+ use_batchnorm=use_batchnorm,
275
+ )
276
+
277
+ def forward(self, x, skip=None):
278
+ x = self.hl_conv(x)
279
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
280
+ attention_hl = self.SE_hl(x)
281
+ if skip is not None:
282
+ attention_ll = self.SE_ll(skip)
283
+ attention_hl = attention_hl + attention_ll
284
+ x = x * attention_hl
285
+ x = torch.cat([x, skip], dim=1)
286
+ x = self.conv1(x)
287
+ x = self.conv2(x)
288
+ return x
289
+
290
+ class PAB3D(nn.Module):
291
+ def __init__(self, in_channels, out_channels, pab_channels=64):
292
+ super(PAB3D, self).__init__()
293
+ # Series of 1x1 conv to generate attention feature maps
294
+ self.pab_channels = pab_channels
295
+ self.in_channels = in_channels
296
+ self.top_conv = nn.Conv3d(in_channels, pab_channels, kernel_size=1)
297
+ self.center_conv = nn.Conv3d(in_channels, pab_channels, kernel_size=1)
298
+ self.bottom_conv = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
299
+ self.map_softmax = nn.Softmax(dim=1)
300
+ self.out_conv = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
301
+
302
+ def forward(self, x):
303
+ bsize = x.size()[0]
304
+ h = x.size()[2]
305
+ w = x.size()[3]
306
+ d = x.size()[4]
307
+ x_top = self.top_conv(x)
308
+ x_center = self.center_conv(x)
309
+ x_bottom = self.bottom_conv(x)
310
+
311
+ x_top = x_top.flatten(2)
312
+ x_center = x_center.flatten(2).transpose(1, 2)
313
+ x_bottom = x_bottom.flatten(2).transpose(1, 2)
314
+ sp_map = torch.matmul(x_center, x_top)
315
+ sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h*w*d, h*w*d)
316
+ sp_map = torch.matmul(sp_map, x_bottom)
317
+ sp_map = sp_map.reshape(bsize, self.in_channels, h, w, d)
318
+ x = x + sp_map
319
+ x = self.out_conv(x)
320
+ # print('x_top',x_top.shape,'x_center',x_center.shape,'x_bottom',x_bottom.shape,'x',x.shape,'sp_map',sp_map.shape)
321
+ return x
322
+
323
+ class MFAB3D(nn.Module):
324
+ def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16):
325
+ # MFAB is just a modified version of SE-blocks, one for skip, one for input
326
+ super(MFAB3D, self).__init__()
327
+ self.hl_conv = nn.Sequential(
328
+ Conv3dReLU(
329
+ in_channels,
330
+ in_channels,
331
+ kernel_size=3,
332
+ padding=1,
333
+ use_batchnorm=use_batchnorm,
334
+ ),
335
+ Conv3dReLU(
336
+ in_channels,
337
+ skip_channels,
338
+ kernel_size=1,
339
+ use_batchnorm=use_batchnorm,
340
+ )
341
+ )
342
+ self.SE_ll = nn.Sequential(
343
+ nn.AdaptiveAvgPool3d(1),
344
+ nn.Conv3d(skip_channels, skip_channels // reduction, 1),
345
+ nn.ReLU(inplace=True),
346
+ nn.Conv3d(skip_channels // reduction, skip_channels, 1),
347
+ nn.Sigmoid(),
348
+ )
349
+ self.SE_hl = nn.Sequential(
350
+ nn.AdaptiveAvgPool3d(1),
351
+ nn.Conv3d(skip_channels, skip_channels // reduction, 1),
352
+ nn.ReLU(inplace=True),
353
+ nn.Conv3d(skip_channels // reduction, skip_channels, 1),
354
+ nn.Sigmoid(),
355
+ )
356
+ self.conv1 = Conv3dReLU(
357
+ skip_channels + skip_channels, # we transform C-prime form high level to C from skip connection
358
+ out_channels,
359
+ kernel_size=3,
360
+ padding=1,
361
+ use_batchnorm=use_batchnorm,
362
+ )
363
+ self.conv2 = Conv3dReLU(
364
+ out_channels,
365
+ out_channels,
366
+ kernel_size=3,
367
+ padding=1,
368
+ use_batchnorm=use_batchnorm,
369
+ )
370
+
371
+ def forward(self, x, skip=None):
372
+ x = self.hl_conv(x)
373
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
374
+ attention_hl = self.SE_hl(x)
375
+ if skip is not None:
376
+ attention_ll = self.SE_ll(skip)
377
+ attention_hl = attention_hl + attention_ll
378
+ x = x * attention_hl
379
+ x = torch.cat([x, skip], dim=1)
380
+ x = self.conv1(x)
381
+ x = self.conv2(x)
382
+ return x
383
+
384
+ class DoubleConv2D(nn.Module):
385
+ """(convolution => [BN] => ReLU) * 2"""
386
+
387
+ def __init__(self, in_channels, out_channels, mid_channels=None):
388
+ super().__init__()
389
+ if not mid_channels:
390
+ mid_channels = out_channels
391
+ self.double_conv = nn.Sequential(
392
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
393
+ nn.BatchNorm2d(mid_channels),
394
+ nn.ReLU(inplace=True),
395
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
396
+ nn.BatchNorm2d(out_channels),
397
+ nn.ReLU(inplace=True)
398
+ )
399
+
400
+ def forward(self, x):
401
+ return self.double_conv(x)
402
+
403
+ class Down2D(nn.Module):
404
+ """Downscaling with maxpool then double conv"""
405
+
406
+ def __init__(self, in_channels, out_channels):
407
+ super().__init__()
408
+ self.maxpool_conv = nn.Sequential(
409
+ nn.MaxPool2d(2),
410
+ NONLocalBlock2D(in_channels),
411
+ DoubleConv2D(in_channels, out_channels)
412
+ )
413
+
414
+ def forward(self, x):
415
+ return self.maxpool_conv(x)
416
+
417
+
418
+ class Up2D(nn.Module):
419
+ """Upscaling then double conv"""
420
+
421
+ def __init__(self, in_channels, out_channels, bilinear=True):
422
+ super().__init__()
423
+
424
+ # if bilinear, use the normal convolutions to reduce the number of channels
425
+ if bilinear:
426
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
427
+ self.conv = DoubleConv2D(in_channels, out_channels, in_channels // 2)
428
+ else:
429
+ self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
430
+ self.conv = DoubleConv2D(in_channels, out_channels)
431
+
432
+ def forward(self, x1, x2):
433
+ x1 = self.up(x1)
434
+ # input is CHW
435
+ diffY = x2.size()[2] - x1.size()[2]
436
+ diffX = x2.size()[3] - x1.size()[3]
437
+
438
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
439
+ diffY // 2, diffY - diffY // 2])
440
+ # if you have padding issues, see
441
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
442
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
443
+ x = torch.cat([x2, x1], dim=1)
444
+ return self.conv(x)
445
+
446
+ class OutConv2D(nn.Module):
447
+ def __init__(self, in_channels, out_channels):
448
+ super(OutConv2D, self).__init__()
449
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
450
+
451
+ def forward(self, x):
452
+ return self.conv(x)
453
+
454
+ class UNet2D(nn.Module):
455
+ def __init__(self, n_channels, n_classes, bilinear=True, pab_channels=64, use_batchnorm=True, aux_classifier = False):
456
+ super(UNet2D, self).__init__()
457
+ self.n_channels = n_channels
458
+ self.n_classes = n_classes
459
+ self.bilinear = bilinear
460
+ self.inc = DoubleConv2D(n_channels, pab_channels)
461
+ self.down1 = Down2D(pab_channels, 2*pab_channels)
462
+ self.down2 = Down2D(2*pab_channels, 4*pab_channels)
463
+ self.down3 = Down2D(4*pab_channels, 8*pab_channels)
464
+ factor = 2 if bilinear else 1
465
+ self.down4 = Down2D(8*pab_channels, 16*pab_channels // factor)
466
+ self.pab = PAB2D(8*pab_channels,8*pab_channels)
467
+ self.up1 = Up2D(16*pab_channels, 8*pab_channels // factor, bilinear)
468
+ self.up2 = Up2D(8*pab_channels, 4*pab_channels // factor, bilinear)
469
+ self.up3 = Up2D(4*pab_channels, 2*pab_channels // factor, bilinear)
470
+ self.up4 = Up2D(2*pab_channels, pab_channels, bilinear)
471
+
472
+ self.mfab1 = MFAB2D(8*pab_channels,8*pab_channels,4*pab_channels,use_batchnorm)
473
+ self.mfab2 = MFAB2D(4*pab_channels,4*pab_channels,2*pab_channels,use_batchnorm)
474
+ self.mfab3 = MFAB2D(2*pab_channels,2*pab_channels,pab_channels,use_batchnorm)
475
+ self.mfab4 = MFAB2D(pab_channels,pab_channels,pab_channels,use_batchnorm)
476
+ self.outc = OutConv2D(pab_channels, n_classes)
477
+
478
+ if aux_classifier == False:
479
+ self.aux = None
480
+ else:
481
+ # customize the auxiliary classification loss
482
+ # self.aux = nn.Sequential(nn.AdaptiveAvgPool2d(1),
483
+ # nn.Flatten(),
484
+ # nn.Dropout(p=0.1, inplace=True),
485
+ # nn.Linear(8*pab_channels, 16, bias=True),
486
+ # nn.Dropout(p=0.1, inplace=True),
487
+ # nn.Linear(16, n_classes, bias=True),
488
+ # nn.Softmax(1))
489
+
490
+ self.aux = nn.Sequential(
491
+ NONLocalBlock2D(8*pab_channels),
492
+ nn.Conv2d(8*pab_channels,1,1),
493
+ nn.InstanceNorm2d(1),
494
+ nn.ReLU(),
495
+ nn.Flatten(),
496
+ nn.Linear(24*24, 16, bias=True),
497
+ nn.Dropout(p=0.2, inplace=True),
498
+ nn.Linear(16, n_classes, bias=True),
499
+ nn.Softmax(1))
500
+ def forward(self, x):
501
+ x1 = self.inc(x)
502
+ x2 = self.down1(x1)
503
+ x3 = self.down2(x2)
504
+ x4 = self.down3(x3)
505
+ x5 = self.down4(x4)
506
+ x5 = self.pab(x5)
507
+
508
+ x = self.mfab1(x5,x4)
509
+ x = self.mfab2(x,x3)
510
+ x = self.mfab3(x,x2)
511
+ x = self.mfab4(x,x1)
512
+
513
+ # x = self.up1(x5, x4)
514
+ # x = self.up2(x, x3)
515
+ # x = self.up3(x, x2)
516
+ # x = self.up4(x, x1)
517
+ logits = self.outc(x)
518
+ logits = F.softmax(logits,1)
519
+
520
+ if self.aux ==None:
521
+ return logits
522
+ else:
523
+ aux = self.aux(x5)
524
+ return logits, aux
525
+
526
+
527
+
528
+
529
+ class DoubleConv3D(nn.Module):
530
+ """(convolution => [BN] => ReLU) * 2"""
531
+
532
+ def __init__(self, in_channels, out_channels, mid_channels=None):
533
+ super().__init__()
534
+ if not mid_channels:
535
+ mid_channels = out_channels
536
+ self.double_conv = nn.Sequential(
537
+ nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1),
538
+ nn.BatchNorm3d(mid_channels),
539
+ nn.ReLU(inplace=True),
540
+ nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1),
541
+ nn.BatchNorm3d(out_channels),
542
+ nn.ReLU(inplace=True)
543
+ )
544
+
545
+ def forward(self, x):
546
+ return self.double_conv(x)
547
+
548
+ class Down3D(nn.Module):
549
+ """Downscaling with maxpool then double conv"""
550
+
551
+ def __init__(self, in_channels, out_channels):
552
+ super().__init__()
553
+ self.maxpool_conv = nn.Sequential(
554
+ nn.MaxPool3d(2),
555
+ # NONLocalBlock3D(in_channels),
556
+ DoubleConv3D(in_channels, out_channels)
557
+ )
558
+
559
+ def forward(self, x):
560
+ return self.maxpool_conv(x)
561
+
562
+ class Up3D(nn.Module):
563
+ """Upscaling then double conv"""
564
+
565
+ def __init__(self, in_channels, out_channels, bilinear=True):
566
+ super().__init__()
567
+
568
+ # if bilinear, use the normal convolutions to reduce the number of channels
569
+ if bilinear:
570
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
571
+ self.conv = DoubleConv3D(in_channels, out_channels, in_channels // 2)
572
+ else:
573
+ self.up = nn.ConvTranspose3d(in_channels , in_channels // 2, kernel_size=2, stride=2)
574
+ self.conv = DoubleConv3D(in_channels, out_channels)
575
+
576
+ def forward(self, x1, x2):
577
+ x1 = self.up(x1)
578
+ # input is CHW
579
+ diffY = x2.size()[2] - x1.size()[2]
580
+ diffX = x2.size()[3] - x1.size()[3]
581
+
582
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
583
+ diffY // 2, diffY - diffY // 2])
584
+ # if you have padding issues, see
585
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
586
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
587
+ x = torch.cat([x2, x1], dim=1)
588
+ return self.conv(x)
589
+
590
+ class OutConv3D(nn.Module):
591
+ def __init__(self, in_channels, out_channels):
592
+ super(OutConv3D, self).__init__()
593
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
594
+
595
+ def forward(self, x):
596
+ return self.conv(x)
597
+
598
+ class UNet3D(nn.Module):
599
+ def __init__(self, n_channels, n_classes, bilinear=True, pab_channels=64, use_batchnorm=True, aux_classifier = False):
600
+ super(UNet3D, self).__init__()
601
+ self.n_channels = n_channels
602
+ self.n_classes = n_classes
603
+ self.bilinear = bilinear
604
+
605
+ self.inc = DoubleConv3D(n_channels, pab_channels)
606
+ self.down1 = Down3D(pab_channels, 2*pab_channels)
607
+ self.nnblock2 = NONLocalBlock3D(2*pab_channels)
608
+ self.down2 = Down3D(2*pab_channels, 4*pab_channels)
609
+ self.down3 = Down3D(4*pab_channels, 8*pab_channels)
610
+ factor = 2 if bilinear else 1
611
+ self.down4 = Down3D(8*pab_channels, 16*pab_channels // factor)
612
+ self.pab = PAB3D(8*pab_channels,8*pab_channels)
613
+ self.up1 = Up3D(16*pab_channels, 8*pab_channels // factor, bilinear)
614
+ self.up2 = Up3D(8*pab_channels, 4*pab_channels // factor, bilinear)
615
+ self.up3 = Up3D(4*pab_channels, 2*pab_channels // factor, bilinear)
616
+ self.up4 = Up3D(2*pab_channels, pab_channels, bilinear)
617
+
618
+ self.mfab1 = MFAB3D(8*pab_channels,8*pab_channels,4*pab_channels,use_batchnorm)
619
+ self.mfab2 = MFAB3D(4*pab_channels,4*pab_channels,2*pab_channels,use_batchnorm)
620
+ self.mfab3 = MFAB3D(2*pab_channels,2*pab_channels,pab_channels,use_batchnorm)
621
+ self.mfab4 = MFAB3D(pab_channels,pab_channels,pab_channels,use_batchnorm)
622
+ self.outc = OutConv3D(pab_channels, n_classes)
623
+
624
+ if aux_classifier == False:
625
+ self.aux = None
626
+ else:
627
+ # customize the auxiliary classification loss
628
+ # self.aux = nn.Sequential(nn.AdaptiveMaxPool3d(1),
629
+ # nn.Flatten(),
630
+ # nn.Dropout(p=0.1, inplace=True),
631
+ # nn.Linear(8*pab_channels, 16, bias=True),
632
+ # nn.Dropout(p=0.1, inplace=True),
633
+ # nn.Linear(16, n_classes, bias=True),
634
+ # nn.Softmax(1))
635
+
636
+ self.aux = nn.Sequential(nn.Conv3d(8*pab_channels,1,1),
637
+ nn.InstanceNorm3d(1),
638
+ nn.ReLU(),
639
+ nn.Flatten(),
640
+ nn.Linear(16*16*2, 16, bias=True),
641
+ nn.Dropout(p=0.2, inplace=True),
642
+ nn.Linear(16, n_classes, bias=True),
643
+ nn.Softmax(1))
644
+
645
+ def forward(self, x):
646
+ x1 = self.inc(x)
647
+ x2 = self.down1(x1)
648
+ # x2 = self.nnblock2(x2)
649
+ x3 = self.down2(x2)
650
+ x4 = self.down3(x3)
651
+ x5 = self.down4(x4)
652
+ x5 = self.pab(x5)
653
+
654
+ x = self.mfab1(x5,x4)
655
+ x = self.mfab2(x,x3)
656
+ x = self.mfab3(x,x2)
657
+ x = self.mfab4(x,x1)
658
+
659
+ # x = self.up1(x5, x4)
660
+ # x = self.up2(x, x3)
661
+ # x = self.up3(x, x2)
662
+ # x = self.up4(x, x1)
663
+ logits = self.outc(x)
664
+ logits = F.softmax(logits,1)
665
+
666
+ if self.aux ==None:
667
+ return logits
668
+ else:
669
+ aux = self.aux(x5)
670
+ return logits, aux
utils/pipeline.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ import tempfile
4
+ from glob import glob
5
+ from torchsummary import summary
6
+ import numpy as np
7
+ import pandas as pd
8
+ from tqdm import tqdm
9
+ import torch
10
+ from torch.utils.tensorboard import SummaryWriter
11
+ from torch.cuda.amp import autocast, GradScaler
12
+ import torch.nn as nn
13
+ import torchvision
14
+ import monai
15
+ from monai.metrics import DiceMetric, ConfusionMatrixMetric, MeanIoU
16
+ from monai.visualize import plot_2d_or_3d_image
17
+ from visualization import visualize_patient
18
+ from sliding_window import sw_inference
19
+ from data_preparation import build_dataset
20
+ from models import UNet2D, UNet3D
21
+ from loss import WeaklyDiceFocalLoss
22
+ from sklearn.linear_model import LinearRegression
23
+ from nrrd import write, read
24
+ import morphsnakes as ms
25
+ from monai.data import decollate_batch
26
+
27
+
28
+ def build_optimizer(model, config):
29
+
30
+ if config['LOSS'] == "gdice":
31
+ loss_function = monai.losses.GeneralizedDiceLoss(
32
+ include_background=config['EVAL_INCLUDE_BACKGROUND'],
33
+ reduction="mean", to_onehot_y=True, sigmoid=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.GeneralizedDiceLoss(
34
+ include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean", to_onehot_y=False, softmax=True)
35
+ elif config['LOSS'] == 'cdice':
36
+ loss_function = monai.losses.DiceCELoss(
37
+ include_background=config['EVAL_INCLUDE_BACKGROUND'],
38
+ reduction="mean", to_onehot_y=True, sigmoid=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.DiceCELoss(
39
+ include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean", to_onehot_y=False, softmax=True)
40
+ elif config['LOSS'] == 'mdice':
41
+ loss_function = monai.losses.MaskedDiceLoss()
42
+ elif config['LOSS'] == 'wdice':
43
+ # Example with 3 classes (including the background: label 0).
44
+ # The distance between the background class (label 0) and the other classes is the maximum, equal to 1.
45
+ # The distance between class 1 and class 2 is 0.5.
46
+ dist_mat = np.array([[0.0, 1.0, 1.0], [1.0, 0.0, 0.5], [1.0, 0.5, 0.0]], dtype=np.float32)
47
+ loss_function = monai.losses.GeneralizedWassersteinDiceLoss(dist_matrix=dist_mat)
48
+ elif config['LOSS'] == "fdice":
49
+ loss_function = monai.losses.DiceFocalLoss(
50
+ include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=True, sigmoid=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.DiceFocalLoss(
51
+ include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=False, softmax=True)
52
+ elif config['LOSS'] == "wfdice":
53
+ loss_function = WeaklyDiceFocalLoss(include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=True, sigmoid=True, lambda_weak=config['LAMBDA_WEAK']) if len(config['KEEP_CLASSES'])<=2 else WeaklyDiceFocalLoss(include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=False, softmax=True, lambda_weak=config['LAMBDA_WEAK'])
54
+ else:
55
+ loss_function = monai.losses.DiceLoss(
56
+ include_background=config['EVAL_INCLUDE_BACKGROUND'],
57
+ reduction="mean", to_onehot_y=True, sigmoid=True, squared_pred=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.DiceLoss(
58
+ include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean", to_onehot_y=False, softmax=True, squared_pred=True)
59
+
60
+ eval_metrics = [
61
+ ("sensitivity", ConfusionMatrixMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], metric_name='sensitivity', reduction="mean_batch")),
62
+ ("specificity", ConfusionMatrixMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], metric_name='specificity', reduction="mean_batch")),
63
+ ("accuracy", ConfusionMatrixMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], metric_name='accuracy', reduction="mean_batch")),
64
+ ("dice", DiceMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean_batch")),
65
+ ("IoU", MeanIoU(include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean_batch"))
66
+ ]
67
+
68
+ optimizer = torch.optim.Adam(model.parameters(), config['LEARNING_RATE'], weight_decay=1e-5, amsgrad=True)
69
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['MAX_EPOCHS'])
70
+ return loss_function, optimizer, lr_scheduler, eval_metrics
71
+
72
+
73
+
74
+ def load_weights(model, config):
75
+ try:
76
+ model.load_state_dict(torch.load("checkpoints/" + config['PRETRAINED_WEIGHTS'] + ".pth", map_location=torch.device(config['DEVICE'])))
77
+ print("Model weights from", config['PRETRAINED_WEIGHTS'], "have been loaded")
78
+ except Exception as e:
79
+ try:
80
+ model.load_state_dict(torch.load(config['PRETRAINED_WEIGHTS'], map_location=torch.device(config['DEVICE'])))
81
+ print("Model weights from", config['PRETRAINED_WEIGHTS'], "have been loaded")
82
+ except Exception as e: # load
83
+ print("WARNING: weights were not loaded. ", e)
84
+ pass
85
+
86
+ return model
87
+
88
+
89
+ def build_model(config):
90
+
91
+ config = get_defaults(config)
92
+
93
+ dropout_prob = config['DROPOUT']
94
+
95
+ if "SegResNetVAE" in config["MODEL_NAME"]:
96
+ model = monai.networks.nets.SegResNetVAE(
97
+ input_image_size=config['ROI_SIZE'] if "3D" in config['MODEL_NAME'] else (config['ROI_SIZE'][0], config['ROI_SIZE'][1]),
98
+ vae_estimate_std=False,
99
+ vae_default_std=0.3,
100
+ vae_nz=256,
101
+ spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
102
+ blocks_down=[1, 2, 2, 4],
103
+ blocks_up=[1, 1, 1],
104
+ init_filters=16,
105
+ in_channels=1,
106
+ norm='instance',
107
+ out_channels=len(config['KEEP_CLASSES']),
108
+ dropout_prob=dropout_prob,
109
+ ).to(config['DEVICE'])
110
+
111
+ elif "SegResNet" in config["MODEL_NAME"]:
112
+ model = monai.networks.nets.SegResNet(
113
+ spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
114
+ blocks_down=[1, 2, 2, 4],
115
+ blocks_up=[1, 1, 1],
116
+ init_filters=16,
117
+ in_channels=1,
118
+ out_channels=len(config['KEEP_CLASSES']),
119
+ dropout_prob=dropout_prob,
120
+ norm="instance"
121
+ ).to(config['DEVICE'])
122
+
123
+ elif "SwinUNETR" in config["MODEL_NAME"]:
124
+ model = monai.networks.nets.SwinUNETR(
125
+ img_size=config['ROI_SIZE'],
126
+ in_channels=1,
127
+ out_channels=len(config['KEEP_CLASSES']),
128
+ feature_size=48,
129
+ drop_rate=dropout_prob,
130
+ attn_drop_rate=0.0,
131
+ dropout_path_rate=0.0,
132
+ use_checkpoint=True
133
+ ).to(config['DEVICE'])
134
+
135
+ elif "UNETR" in config["MODEL_NAME"]:
136
+ model = monai.networks.nets.UNETR(
137
+ img_size=config['ROI_SIZE'] if "3D" in config['MODEL_NAME'] else (config['ROI_SIZE'][0], config['ROI_SIZE'][1]),
138
+ in_channels=1,
139
+ out_channels=len(config['KEEP_CLASSES']),
140
+ feature_size=16,
141
+ hidden_size=256,
142
+ mlp_dim=3072,
143
+ num_heads=8,
144
+ pos_embed="perceptron",
145
+ norm_name="instance",
146
+ res_block=True,
147
+ dropout_rate=dropout_prob,
148
+ ).to(config['DEVICE'])
149
+
150
+ elif "MANet" in config["MODEL_NAME"]:
151
+ if "2D" in config["MODEL_NAME"]:
152
+ model = UNet2D(
153
+ 1,
154
+ len(config['KEEP_CLASSES']),
155
+ pab_channels=64,
156
+ use_batchnorm=True
157
+ ).to(config['DEVICE'])
158
+ else:
159
+ model = UNet3D(
160
+ 1,
161
+ len(config['KEEP_CLASSES']),
162
+ pab_channels=32,
163
+ use_batchnorm=True
164
+ ).to(config['DEVICE'])
165
+
166
+ elif "UNetPlusPlus" in config["MODEL_NAME"]:
167
+ model = monai.networks.nets.BasicUNetPlusPlus(
168
+ spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
169
+ in_channels=1,
170
+ out_channels=len(config['KEEP_CLASSES']),
171
+ features=(32, 32, 64, 128, 256, 32),
172
+ norm="instance",
173
+ dropout=dropout_prob,
174
+ ).to(config['DEVICE'])
175
+
176
+ elif "UNet1" in config['MODEL_NAME']:
177
+ model = monai.networks.nets.UNet(
178
+ spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
179
+ in_channels=1,
180
+ out_channels=len(config['KEEP_CLASSES']),
181
+ channels=(16, 32, 64, 128, 256),
182
+ strides=(2, 2, 2, 2),
183
+ num_res_units=2,
184
+ norm="instance"
185
+ ).to(config['DEVICE'])
186
+
187
+ elif "UNet2" in config['MODEL_NAME']:
188
+ model = monai.networks.nets.UNet(
189
+ spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
190
+ in_channels=1,
191
+ out_channels=len(config['KEEP_CLASSES']),
192
+ channels=(32, 64, 128, 256),
193
+ strides=(2, 2, 2, 2),
194
+ num_res_units=4,
195
+ norm="instance"
196
+ ).to(config['DEVICE'])
197
+
198
+ else:
199
+ print(config["MODEL_NAME"], "is not a valid model name")
200
+ return None
201
+
202
+ try:
203
+ if "3D" in config['MODEL_NAME']:
204
+ print(summary(model, input_size=(1, config['ROI_SIZE'][0], config['ROI_SIZE'][1], config['ROI_SIZE'][2])))
205
+ else:
206
+ print(summary(model, input_size=(1, config['ROI_SIZE'][0], config['ROI_SIZE'][1])))
207
+ except Exception as e:
208
+ print("could not load model summary:", e)
209
+
210
+ if config['PRETRAINED_WEIGHTS'] is not None and config['PRETRAINED_WEIGHTS']:
211
+ model = load_weights(model, config)
212
+ return model
213
+
214
+
215
+ def train(model, train_loader, val_loader, loss_function, eval_metrics, optimizer, config,
216
+ scheduler=None, writer=None, postprocessing_transforms = None, weak_labels = None):
217
+
218
+ if writer is None: writer = SummaryWriter(log_dir="runs/" + config['EXPORT_FILE_NAME'])
219
+ best_metric, best_metric_epoch = -1, -1
220
+ prev_metric, patience, patience_counter = 1, config['EARLY_STOPPING_PATIENCE'], 0
221
+ if config['AUTOCAST']: scaler = GradScaler() # Initialize GradScaler for mixed precision training
222
+
223
+ for epoch in range(config['MAX_EPOCHS']):
224
+ print("-" * 10)
225
+ model.train()
226
+ epoch_loss, step = 0, 0
227
+ with tqdm(train_loader) as progress_bar:
228
+ for batch_data in progress_bar:
229
+ step += 1
230
+ inputs, labels = batch_data["image"].to(config['DEVICE']), batch_data["mask"].to(config['DEVICE'])
231
+
232
+ # only train with batches that have tumor; skip those without tumor
233
+ if config['TYPE'] == "tumor":
234
+ if torch.sum(labels[:,-1]) == 0:
235
+ continue
236
+
237
+ # check input shapes
238
+ if inputs is None or labels is None:
239
+ continue
240
+ if inputs.shape[-1] != labels.shape[-1] or inputs.shape[0] != labels.shape[0]:
241
+ print("WARNING: Batch skipped. Image and mask shape does not match:", inputs.shape[0], labels.shape[0])
242
+ continue
243
+
244
+ optimizer.zero_grad()
245
+ if not config['AUTOCAST']:
246
+
247
+ # segmentation output
248
+ outputs = model(inputs)
249
+ if "SegResNetVAE" in config["MODEL_NAME"]: outputs = outputs[0]
250
+ if isinstance(outputs, list): outputs = outputs[0]
251
+
252
+ # loss
253
+ if weak_labels is not None:
254
+ weak_label = torch.tensor([weak_labels[step]]).to(config['DEVICE'])
255
+ loss = loss_function(outputs, labels, weak_label) if config['LOSS'] == 'wfdice' else loss_function(outputs, labels)
256
+ loss.backward()
257
+ optimizer.step()
258
+
259
+ else:
260
+ with autocast():
261
+ outputs = model(inputs)
262
+ if "SegResNetVAE" in config["MODEL_NAME"]: outputs = outputs[0]
263
+ if isinstance(outputs, list): outputs = outputs[0]
264
+ loss = loss_function(outputs, labels, [weak_labels[step]]) if config['LOSS'] == 'wfdice' else loss_function(outputs, labels)
265
+
266
+ scaler.scale(loss).backward()
267
+ scaler.unscale_(optimizer)
268
+ if torch.isinf(loss).any():
269
+ print("Detected inf in gradients.")
270
+ else:
271
+ scaler.step(optimizer)
272
+ scaler.update()
273
+
274
+ epoch_loss += loss.item()
275
+ progress_bar.set_description(f'Epoch [{epoch+1}/{config["MAX_EPOCHS"]}], Loss: {epoch_loss/step:.4f}')
276
+
277
+ epoch_loss /= step
278
+ writer.add_scalar("train_loss_epoch", epoch_loss, epoch)
279
+ progress_bar.set_description(f'Epoch [{epoch+1}/{config["MAX_EPOCHS"]}], Loss: {epoch_loss:.4f}')
280
+
281
+ # validation
282
+ if (epoch + 1) % config['VAL_INTERVAL'] == 0:
283
+
284
+ # get a list of validation measures, pick one to be the decision maker
285
+ val_metrics, (val_images, val_labels, val_outputs) = evaluate(model, val_loader, eval_metrics, config, postprocessing_transforms)
286
+ if isinstance(config['EVAL_METRIC'], list):
287
+ cur_metric = np.mean([val_metrics[m] for m in config['EVAL_METRIC']])
288
+ else:
289
+ cur_metric = val_metrics[config['EVAL_METRIC']]
290
+
291
+ # determine if better than previous best validation metric
292
+ if cur_metric > best_metric:
293
+ best_metric, best_metric_epoch = cur_metric, epoch + 1
294
+ torch.save(model.state_dict(), "checkpoints/" + config['EXPORT_FILE_NAME'] + ".pth")
295
+
296
+ # early stopping
297
+ patience_counter = patience_counter + 1 if prev_metric > cur_metric else 0
298
+ if patience_counter == patience or epoch - best_metric_epoch > patience:
299
+ print("Early stopping at epoch", epoch + 1)
300
+ break
301
+ print(f'Current epoch: {epoch + 1} current avg {config["EVAL_METRIC"]}: {cur_metric :.4f} best avg {config["EVAL_METRIC"]}: {best_metric:.4f} at epoch {best_metric_epoch}')
302
+ prev_metric = cur_metric
303
+
304
+ # writer
305
+ for key, value in val_metrics.items():
306
+ writer.add_scalar("val_" + key, value, epoch)
307
+ plot_2d_or_3d_image(val_images, epoch + 1, writer, index=len(val_outputs)//2, tag="image",frame_dim=-1)
308
+ plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=len(val_outputs)//2, tag="label",frame_dim=-1)
309
+ plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=len(val_outputs)//2, tag="output",frame_dim=-1)
310
+
311
+ # update scheduler
312
+ try:
313
+ if scheduler is not None: scheduler.step()
314
+ except:
315
+ pass
316
+
317
+ print(f"Train completed, best {config['EVAL_METRIC']}: {best_metric:.4f} at epoch: {best_metric_epoch}")
318
+ writer.close()
319
+ return model, writer
320
+
321
+
322
+
323
+ def evaluate(model, val_loader, eval_metrics, config, postprocessing_transforms=None, use_liver_seg=False, export_filenames = [], export_file_metadata = []):
324
+
325
+ val_metrics = {}
326
+ model.eval()
327
+ with torch.no_grad():
328
+
329
+ step = 0
330
+ for val_data in val_loader:
331
+ # 3D: val_images has shape (1,C,H,W,Z)
332
+ # 2D: val_images has shape (B,C,H,W)
333
+ val_images, val_labels = val_data["image"].to(config['DEVICE']), val_data["mask"].to(config['DEVICE'])
334
+ if use_liver_seg: val_liver = val_data["pred_liver"].to(config['DEVICE'])
335
+
336
+ if (val_images[0].shape[-1] != val_labels[0].shape[-1]) or (
337
+ "3D" not in config["MODEL_NAME"] and val_images.shape[0] != val_labels.shape[0]):
338
+ print("WARNING: Batch skipped. Image and mask shape does not match:", val_images.shape, val_labels.shape)
339
+ continue
340
+
341
+ # convert outputs to probability
342
+ if "3D" in config["MODEL_NAME"]:
343
+ val_outputs = sw_inference(model, val_images, config['ROI_SIZE'], config['AUTOCAST'], discard_second_output='SegResNetVAE' in config['MODEL_NAME'])
344
+ else:
345
+ if "SegResNetVAE" in config["MODEL_NAME"]: val_outputs, _ = model(val_images)
346
+ else: val_outputs = model(val_images)
347
+
348
+ # post-procesing
349
+ if postprocessing_transforms is not None:
350
+ val_outputs = [postprocessing_transforms(i) for i in decollate_batch(val_outputs)]
351
+
352
+ # remove tumor predictions outside liver
353
+ for i in range(len(val_outputs)):
354
+ val_outputs[i][-1][torch.where(val_images[i][0] <= 1e-6)] = 0
355
+
356
+ # apply morphological snakes algorithm
357
+ if config['POSTPROCESSING_MORF']:
358
+ for i in range(len(val_outputs)):
359
+ val_outputs[i][-1] = torch.from_numpy(ms.morphological_chan_vese(val_images[i][0].cpu(), iterations=2, init_level_set=val_outputs[i][-1].cpu())).to(config['DEVICE'])
360
+
361
+ for i in range(len(val_outputs)):
362
+ if use_liver_seg:
363
+ # use liver model outputs for liver channel
364
+ val_outputs[i][1] = val_liver[i]
365
+ # if region is tumor, assign liver prediction to 0
366
+ val_outputs[i][1] -= val_outputs[i][2]
367
+
368
+ # compute metric for current iteration
369
+ for metric_name, metric in eval_metrics:
370
+ if isinstance(val_outputs[0], list):
371
+ val_outputs = val_outputs[0]
372
+ metric(val_outputs, val_labels)
373
+
374
+ # save prediction to local folder
375
+ if len(export_filenames) > 0:
376
+ for _ in range(len(val_outputs)):
377
+ numpy_array = val_outputs[_].cpu().detach().numpy()
378
+ write(export_filenames[step], numpy_array[-1], header=export_file_metadata[step])
379
+ print(" Segmentation exported to", export_filenames[step])
380
+ step += 1
381
+
382
+ # aggregate the final mean metric
383
+ for metric_name, metric in eval_metrics:
384
+ if "dice" in metric_name or "IoU" in metric_name: metric_value = metric.aggregate().tolist()
385
+ else: metric_value = metric.aggregate()[0].tolist() # a list of accuracies, one per class
386
+ val_metrics[metric_name + "_avg"] = np.mean(metric_value)
387
+ if config['TYPE'] != "liver":
388
+ for c in range(1, len(metric_value) + 1): # class-wise accuracies
389
+ val_metrics[metric_name + "_class" + str(c)] = metric_value[c-1]
390
+ metric.reset()
391
+
392
+ return val_metrics, (val_images, val_labels, val_outputs)
393
+
394
+
395
+
396
+
397
+ def get_defaults(config):
398
+
399
+ if 'TRAIN' not in config.keys(): config['TRAIN'] = True
400
+ if 'VALID_PATIENT_RATIO' not in config.keys(): config['VALID_PATIENT_RATIO'] = 0.2
401
+ if 'VAL_INTERVAL' not in config.keys(): config['VAL_INTERVAL'] = 1
402
+ if 'VAL_INTERVAL' not in config.keys(): config['DROPOUT'] = 0.1
403
+ if 'EARLY_STOPPING_PATIENCE' not in config.keys(): config['EARLY_STOPPING_PATIENCE'] = 20
404
+ if 'AUTOCAST' not in config.keys(): config['AUTOCAST'] = False
405
+ if 'NUM_WORKERS' not in config.keys(): config['NUM_WORKERS'] = 0
406
+ if 'DROPOUT' not in config.keys(): config['DROPOUT'] = 0.1
407
+ if 'ONESAMPLETESTRUN' not in config.keys(): config['ONESAMPLETESTRUN'] = False
408
+ if 'TRAIN' not in config.keys(): config['TRAIN'] = True
409
+ if 'DATA_AUGMENTATION' not in config.keys(): config['DATA_AUGMENTATION'] = False
410
+ if 'POSTPROCESSING_MORF' not in config.keys(): config['POSTPROCESSING_MORF'] = False
411
+ if 'PREPROCESSING' not in config.keys(): config['PREPROCESSING'] = ""
412
+ if 'PRETRAINED_WEIGHTS' not in config.keys(): config['PRETRAINED_WEIGHTS'] = ""
413
+
414
+ if 'EVAL_INCLUDE_BACKGROUND' not in config.keys():
415
+ if config['TYPE'] == "liver": config['EVAL_INCLUDE_BACKGROUND'] = True
416
+ else: config['EVAL_INCLUDE_BACKGROUND'] = False
417
+ if 'EVAL_METRIC' not in config.keys():
418
+ if config['TYPE'] == "liver": config['EVAL_METRIC'] = ["dice_avg"]
419
+ else: config['EVAL_METRIC'] = ["dice_class2"]
420
+
421
+ if 'CLINICAL_DATA_FILE' not in config.keys(): config['CLINICAL_DATA_FILE'] = "Dataset/HCC-TACE-Seg_clinical_data-V2.xlsx"
422
+ if 'CLINICAL_PREDICTORS' not in config.keys(): config['CLINICAL_PREDICTORS'] = ['T_involvment', 'CLIP_Score','Personal history of cancer', 'TNM', 'Metastasis','fhx_can', 'Alcohol', 'Smoking', 'Evidence_of_cirh', 'AFP', 'age', 'Diabetes', 'Lymphnodes', 'Interval_BL', 'TTP']
423
+ if 'LAMBDA_WEAK' not in config.keys(): config['LAMBDA_WEAK'] = 0.5
424
+ if 'MASKNONLIVER' not in config.keys(): config['MASKNONLIVER'] = False
425
+
426
+ if config['TYPE'] == "liver": config['KEEP_CLASSES']=["normal", "liver"]
427
+ elif config['TYPE'] == "tumor": config['KEEP_CLASSES']=["normal", "liver", "tumor"]
428
+ else: config['KEEP_CLASSES'] = ["normal", "liver", "tumor", "portal vein", "abdominal aorta"]
429
+
430
+ config['DEVICE'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
431
+ config['EXPORT_FILE_NAME'] = config['TYPE']+ "_" + config['MODEL_NAME'] + "_" + config['LOSS'] + "_batchsize" + str(config['BATCH_SIZE']) + "_DA" + str(config['DATA_AUGMENTATION']) + "_HU" + str(config['HU_RANGE'][0]) + "-" + str(config['HU_RANGE'][1]) + "_" + config['PREPROCESSING'] + "_" + str(config['ROI_SIZE'][0]) + "_" + str(config['ROI_SIZE'][1]) + "_" + str(config['ROI_SIZE'][2]) + "_dropout" + str(config['DROPOUT'])
432
+ if config['MASKNONLIVER']: config['EXPORT_FILE_NAME'] += "_wobackground"
433
+ if config['LOSS'] == "wfdice": config['EXPORT_FILE_NAME'] += "_weaklambda" + str(config['LAMBDA_WEAK'])
434
+ if config['PRETRAINED_WEIGHTS'] != "" and config['PRETRAINED_WEIGHTS'] != config['EXPORT_FILE_NAME']: config['EXPORT_FILE_NAME'] += "_pretraining"
435
+ if config['POSTPROCESSING_MORF']: config['EXPORT_FILE_NAME'] += "_wpostmorf"
436
+ if not config['EVAL_INCLUDE_BACKGROUND']: config['EXPORT_FILE_NAME'] += "_evalnobackground"
437
+
438
+ return config
439
+
440
+
441
+ def train_clinical(df_clinical):
442
+
443
+ clinical_model = LinearRegression()
444
+
445
+ # train model
446
+ print("Training model using", df_clinical.loc[:, df_clinical.columns != 'tumor_ratio'].shape[1], "features")
447
+ print(df_clinical.head())
448
+ clinical_model.fit(df_clinical.loc[:, df_clinical.columns != 'tumor_ratio'], df_clinical['tumor_ratio'])
449
+
450
+ # obtain predicted ratios
451
+ pred = clinical_model.predict(df_clinical.loc[:, df_clinical.columns != 'tumor_ratio'])
452
+
453
+ # evaluate
454
+ corr = np.corrcoef(pred, df_clinical['tumor_ratio'])[0][1]
455
+ mae = np.mean(np.abs(pred - df_clinical['tumor_ratio']))
456
+ print(f"The clinical model was fitted. Corr = {corr: .6f} MAE = {mae: .6f}")
457
+
458
+ return pred
459
+
460
+
461
+ def model_pipeline(config=None, plot=True):
462
+
463
+ torch.cuda.empty_cache()
464
+ config = get_defaults(config)
465
+ print(f"You Are Running on a: {config['DEVICE']}")
466
+ print("file name:", config['EXPORT_FILE_NAME'])
467
+
468
+ writer = SummaryWriter(log_dir="runs/" + config['EXPORT_FILE_NAME'])
469
+
470
+ # prepare data
471
+ train_loader, valid_loader, test_loader, postprocessing_transforms, df_clinical_train = build_dataset(config, get_clinical=config['LOSS']=="wfdice")
472
+
473
+ # train clinical model
474
+ if config['LOSS'] == "wfdice": weak_labels = train_clinical(df_clinical_train)
475
+ else: weak_labels = None
476
+
477
+ # train segmentation model
478
+ model = build_model(config)
479
+ loss_function, optimizer, lr_scheduler, eval_metrics = build_optimizer(model, config)
480
+ if config['TRAIN']:
481
+ train(model, train_loader, valid_loader, loss_function, eval_metrics, optimizer, config, lr_scheduler, writer, postprocessing_transforms, weak_labels)
482
+ model.load_state_dict(torch.load("checkpoints/" + config['EXPORT_FILE_NAME'] + ".pth", map_location=torch.device(config['DEVICE'])))
483
+ if config['ONESAMPLETESTRUN']:
484
+ return None, None, None
485
+
486
+ # test segmentation model
487
+ test_metrics, (test_images, test_labels, test_outputs) = evaluate(model, test_loader, eval_metrics, config, postprocessing_transforms)
488
+ print("Test metrics")
489
+ for key, value in test_metrics.items():
490
+ print(f" {key}: {value:.4f}")
491
+
492
+ # visualize
493
+ if plot:
494
+ if "3D" in config['MODEL_NAME']:
495
+ visualize_patient(test_images[0].cpu(), mask=test_labels[0].cpu(), n_slices=9, title="ground truth", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
496
+ visualize_patient(test_images[0].cpu(), mask=test_outputs[0].cpu(), n_slices=9, title="predicted", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
497
+ else:
498
+ visualize_patient(test_images.cpu(), mask=test_labels.cpu(), n_slices=9, title="ground truth", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
499
+ visualize_patient(test_images.cpu(), mask=torch.stack(test_outputs).cpu(), n_slices=9, title="predicted", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
500
+
501
+ return (test_images, test_labels, test_outputs)
utils/sliding_window.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Callable, Sequence
2
+ from typing import Any, Iterable
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from monai.data.meta_tensor import MetaTensor
7
+ from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size
8
+ from monai.inferers.utils import _create_buffered_slices, _compute_coords, _get_scan_interval, _flatten_struct, _pack_struct
9
+ from monai.utils import (
10
+ BlendMode,
11
+ PytorchPadMode,
12
+ convert_data_type,
13
+ convert_to_dst_type,
14
+ ensure_tuple,
15
+ ensure_tuple_rep,
16
+ fall_back_tuple,
17
+ look_up_option,
18
+ optional_import,
19
+ pytorch_after,
20
+ )
21
+ from tqdm import tqdm
22
+
23
+ # Adapted from monai
24
+ def sliding_window_inference(
25
+ inputs: torch.Tensor | MetaTensor,
26
+ roi_size: Sequence[int] | int,
27
+ sw_batch_size: int,
28
+ predictor: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]],
29
+ overlap: Sequence[float] | float = 0.25,
30
+ mode: BlendMode | str = BlendMode.CONSTANT,
31
+ sigma_scale: Sequence[float] | float = 0.125,
32
+ padding_mode: PytorchPadMode | str = PytorchPadMode.CONSTANT,
33
+ cval: float = 0.0,
34
+ sw_device: torch.device | str | None = None,
35
+ device: torch.device | str | None = None,
36
+ progress: bool = False,
37
+ roi_weight_map: torch.Tensor | None = None,
38
+ process_fn: Callable | None = None,
39
+ buffer_steps: int | None = None,
40
+ buffer_dim: int = -1,
41
+ with_coord: bool = False,
42
+ discard_second_output: bool = False,
43
+ *args: Any,
44
+ **kwargs: Any,
45
+ ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:
46
+ """
47
+ Sliding window inference on `inputs` with `predictor`.
48
+
49
+ The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
50
+ Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
51
+ e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
52
+ could be ([128,64,256], [64,32,128]).
53
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
54
+ an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
55
+ so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).
56
+
57
+ When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
58
+ To maintain the same spatial sizes, the output image will be cropped to the original input size.
59
+
60
+ Args:
61
+ inputs: input image to be processed (assuming NCHW[D])
62
+ roi_size: the spatial window size for inferences.
63
+ When its components have None or non-positives, the corresponding inputs dimension will be used.
64
+ if the components of the `roi_size` are non-positive values, the transform will use the
65
+ corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
66
+ to `(32, 64)` if the second spatial dimension size of img is `64`.
67
+ sw_batch_size: the batch size to run window slices.
68
+ predictor: given input tensor ``patch_data`` in shape NCHW[D],
69
+ The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
70
+ with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
71
+ where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
72
+ N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
73
+ the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
74
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
75
+ to ensure the scaled output ROI sizes are still integers.
76
+ If the `predictor`'s input and output spatial sizes are different,
77
+ we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
78
+ overlap: Amount of overlap between scans along each spatial dimension, defaults to ``0.25``.
79
+ mode: {``"constant"``, ``"gaussian"``}
80
+ How to blend output of overlapping windows. Defaults to ``"constant"``.
81
+
82
+ - ``"constant``": gives equal weight to all predictions.
83
+ - ``"gaussian``": gives less weight to predictions on edges of windows.
84
+
85
+ sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
86
+ Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
87
+ When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
88
+ spatial dimensions.
89
+ padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
90
+ Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
91
+ See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
92
+ cval: fill value for 'constant' padding mode. Default: 0
93
+ sw_device: device for the window data.
94
+ By default the device (and accordingly the memory) of the `inputs` is used.
95
+ Normally `sw_device` should be consistent with the device where `predictor` is defined.
96
+ device: device for the stitched output prediction.
97
+ By default the device (and accordingly the memory) of the `inputs` is used. If for example
98
+ set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
99
+ `inputs` and `roi_size`. Output is on the `device`.
100
+ progress: whether to print a `tqdm` progress bar.
101
+ roi_weight_map: pre-computed (non-negative) weight map for each ROI.
102
+ If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
103
+ process_fn: process inference output and adjust the importance map per window
104
+ buffer_steps: the number of sliding window iterations along the ``buffer_dim``
105
+ to be buffered on ``sw_device`` before writing to ``device``.
106
+ (Typically, ``sw_device`` is ``cuda`` and ``device`` is ``cpu``.)
107
+ default is None, no buffering. For the buffer dim, when spatial size is divisible by buffer_steps*roi_size,
108
+ (i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency.
109
+ buffer_dim: the spatial dimension along which the buffers are created.
110
+ 0 indicates the first spatial dimension. Default is -1, the last spatial dimension.
111
+ with_coord: whether to pass the window coordinates to ``predictor``. Default is False.
112
+ If True, the signature of ``predictor`` should be ``predictor(patch_data, patch_coord, ...)``.
113
+ args: optional args to be passed to ``predictor``.
114
+ kwargs: optional keyword args to be passed to ``predictor``.
115
+
116
+ Note:
117
+ - input must be channel-first and have a batch dim, supports N-D sliding window.
118
+
119
+ """
120
+ buffered = buffer_steps is not None and buffer_steps > 0
121
+ num_spatial_dims = len(inputs.shape) - 2
122
+ if buffered:
123
+ if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims:
124
+ raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.")
125
+ if buffer_dim < 0:
126
+ buffer_dim += num_spatial_dims
127
+ overlap = ensure_tuple_rep(overlap, num_spatial_dims)
128
+ for o in overlap:
129
+ if o < 0 or o >= 1:
130
+ raise ValueError(f"overlap must be >= 0 and < 1, got {overlap}.")
131
+ compute_dtype = inputs.dtype
132
+
133
+ # determine image spatial size and batch size
134
+ # Note: all input images must have the same image size and batch size
135
+ batch_size, _, *image_size_ = inputs.shape
136
+ device = device or inputs.device
137
+ sw_device = sw_device or inputs.device
138
+
139
+ temp_meta = None
140
+ if isinstance(inputs, MetaTensor):
141
+ temp_meta = MetaTensor([]).copy_meta_from(inputs, copy_attr=False)
142
+ inputs = convert_data_type(inputs, torch.Tensor, wrap_sequence=True)[0]
143
+ roi_size = fall_back_tuple(roi_size, image_size_)
144
+
145
+ # in case that image size is smaller than roi size
146
+ image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
147
+ pad_size = []
148
+ for k in range(len(inputs.shape) - 1, 1, -1):
149
+ diff = max(roi_size[k - 2] - inputs.shape[k], 0)
150
+ half = diff // 2
151
+ pad_size.extend([half, diff - half])
152
+ if any(pad_size):
153
+ inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)
154
+
155
+ # Store all slices
156
+ scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
157
+ slices = dense_patch_slices(image_size, roi_size, scan_interval, return_slice=not buffered)
158
+
159
+ num_win = len(slices) # number of windows per image
160
+ total_slices = num_win * batch_size # total number of windows
161
+ windows_range: Iterable
162
+ if not buffered:
163
+ non_blocking = False
164
+ windows_range = range(0, total_slices, sw_batch_size)
165
+ else:
166
+ slices, n_per_batch, b_slices, windows_range = _create_buffered_slices(
167
+ slices, batch_size, sw_batch_size, buffer_dim, buffer_steps
168
+ )
169
+ non_blocking, _ss = torch.cuda.is_available(), -1
170
+ for x in b_slices[:n_per_batch]:
171
+ if x[1] < _ss: # detect overlapping slices
172
+ non_blocking = False
173
+ break
174
+ _ss = x[2]
175
+
176
+ # Create window-level importance map
177
+ valid_patch_size = get_valid_patch_size(image_size, roi_size)
178
+ if valid_patch_size == roi_size and (roi_weight_map is not None):
179
+ importance_map_ = roi_weight_map
180
+ else:
181
+ try:
182
+ valid_p_size = ensure_tuple(valid_patch_size)
183
+ importance_map_ = compute_importance_map(
184
+ valid_p_size, mode=mode, sigma_scale=sigma_scale, device=sw_device, dtype=compute_dtype
185
+ )
186
+ if len(importance_map_.shape) == num_spatial_dims and not process_fn:
187
+ importance_map_ = importance_map_[None, None] # adds batch, channel dimensions
188
+ except Exception as e:
189
+ raise RuntimeError(
190
+ f"patch size {valid_p_size}, mode={mode}, sigma_scale={sigma_scale}, device={device}\n"
191
+ "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
192
+ ) from e
193
+ importance_map_ = convert_data_type(importance_map_, torch.Tensor, device=sw_device, dtype=compute_dtype)[0]
194
+
195
+ # stores output and count map
196
+ output_image_list, count_map_list, sw_device_buffer, b_s, b_i = [], [], [], 0, 0 # type: ignore
197
+ # for each patch
198
+ for slice_g in tqdm(windows_range) if progress else windows_range:
199
+ slice_range = range(slice_g, min(slice_g + sw_batch_size, b_slices[b_s][0] if buffered else total_slices))
200
+ unravel_slice = [
201
+ [slice(idx // num_win, idx // num_win + 1), slice(None)] + list(slices[idx % num_win])
202
+ for idx in slice_range
203
+ ]
204
+ if sw_batch_size > 1:
205
+ win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
206
+ else:
207
+ win_data = inputs[unravel_slice[0]].to(sw_device)
208
+ if with_coord:
209
+ seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs)
210
+ if discard_second_output and seg_prob_out is not None: seg_prob_out = seg_prob_out[0]
211
+ else:
212
+ seg_prob_out = predictor(win_data, *args, **kwargs)
213
+ if discard_second_output and seg_prob_out is not None: seg_prob_out = seg_prob_out[0]
214
+
215
+ # convert seg_prob_out to tuple seg_tuple, this does not allocate new memory.
216
+ dict_keys, seg_tuple = _flatten_struct(seg_prob_out)
217
+ if process_fn:
218
+ seg_tuple, w_t = process_fn(seg_tuple, win_data, importance_map_)
219
+ else:
220
+ w_t = importance_map_
221
+ if len(w_t.shape) == num_spatial_dims:
222
+ w_t = w_t[None, None]
223
+ w_t = w_t.to(dtype=compute_dtype, device=sw_device)
224
+ if buffered:
225
+ c_start, c_end = b_slices[b_s][1:]
226
+ if not sw_device_buffer:
227
+ k = seg_tuple[0].shape[1] # len(seg_tuple) > 1 is currently ignored
228
+ sp_size = list(image_size)
229
+ sp_size[buffer_dim] = c_end - c_start
230
+ sw_device_buffer = [torch.zeros(size=[1, k, *sp_size], dtype=compute_dtype, device=sw_device)]
231
+ for p, s in zip(seg_tuple[0], unravel_slice):
232
+ offset = s[buffer_dim + 2].start - c_start
233
+ s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim])
234
+ s[0] = slice(0, 1)
235
+ sw_device_buffer[0][s] += p * w_t
236
+ b_i += len(unravel_slice)
237
+ if b_i < b_slices[b_s][0]:
238
+ continue
239
+ else:
240
+ sw_device_buffer = list(seg_tuple)
241
+
242
+ for ss in range(len(sw_device_buffer)):
243
+ b_shape = sw_device_buffer[ss].shape
244
+ seg_chns, seg_shape = b_shape[1], b_shape[2:]
245
+ z_scale = None
246
+ if not buffered and seg_shape != roi_size:
247
+ z_scale = [out_w_i / float(in_w_i) for out_w_i, in_w_i in zip(seg_shape, roi_size)]
248
+ w_t = F.interpolate(w_t, seg_shape, mode=_nearest_mode)
249
+ if len(output_image_list) <= ss:
250
+ output_shape = [batch_size, seg_chns]
251
+ output_shape += [int(_i * _z) for _i, _z in zip(image_size, z_scale)] if z_scale else list(image_size)
252
+ # allocate memory to store the full output and the count for overlapping parts
253
+ new_tensor: Callable = torch.empty if non_blocking else torch.zeros # type: ignore
254
+ output_image_list.append(new_tensor(output_shape, dtype=compute_dtype, device=device))
255
+ count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device))
256
+ w_t_ = w_t.to(device)
257
+ for __s in slices:
258
+ if z_scale is not None:
259
+ __s = tuple(slice(int(_si.start * z_s), int(_si.stop * z_s)) for _si, z_s in zip(__s, z_scale))
260
+ count_map_list[-1][(slice(None), slice(None), *__s)] += w_t_
261
+ if buffered:
262
+ o_slice = [slice(None)] * len(inputs.shape)
263
+ o_slice[buffer_dim + 2] = slice(c_start, c_end)
264
+ img_b = b_s // n_per_batch # image batch index
265
+ o_slice[0] = slice(img_b, img_b + 1)
266
+ if non_blocking:
267
+ output_image_list[0][o_slice].copy_(sw_device_buffer[0], non_blocking=non_blocking)
268
+ else:
269
+ output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device)
270
+ else:
271
+ sw_device_buffer[ss] *= w_t
272
+ sw_device_buffer[ss] = sw_device_buffer[ss].to(device)
273
+ _compute_coords(unravel_slice, z_scale, output_image_list[ss], sw_device_buffer[ss])
274
+ sw_device_buffer = []
275
+ if buffered:
276
+ b_s += 1
277
+
278
+ if non_blocking:
279
+ torch.cuda.current_stream().synchronize()
280
+
281
+ # account for any overlapping sections
282
+ for ss in range(len(output_image_list)):
283
+ output_image_list[ss] /= count_map_list.pop(0)
284
+
285
+ # remove padding if image_size smaller than roi_size
286
+ if any(pad_size):
287
+ for ss, output_i in enumerate(output_image_list):
288
+ zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)]
289
+ final_slicing: list[slice] = []
290
+ for sp in range(num_spatial_dims):
291
+ si = num_spatial_dims - sp - 1
292
+ slice_dim = slice(
293
+ int(round(pad_size[sp * 2] * zoom_scale[si])),
294
+ int(round((pad_size[sp * 2] + image_size_[si]) * zoom_scale[si])),
295
+ )
296
+ final_slicing.insert(0, slice_dim)
297
+ output_image_list[ss] = output_i[(slice(None), slice(None), *final_slicing)]
298
+
299
+ final_output = _pack_struct(output_image_list, dict_keys)
300
+ if temp_meta is not None:
301
+ final_output = convert_to_dst_type(final_output, temp_meta, device=device)[0]
302
+ else:
303
+ final_output = convert_to_dst_type(final_output, inputs, device=device)[0]
304
+
305
+ return final_output # type: ignore
306
+
307
+
308
+ def sw_inference(model, input, roi_size, autocast_on, discard_second_output, overlap=0.8):
309
+ def _compute(input):
310
+ return sliding_window_inference(
311
+ inputs=input,
312
+ roi_size=roi_size,
313
+ sw_batch_size=1,
314
+ predictor=model,
315
+ overlap=overlap,
316
+ progress=False,
317
+ mode="constant",
318
+ discard_second_output=discard_second_output
319
+ )
320
+
321
+ if autocast_on:
322
+ with torch.cuda.amp.autocast():
323
+ return _compute(input)
324
+ else:
325
+ return _compute(input)
326
+
327
+
328
+
utils/tumor_features.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.ndimage import label, find_objects
2
+ import numpy as np
3
+ import cv2
4
+
5
+
6
+ IMAGE_SPACING_X = 0.7031
7
+ IMAGE_SPACING_Y = 0.7031
8
+ IMAGE_SPACING_Z = 2.5
9
+
10
+
11
+
12
+ def compute_largest_diameter(binary_mask):
13
+
14
+ # Label connected components in the binary mask
15
+ labeled_array, num_features = label(binary_mask)
16
+
17
+ # Find the objects (tumors) in the labeled array
18
+ tumor_objects = find_objects(labeled_array)
19
+
20
+ # Initialize the largest diameter variable
21
+ largest_diameter = 0
22
+
23
+ # Iterate through each tumor object
24
+ for obj in tumor_objects:
25
+ # Calculate the dimensions of the tumor object
26
+ z_dim = obj[2].stop - obj[2].start
27
+ y_dim = obj[1].stop - obj[1].start
28
+ x_dim = obj[0].stop - obj[0].start
29
+
30
+ # Calculate the diameter using the longest dimension
31
+ diameter = max(z_dim * IMAGE_SPACING_Z, y_dim * IMAGE_SPACING_Y, x_dim * IMAGE_SPACING_X)
32
+
33
+ # Update the largest diameter if necessary
34
+ if diameter > largest_diameter:
35
+ largest_diameter = diameter
36
+
37
+ return largest_diameter / 10 # IN CM
38
+
39
+
40
+
41
+
42
+ def generate_features(img, liver, tumor):
43
+
44
+ contours, _ = cv2.findContours(mask_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
45
+
46
+
47
+ features = {
48
+ "lesion size (cm)": compute_largest_diameter(tumor),
49
+ "lesion shape": "irregular",
50
+ "lesion density (HU)": np.mean(img[tumor==1]),
51
+ "involvement of adjacent organs:": "Yes" if np.sum(np.multiply(liver==0, tumor)) > 0 else "No"
52
+ }
53
+
54
+
55
+ return features
utils/visualization.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from matplotlib import pyplot as plt
2
+ import math
3
+ import numpy as np
4
+
5
+
6
+
7
+ def visualize_results(img, mask, pred, n_slices: int=3, slices: list=None, title: str=""):
8
+ """
9
+ img: tensor [C, H, W, Z]
10
+ mask: tensor [C, H, W, Z]
11
+ pred: tensor [C, H, W, Z]
12
+ n_slices: number of slices to visualize
13
+ slices: list of slices to visualize
14
+ title; title of the plot
15
+ """
16
+ if slices is not None:
17
+ n_slices = len(slices)
18
+
19
+ fig, ax = plt.subplots(n_slices, 3, figsize=(14, 5*n_slices))
20
+ inc = img.shape[-1] // n_slices
21
+ mask_masked = np.ma.masked_where(mask == 0, mask)
22
+ pred_masked = np.ma.masked_where(pred == 0, pred)
23
+
24
+ for i in range(n_slices):
25
+ slice_num = i*inc if slices is None else slices[i]
26
+
27
+ # image
28
+ for c in range(3):
29
+ ax[i,c].imshow(img[0,:,:,slice_num], cmap="gray")
30
+ ax[i,c].axis("off")
31
+ ax[i,c].set_title(f'image')
32
+
33
+ # ground truth
34
+ ax[i,1].imshow(mask_masked[1,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.5)
35
+ ax[i,1].imshow(mask_masked[2,:,:,slice_num], cmap='Reds', vmin=0, vmax=1.3, interpolation='none', alpha=0.8)
36
+ ax[i,1].set_title(f'ground truth')
37
+
38
+ # predicted
39
+ ax[i,2].imshow(pred_masked[1,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.5)
40
+ ax[i,2].imshow(pred_masked[2,:,:,slice_num], cmap='Reds', vmin=0, vmax=1.3, interpolation='none', alpha=0.8)
41
+ ax[i,2].set_title(f'predicted')
42
+
43
+ plt.suptitle(title, size=14)
44
+ plt.tight_layout()
45
+ plt.show()
46
+
47
+
48
+ def visualize_patient(img, mask=None, n_slices: int=3, slices: list=None, z_dim_last=True, mask_channel=0, title: str=""):
49
+ """
50
+ img: tensor [C, H, W, Z]
51
+ mask: tensor [C, H, W, Z]
52
+ n: number of slices to visualize
53
+ """
54
+ if slices is not None:
55
+ n_slices = len(slices)
56
+
57
+ fig, ax = plt.subplots(math.ceil(n_slices/3), 3, figsize=(14, 5*math.ceil(n_slices/3)))
58
+ if z_dim_last: inc = img.shape[-1] // n_slices
59
+ else: inc = img.shape[0] // n_slices
60
+ masked = np.ma.masked_where(mask == 0, mask)
61
+
62
+ for i in range(n_slices):
63
+ r, c = divmod(i, 3)
64
+ slice_num = i*inc if slices is None else slices[i]
65
+ if n_slices <= 3:
66
+ if z_dim_last: ax[c].imshow(img[0,:,:,slice_num], cmap="gray")
67
+ else: ax[c].imshow(img[slice_num,0,:,:], cmap="gray")
68
+ ax[c].axis("off")
69
+ ax[c].set_title(f'slice {slice_num}')
70
+ if mask is not None:
71
+ if z_dim_last: mask_overlay = ax[c].imshow(masked[mask_channel,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4)
72
+ else: mask_overlay = ax[c].imshow(masked[slice_num,mask_channel,:,:], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4)
73
+ else:
74
+ if z_dim_last: ax[r][c].imshow(img[0,:,:,slice_num], cmap="gray")
75
+ else: ax[r][c].imshow(img[slice_num,0,:,:], cmap="gray")
76
+ ax[r][c].axis("off")
77
+ ax[r][c].set_title(f'slice {slice_num}')
78
+ if mask is not None:
79
+ if z_dim_last: mask_overlay = ax[r][c].imshow(masked[mask_channel,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4)
80
+ else: mask_overlay = ax[r][c].imshow(masked[slice_num,mask_channel,:,:], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4)
81
+
82
+ plt.suptitle(title, size=14)
83
+ #if mask is not None:
84
+ # cbar = fig.colorbar(mask_overlay, extend='both')
85
+ plt.tight_layout()
86
+ plt.show()
87
+
88
+ fig, ax = plt.subplots(math.ceil(n_slices/3), 3, figsize=(14, 5*math.ceil(n_slices/3)))
89
+ if z_dim_last: inc = img.shape[-1] // n_slices
90
+ else: inc = img.shape[0] // n_slices
91
+
92
+ for i in range(n_slices):
93
+ r, c = divmod(i, 3)
94
+ slice_num = i*inc if slices is None else slices[i]
95
+ if n_slices <= 3:
96
+ if z_dim_last: ax[c].imshow(img[0,:,:,slice_num], cmap="gray")
97
+ else: ax[c].imshow(img[slice_num,0,:,:], cmap="gray")
98
+ ax[c].axis("off")
99
+ ax[c].set_title(f'slice {slice_num}')
100
+ else:
101
+ if z_dim_last: ax[r][c].imshow(img[0,:,:,slice_num], cmap="gray")
102
+ else: ax[r][c].imshow(img[slice_num,0,:,:], cmap="gray")
103
+ ax[r][c].axis("off")
104
+ ax[r][c].set_title(f'slice {slice_num}')
105
+
106
+ plt.suptitle(title, size=14)
107
+
108
+ plt.tight_layout()
109
+ plt.show()