maryann-gitonga commited on
Commit
e2e0e81
·
1 Parent(s): 9ad4976

Upload 3D_Brain_Tumor_Segmentation_Attention_UNet.ipynb

Browse files
3D_Brain_Tumor_Segmentation_Attention_UNet.ipynb ADDED
@@ -0,0 +1,1105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "TdEse3Kwq3JD"
7
+ },
8
+ "source": [
9
+ "# Import Necessary Libraries"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "id": "WRKzuv_5owuz"
17
+ },
18
+ "outputs": [],
19
+ "source": [
20
+ "import numpy as np\n",
21
+ "import nibabel as nib\n",
22
+ "import glob\n",
23
+ "from tensorflow.keras.utils import to_categorical # multiclass semantic segmentation, therefore the volumes to categorical\n",
24
+ "import matplotlib.pyplot as plt\n",
25
+ "from tifffile import imsave\n",
26
+ "from sklearn.preprocessing import MinMaxScaler #scale values\n",
27
+ "import tensorflow as tf\n",
28
+ "import random\n",
29
+ "import os.path\n",
30
+ "!pip install split-folders\n",
31
+ "!pip3 install -U segmentation-models-3D\n",
32
+ "import splitfolders\n",
33
+ "!pip install -q -U keras-tuner"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "metadata": {
40
+ "id": "vEtRg2vutWru"
41
+ },
42
+ "outputs": [],
43
+ "source": [
44
+ "# To always ensure that the GPU is available\n",
45
+ "import tensorflow as tf\n",
46
+ "device_name = tf.test.gpu_device_name()\n",
47
+ "if device_name != '/device:GPU:0':\n",
48
+ " raise SystemError('GPU device not found')\n",
49
+ "print('Found GPU at: {}'.format(device_name))"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "markdown",
54
+ "metadata": {
55
+ "id": "L5yBxROtvDAI"
56
+ },
57
+ "source": [
58
+ "# Define the MinMax Scaler + Mount Drive to access Dataset\n",
59
+ "\n",
60
+ "* The MinMax scaler is necessary for transforming the scans' features to a range between 0 and 1"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "metadata": {
67
+ "id": "sqMRiba8q-30"
68
+ },
69
+ "outputs": [],
70
+ "source": [
71
+ "scaler = MinMaxScaler()\n",
72
+ "\n",
73
+ "from google.colab import drive\n",
74
+ "drive.mount('/content/drive')"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "markdown",
79
+ "metadata": {
80
+ "id": "XH4_Z5f2sfxZ"
81
+ },
82
+ "source": [
83
+ "# Load sample images and visualize\n",
84
+ "\n"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "metadata": {
91
+ "id": "SvfI9iTrrZuN"
92
+ },
93
+ "outputs": [],
94
+ "source": [
95
+ "DATASET_PATH = ''\n",
96
+ "\n",
97
+ "test_image_flair = nib.load(DATASET_PATH + 'flair.nii').get_fdata()\n",
98
+ "print(test_image_flair[156][98][78])\n",
99
+ "test_image_flair = scaler.fit_transform(test_image_flair.reshape(-1, test_image_flair.shape[-1])).reshape(test_image_flair.shape)\n",
100
+ "print(test_image_flair[156][98][78])\n",
101
+ "\n",
102
+ "test_image_t1 = nib.load(DATASET_PATH + 't1.nii').get_fdata()\n",
103
+ "test_image_t1 = scaler.fit_transform(test_image_t1.reshape(-1, test_image_t1.shape[-1])).reshape(test_image_t1.shape)\n",
104
+ "\n",
105
+ "test_image_t1ce = nib.load(DATASET_PATH + 't1ce.nii').get_fdata()\n",
106
+ "test_image_t1ce = scaler.fit_transform(test_image_t1ce.reshape(-1, test_image_t1ce.shape[-1])).reshape(test_image_t1ce.shape)\n",
107
+ "\n",
108
+ "test_image_t2 = nib.load(DATASET_PATH + 't2.nii').get_fdata()\n",
109
+ "test_image_t2 = scaler.fit_transform(test_image_t2.reshape(-1, test_image_t2.shape[-1])).reshape(test_image_t2.shape)\n",
110
+ "\n",
111
+ "test_mask = nib.load(DATASET_PATH + 'seg.nii').get_fdata()\n",
112
+ "test_mask = test_mask.astype(np.uint8)\n",
113
+ "\n",
114
+ "print(np.unique(test_mask))\n",
115
+ "# Reassign label value 4 to 3\n",
116
+ "test_mask[test_mask==4] = 3\n",
117
+ "print(np.unique(test_mask))"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "metadata": {
124
+ "id": "aTkjA-mgwecE"
125
+ },
126
+ "outputs": [],
127
+ "source": [
128
+ "n_slice = random.randint(0, test_mask.shape[2])\n",
129
+ "\n",
130
+ "plt.figure(figsize=(12,8))\n",
131
+ "plt.subplot(231)\n",
132
+ "plt.imshow(test_image_flair[:, :, n_slice], cmap='gray')\n",
133
+ "plt.title('Flair Scan')\n",
134
+ "\n",
135
+ "plt.subplot(232)\n",
136
+ "plt.imshow(test_image_t1[:, :, n_slice], cmap='gray')\n",
137
+ "plt.title('T1 Scan')\n",
138
+ "\n",
139
+ "plt.subplot(233)\n",
140
+ "plt.imshow(test_image_t1ce[:, :, n_slice], cmap='gray')\n",
141
+ "plt.title('T1ce Scan')\n",
142
+ "\n",
143
+ "plt.subplot(234)\n",
144
+ "plt.imshow(test_image_t2[:, :, n_slice], cmap='gray')\n",
145
+ "plt.title('T2 Scan')\n",
146
+ "\n",
147
+ "plt.subplot(235)\n",
148
+ "plt.imshow(test_mask[:, :, n_slice])\n",
149
+ "plt.title('Mask')\n",
150
+ "\n",
151
+ "plt.show()\n",
152
+ "\n"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "markdown",
157
+ "metadata": {
158
+ "id": "EORoZoj7yPfW"
159
+ },
160
+ "source": [
161
+ "# Data Processing: Combining the volumes of scans to one + Cropping the scans and masks\n",
162
+ "\n",
163
+ "* The numpy array is reshaped to 2D, the dimensions the scaler can take as input, the array is transformed and then reshaped back to 3D\n",
164
+ "* Result: the feature at position [156][98][78] of the loaded FLAIR scan numpy array is transformed from 1920.0 to 0.7683...\n",
165
+ "* The three scans to be used are stacked together to forme a combined scan.\n",
166
+ "* Result: A FLAIR scan, a T1CE scan and a T2 scan, all of dimensions 255 x 255 x 155 are stacked to form a combined scan of dimensions 255 x 255 x 155 x 3\n",
167
+ "* The combined scan is cropped to 128 x 128 x 128 x 3\n",
168
+ "* Label 4 in the dataset is reassigned to label 3 resulting to a continuous list of labels: 0, 1, 2, 3"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "metadata": {
175
+ "id": "-3u91yIqybn-"
176
+ },
177
+ "outputs": [],
178
+ "source": [
179
+ "combined_x = np.stack([test_image_flair, test_image_t1ce, test_image_t2], axis=3)\n",
180
+ "combined_x = combined_x[56:184, 56:184, 13:141] #crop to 128 x 128 x 128 X 3\n",
181
+ "\n",
182
+ "test_mask = test_mask[56:184, 56:184, 13:141]\n",
183
+ "n_slice = random.randint(0, test_mask.shape[1])\n",
184
+ "plt.figure(figsize=(12, 8))\n",
185
+ "\n",
186
+ "plt.subplot(231)\n",
187
+ "plt.imshow(combined_x[:, :, n_slice, 0], cmap='gray')\n",
188
+ "plt.title('Flair Scan')\n",
189
+ "\n",
190
+ "plt.subplot(232)\n",
191
+ "plt.imshow(combined_x[:, :, n_slice, 1], cmap='gray')\n",
192
+ "plt.title('T1ce Scan')\n",
193
+ "\n",
194
+ "plt.subplot(233)\n",
195
+ "plt.imshow(combined_x[:, :, n_slice, 2], cmap='gray')\n",
196
+ "plt.title('T2 Scan')\n",
197
+ "\n",
198
+ "plt.subplot(234)\n",
199
+ "plt.imshow(test_mask[:, :, n_slice])\n",
200
+ "plt.title('Mask')\n",
201
+ "\n",
202
+ "plt.show()"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "metadata": {
209
+ "id": "T8r7sy4QND41"
210
+ },
211
+ "outputs": [],
212
+ "source": [
213
+ "from tensorflow.keras import backend as K\n",
214
+ "\n",
215
+ "print(K.int_shape(test_image_flair))\n",
216
+ "\n",
217
+ "print(K.int_shape(combined_x))"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": null,
223
+ "metadata": {
224
+ "id": "WeD_PqCv6Vww"
225
+ },
226
+ "outputs": [],
227
+ "source": [
228
+ "flair_list = sorted(glob.glob(DATASET_PATH + '*/flair.nii'))\n",
229
+ "t1_list = sorted(glob.glob(DATASET_PATH + '*/t1.nii'))\n",
230
+ "t1ce_list = sorted(glob.glob(DATASET_PATH + '*/t1ce.nii'))\n",
231
+ "t2_list = sorted(glob.glob(DATASET_PATH + '*/t2.nii'))\n",
232
+ "mask_list = sorted(glob.glob(DATASET_PATH + '*/seg.nii'))\n",
233
+ "\n",
234
+ "\n",
235
+ "for img in range(len(flair_list)):\n",
236
+ " print('Now processing image and masks no: ', img)\n",
237
+ "\n",
238
+ " temp_image_flair = nib.load(flair_list[img]).get_fdata()\n",
239
+ " temp_image_flair = scaler.fit_transform(temp_image_flair.reshape(-1, temp_image_flair.shape[-1])).reshape(temp_image_flair.shape)\n",
240
+ "\n",
241
+ " temp_image_t1 = nib.load(t1_list[img]).get_fdata()\n",
242
+ " temp_image_t1 = scaler.fit_transform(temp_image_t1.reshape(-1, temp_image_t1.shape[-1])).reshape(temp_image_t1.shape)\n",
243
+ "\n",
244
+ " temp_image_t1ce = nib.load(t1ce_list[img]).get_fdata()\n",
245
+ " temp_image_t1ce = scaler.fit_transform(temp_image_t1ce.reshape(-1, temp_image_t1ce.shape[-1])).reshape(temp_image_t1ce.shape)\n",
246
+ "\n",
247
+ " temp_image_t2 = nib.load(t2_list[img]).get_fdata()\n",
248
+ " temp_image_t2 = scaler.fit_transform(temp_image_t2.reshape(-1, temp_image_t2.shape[-1])).reshape(temp_image_t2.shape)\n",
249
+ "\n",
250
+ " temp_mask = nib.load(mask_list[img]).get_fdata()\n",
251
+ " temp_mask = temp_mask.astype(np.uint8)\n",
252
+ " temp_mask[temp_mask == 4] = 3\n",
253
+ "\n",
254
+ " temp_combined_images = np.stack([temp_image_flair, temp_image_t1, temp_image_t1ce, temp_image_t2], axis = 3)\n",
255
+ " temp_combined_images = temp_combined_images[56:184, 56:184, 13:141]\n",
256
+ " temp_mask = temp_mask[56:184, 56:184, 13:141]\n",
257
+ "\n",
258
+ " val, counts = np.unique(temp_mask, return_counts=True)\n",
259
+ "\n",
260
+ " if(1 - (counts[0]/counts.sum())) > 0.01:\n",
261
+ " temp_mask = to_categorical(temp_mask, num_classes=4)\n",
262
+ " np.save(DATASET_PATH + 'final_dataset/scans/image_' + str(img) + '.npy', temp_combined_images)\n",
263
+ " np.save(DATASET_PATH + 'final_dataset/masks/image_' + str(img) + '.npy', temp_mask)\n",
264
+ " print(\"Saved\")\n",
265
+ " else:\n",
266
+ " print(\"Not saved\")"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "markdown",
271
+ "metadata": {
272
+ "id": "-wICUx56ugDz"
273
+ },
274
+ "source": [
275
+ "# Dataset Splitting: 60:20:20 for train, val and test"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": null,
281
+ "metadata": {
282
+ "id": "Oi_g5D01HSnq"
283
+ },
284
+ "outputs": [],
285
+ "source": [
286
+ "input_folder = DATASET_PATH + 'final_dataset/'\n",
287
+ "output_folder = DATASET_PATH + 'split_dataset/'\n",
288
+ "splitfolders.ratio(input_folder, output=output_folder, seed=42, ratio=(.6, .2, .2), group_prefix=None)"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "markdown",
293
+ "metadata": {
294
+ "id": "RtaRf0B4kPkM"
295
+ },
296
+ "source": [
297
+ "# Data Generator\n",
298
+ "\n",
299
+ "\n"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "code",
304
+ "execution_count": null,
305
+ "metadata": {
306
+ "id": "UMfHysy2ixc8"
307
+ },
308
+ "outputs": [],
309
+ "source": [
310
+ "import os\n",
311
+ "import numpy as np\n",
312
+ "\n",
313
+ "def load_img(img_dir, img_list):\n",
314
+ " images=[]\n",
315
+ " for i, image_name in enumerate(img_list):\n",
316
+ " if(image_name.split('.')[1] == 'npy'):\n",
317
+ " image = np.load(img_dir + image_name)\n",
318
+ " images.append(image)\n",
319
+ " images = np.array(images)\n",
320
+ " return images\n",
321
+ "\n",
322
+ "def imageLoader(img_dir, img_list, mask_dir, mask_list, batch_size):\n",
323
+ " L = len(img_list)\n",
324
+ " # keras needs the generator infinite, so use while True\n",
325
+ " while True:\n",
326
+ " batch_start = 0\n",
327
+ " batch_end = batch_size\n",
328
+ "\n",
329
+ " while batch_start < L:\n",
330
+ " limit = min(batch_end, L)\n",
331
+ " X = load_img(img_dir, img_list[batch_start:limit])\n",
332
+ " Y = load_img(mask_dir, mask_list[batch_start:limit])\n",
333
+ "\n",
334
+ " yield(X, Y) # a tuple with two numpy arrays with batch_size samples\n",
335
+ "\n",
336
+ " batch_start += batch_size\n",
337
+ " batch_end += batch_size\n",
338
+ "\n",
339
+ "\n",
340
+ "# Test the generator\n",
341
+ "TRAIN_DATASET_PATH = ''\n",
342
+ "train_img_dir = TRAIN_DATASET_PATH + 'scans/'\n",
343
+ "train_mask_dir = TRAIN_DATASET_PATH + 'masks/'\n",
344
+ "\n",
345
+ "train_img_list = os.listdir(train_img_dir)\n",
346
+ "train_mask_list = os.listdir(train_mask_dir)\n",
347
+ "\n",
348
+ "batch_size = 2\n",
349
+ "\n",
350
+ "train_img_datagen = imageLoader(train_img_dir, train_img_list,\n",
351
+ " train_mask_dir, train_mask_list, batch_size)\n",
352
+ "\n",
353
+ "# Verify generator - In python 3 next() is renamed as __next__()\n",
354
+ "img, msk = train_img_datagen.__next__()\n",
355
+ "\n",
356
+ "img_num = random.randint(0, img.shape[0]-1)\n",
357
+ "\n",
358
+ "test_img = img[img_num]\n",
359
+ "test_mask = msk[img_num]\n",
360
+ "test_mask = np.argmax(test_mask, axis=3)\n",
361
+ "\n",
362
+ "n_slice = random.randint(0, test_mask.shape[2])\n",
363
+ "plt.figure(figsize=(12,8))\n",
364
+ "\n",
365
+ "plt.subplot(221)\n",
366
+ "plt.imshow(test_img[:, :, n_slice, 0], cmap='gray')\n",
367
+ "plt.title('Flair Scan')\n",
368
+ "\n",
369
+ "plt.subplot(222)\n",
370
+ "plt.imshow(test_img[:, :, n_slice, 1], cmap='gray')\n",
371
+ "plt.title('T1ce Scan')\n",
372
+ "\n",
373
+ "plt.subplot(223)\n",
374
+ "plt.imshow(test_img[:, :, n_slice, 2], cmap='gray')\n",
375
+ "plt.title('T2 Scan')\n",
376
+ "\n",
377
+ "plt.subplot(224)\n",
378
+ "plt.imshow(test_mask[:, :, n_slice])\n",
379
+ "plt.title('Mask')\n",
380
+ "\n",
381
+ "plt.show()"
382
+ ]
383
+ },
384
+ {
385
+ "cell_type": "markdown",
386
+ "metadata": {
387
+ "id": "ReTmFPr0QV17"
388
+ },
389
+ "source": [
390
+ "# Define image generators for training, validation and testing"
391
+ ]
392
+ },
393
+ {
394
+ "cell_type": "code",
395
+ "execution_count": null,
396
+ "metadata": {
397
+ "id": "HS9Dihs_QbqU"
398
+ },
399
+ "outputs": [],
400
+ "source": [
401
+ "DATASET_PATH = ''\n",
402
+ "train_img_dir = DATASET_PATH + 'train/scans/'\n",
403
+ "train_mask_dir = DATASET_PATH + 'train/masks/'\n",
404
+ "\n",
405
+ "val_img_dir = DATASET_PATH + 'val/scans/'\n",
406
+ "val_mask_dir = DATASET_PATH + 'val/masks/'\n",
407
+ "\n",
408
+ "test_img_dir = DATASET_PATH + 'test/scans/'\n",
409
+ "test_mask_dir = DATASET_PATH + 'test/masks/'\n",
410
+ "\n",
411
+ "train_img_list = os.listdir(train_img_dir)\n",
412
+ "train_mask_list = os.listdir(train_mask_dir)\n",
413
+ "\n",
414
+ "val_img_list = os.listdir(val_img_dir)\n",
415
+ "val_mask_list = os.listdir(val_mask_dir)\n",
416
+ "\n",
417
+ "test_img_list = os.listdir(test_img_dir)\n",
418
+ "test_mask_list = os.listdir(test_mask_dir)\n",
419
+ "\n",
420
+ "batch_size = 2\n",
421
+ "train_img_datagen = imageLoader(train_img_dir, train_img_list,\n",
422
+ " train_mask_dir, train_mask_list, batch_size)\n",
423
+ "\n",
424
+ "val_img_datagen = imageLoader(val_img_dir, val_img_list,\n",
425
+ " val_mask_dir, val_mask_list, batch_size)\n",
426
+ "\n",
427
+ "test_img_datagen = imageLoader(test_img_dir, test_img_list,\n",
428
+ " test_mask_dir, test_mask_list, batch_size)\n"
429
+ ]
430
+ },
431
+ {
432
+ "cell_type": "markdown",
433
+ "metadata": {
434
+ "id": "dBKMHMn96Z3c"
435
+ },
436
+ "source": [
437
+ "# Losses and metrics\n",
438
+ "* These losses and metrics best handle the problem of class imbalance\n",
439
+ "* Used: dice_coef as a metric, tversky_loss as a loss"
440
+ ]
441
+ },
442
+ {
443
+ "cell_type": "code",
444
+ "execution_count": null,
445
+ "metadata": {
446
+ "id": "pshixCsr6eyt"
447
+ },
448
+ "outputs": [],
449
+ "source": [
450
+ "import tensorflow.keras.backend as K\n",
451
+ "\n",
452
+ "\n",
453
+ "def dice_coef(y_true, y_pred, smooth=1):\n",
454
+ " y_true_f = K.flatten(y_true)\n",
455
+ " y_pred_f = K.flatten(y_pred)\n",
456
+ " intersection = K.sum(y_true_f * y_pred_f)\n",
457
+ " return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) +\n",
458
+ " smooth)\n",
459
+ "\n",
460
+ "\n",
461
+ "def dice_coef_loss(y_true, y_pred):\n",
462
+ " return 1 - dice_coef(y_true, y_pred)\n",
463
+ "\n",
464
+ "\n",
465
+ "def tversky(y_true, y_pred, smooth=1, alpha=0.7):\n",
466
+ " y_true_pos = K.flatten(y_true)\n",
467
+ " y_pred_pos = K.flatten(y_pred)\n",
468
+ " true_pos = K.sum(y_true_pos * y_pred_pos)\n",
469
+ " false_neg = K.sum(y_true_pos * (1 - y_pred_pos))\n",
470
+ " false_pos = K.sum((1 - y_true_pos) * y_pred_pos)\n",
471
+ " return (true_pos + smooth) / (true_pos + alpha * false_neg +\n",
472
+ " (1 - alpha) * false_pos + smooth)\n",
473
+ "\n",
474
+ "\n",
475
+ "def tversky_loss(y_true, y_pred):\n",
476
+ " return 1 - tversky(y_true, y_pred)\n",
477
+ "\n",
478
+ "\n",
479
+ "def focal_tversky_loss(y_true, y_pred, gamma=0.75):\n",
480
+ " tv = tversky(y_true, y_pred)\n",
481
+ " return K.pow((1 - tv), gamma)"
482
+ ]
483
+ },
484
+ {
485
+ "cell_type": "markdown",
486
+ "metadata": {
487
+ "id": "2o2WuIhaW5ff"
488
+ },
489
+ "source": [
490
+ "# Define loss, metrics and optimizer to be used for training"
491
+ ]
492
+ },
493
+ {
494
+ "cell_type": "code",
495
+ "execution_count": null,
496
+ "metadata": {
497
+ "id": "WxiJ1eUQXJ4I"
498
+ },
499
+ "outputs": [],
500
+ "source": [
501
+ "from keras.models import Model\n",
502
+ "from keras.layers import Input, Conv3D, MaxPooling3D, Activation, add, concatenate, Conv3DTranspose, BatchNormalization, Dropout, UpSampling3D, multiply\n",
503
+ "from tensorflow.keras.optimizers import Adam\n",
504
+ "from keras import layers\n",
505
+ "\n",
506
+ "kernel_initializer = 'he_uniform'\n",
507
+ "\n",
508
+ "import segmentation_models_3D as sm\n",
509
+ "\n",
510
+ "metrics = [dice_coef]\n",
511
+ "\n",
512
+ "LR = 0.0001\n",
513
+ "optim = Adam(LR)\n",
514
+ "\n",
515
+ "steps_per_epoch = len(train_img_list) // batch_size\n",
516
+ "val_steps_per_epoch = len(val_img_list) // batch_size"
517
+ ]
518
+ },
519
+ {
520
+ "cell_type": "markdown",
521
+ "metadata": {
522
+ "id": "PR2Ugre0YP-v"
523
+ },
524
+ "source": [
525
+ "# 3D UNet Model"
526
+ ]
527
+ },
528
+ {
529
+ "cell_type": "code",
530
+ "execution_count": null,
531
+ "metadata": {
532
+ "id": "N0VyhdjCYVuZ"
533
+ },
534
+ "outputs": [],
535
+ "source": [
536
+ "def UNet(IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS, num_classes):\n",
537
+ " inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS))\n",
538
+ "\n",
539
+ " # Downsampling\n",
540
+ " c1 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(inputs)\n",
541
+ " c1 = Dropout(0.1)(c1)\n",
542
+ " c1 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c1)\n",
543
+ " p1 = MaxPooling3D((2, 2, 2))(c1)\n",
544
+ "\n",
545
+ " c2 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(p1)\n",
546
+ " c2 = Dropout(0.1)(c2)\n",
547
+ " c2 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c2)\n",
548
+ " p2 = MaxPooling3D((2, 2, 2))(c2)\n",
549
+ "\n",
550
+ " c3 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(p2)\n",
551
+ " c3 = Dropout(0.2)(c3)\n",
552
+ " c3 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c3)\n",
553
+ " p3 = MaxPooling3D((2, 2, 2))(c3)\n",
554
+ "\n",
555
+ " c4 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(p3)\n",
556
+ " c4 = Dropout(0.2)(c4)\n",
557
+ " c4 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c4)\n",
558
+ " p4 = MaxPooling3D((2, 2, 2))(c4)\n",
559
+ "\n",
560
+ " c5 = Conv3D(filters = 256, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(p4)\n",
561
+ " c5 = Dropout(0.3)(c5)\n",
562
+ " c5 = Conv3D(filters = 256, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c5)\n",
563
+ " \n",
564
+ " # Upsampling part\n",
565
+ " u6 = Conv3DTranspose(128, (2, 2, 2), strides=(2, 2, 2), padding='same')(c5)\n",
566
+ " u6 = concatenate([u6, c4])\n",
567
+ " c6 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(u6)\n",
568
+ " c6 = Dropout(0.2)(c6)\n",
569
+ " c6 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c6) \n",
570
+ " \n",
571
+ " u7 = Conv3DTranspose(64, (2, 2, 2), strides=(2, 2, 2), padding='same')(c6)\n",
572
+ " u7 = concatenate([u7, c3])\n",
573
+ " c7 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(u7)\n",
574
+ " c7 = Dropout(0.2)(c7)\n",
575
+ " c7 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c7) \n",
576
+ " \n",
577
+ " u8 = Conv3DTranspose(32, (2, 2, 2), strides=(2, 2, 2), padding='same')(c7)\n",
578
+ " u8 = concatenate([u8, c2])\n",
579
+ " c8 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(u8)\n",
580
+ " c8 = Dropout(0.1)(c8)\n",
581
+ " c8 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c8) \n",
582
+ "\n",
583
+ " u9 = Conv3DTranspose(16, (2, 2, 2), strides=(2, 2, 2), padding='same')(c8)\n",
584
+ " u9 = concatenate([u9, c1])\n",
585
+ " c9 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(u9)\n",
586
+ " c9 = Dropout(0.1)(c9)\n",
587
+ " c9 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c9) \n",
588
+ "\n",
589
+ " outputs = Conv3D(num_classes, (1, 1, 1), activation='softmax')(c9)\n",
590
+ "\n",
591
+ " model = Model(inputs=[inputs], outputs=[outputs])\n",
592
+ " model.summary()\n",
593
+ "\n",
594
+ " return model"
595
+ ]
596
+ },
597
+ {
598
+ "cell_type": "markdown",
599
+ "metadata": {
600
+ "id": "-Aw_Peb9iJYb"
601
+ },
602
+ "source": [
603
+ "# Test the working of the 3D UNet model"
604
+ ]
605
+ },
606
+ {
607
+ "cell_type": "code",
608
+ "execution_count": null,
609
+ "metadata": {
610
+ "id": "fjdzCTisiMLI"
611
+ },
612
+ "outputs": [],
613
+ "source": [
614
+ "steps_per_epoch = len(train_img_list)//batch_size\n",
615
+ "val_steps_per_epoch = len(val_img_list)//batch_size\n",
616
+ "\n",
617
+ "model = UNet(IMG_HEIGHT = 128,\n",
618
+ " IMG_WIDTH = 128,\n",
619
+ " IMG_DEPTH = 128,\n",
620
+ " IMG_CHANNELS = 3,\n",
621
+ " num_classes = 4)\n",
622
+ "\n",
623
+ "model.compile(optimizer = optim, loss = tversky_loss, metrics = metrics)\n",
624
+ "\n",
625
+ "print(model.summary)\n",
626
+ "\n",
627
+ "print(model.input_shape)\n",
628
+ "print(model.output_shape)"
629
+ ]
630
+ },
631
+ {
632
+ "cell_type": "markdown",
633
+ "metadata": {
634
+ "id": "e6Cvn6hWvars"
635
+ },
636
+ "source": [
637
+ "# 3D Attention UNet Model"
638
+ ]
639
+ },
640
+ {
641
+ "cell_type": "code",
642
+ "execution_count": null,
643
+ "metadata": {
644
+ "id": "JBcFdz80v2mL"
645
+ },
646
+ "outputs": [],
647
+ "source": [
648
+ "from keras.layers.core.activation import Activation\n",
649
+ "from tensorflow.keras import backend as K\n",
650
+ "from keras.layers import LeakyReLU\n",
651
+ "\n",
652
+ "def repeat_elem(tensor, rep):\n",
653
+ " # lambda function to repeat Repeats the elements of a tensor along an axis\n",
654
+ " #by a factor of rep.\n",
655
+ " # If tensor has shape (None, 128,128,3), lambda will return a tensor of shape \n",
656
+ " #(None, 128,128,6), if specified axis=3 and rep=2.\n",
657
+ "\n",
658
+ " return layers.Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=4),\n",
659
+ " arguments={'repnum': rep})(tensor)\n",
660
+ "\n",
661
+ "def attention_block(x, gating, inter_shape):\n",
662
+ " shape_x = K.int_shape(x)\n",
663
+ " shape_g = K.int_shape(gating)\n",
664
+ "\n",
665
+ " # Getting the gating signal to the same number of filters as the inter_shape\n",
666
+ " phi_g = Conv3D(filters=inter_shape, kernel_size=1, strides=1, padding='same')(gating)\n",
667
+ "\n",
668
+ " # Geting the x signal to the same shape as the gating signal\n",
669
+ " theta_x = Conv3D(filters=inter_shape, kernel_size=3, strides=(\n",
670
+ " shape_x[1] // shape_g[1],\n",
671
+ " shape_x[2] // shape_g[2],\n",
672
+ " shape_x[3] // shape_g[3]\n",
673
+ " ), padding='same')(x)\n",
674
+ " shape_theta_x = K.int_shape(theta_x)\n",
675
+ "\n",
676
+ " print(shape_theta_x, shape_g)\n",
677
+ "\n",
678
+ " # Elemet-wise addition of the gating and x signals\n",
679
+ " xg_sum = add([phi_g, theta_x])\n",
680
+ " xg_sum = Activation('relu')(xg_sum)\n",
681
+ "\n",
682
+ " # 1x1x1 convolution\n",
683
+ " psi = Conv3D(filters=1, kernel_size=1, padding='same')(xg_sum)\n",
684
+ " sigmoid_psi = Activation('sigmoid')(psi)\n",
685
+ " shape_sigmoid = K.int_shape(sigmoid_psi)\n",
686
+ "\n",
687
+ " # Upsampling psi back to the original dimensions of x signal to enable \n",
688
+ " # element-wise multiplication with the signal\n",
689
+ "\n",
690
+ " upsampled_sigmoid_psi = UpSampling3D(size=(\n",
691
+ " shape_x[1] // shape_sigmoid[1], \n",
692
+ " shape_x[2] // shape_sigmoid[2],\n",
693
+ " shape_x[3] // shape_sigmoid[3]\n",
694
+ " ))(sigmoid_psi)\n",
695
+ "\n",
696
+ " # Expand the filter axis to the number of filters in the original x signal\n",
697
+ " upsampled_sigmoid_psi = repeat_elem(upsampled_sigmoid_psi, shape_x[4])\n",
698
+ "\n",
699
+ " # Element-wise multiplication of attention coefficients back onto original x signal\n",
700
+ " attention_coeffs = multiply([upsampled_sigmoid_psi, x])\n",
701
+ "\n",
702
+ " # Final 1x1x1 convolution to consolidate attention signal to original x dimensions\n",
703
+ " output = Conv3D(filters=shape_x[3], kernel_size=1, strides=1, padding='same')(attention_coeffs)\n",
704
+ " output = BatchNormalization()(output)\n",
705
+ " return output\n",
706
+ "\n",
707
+ "\n",
708
+ "# Gating signal\n",
709
+ "def gating_signal(input, output_size, batch_norm=False):\n",
710
+ " # Resize the down layer feature map into the same dimensions as the up layer feature map using 1x1 conv\n",
711
+ " # Return: the gating feature map with the same dimension of the up layer feature map\n",
712
+ " x = Conv3D(output_size, (1, 1, 1), padding='same')(input)\n",
713
+ " if batch_norm:\n",
714
+ " x = BatchNormalization()(x)\n",
715
+ " x = Activation('relu')(x)\n",
716
+ " return x\n",
717
+ "\n",
718
+ "\n",
719
+ "# Attention UNet\n",
720
+ "def attention_unet(IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS, num_classes, batch_norm = True):\n",
721
+ " inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS))\n",
722
+ " FILTER_NUM = 64 #\n",
723
+ " FILTER_SIZE = 3 #\n",
724
+ " UP_SAMPLING_SIZE = 2 # \n",
725
+ "\n",
726
+ " c1 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(inputs)\n",
727
+ " c1 = Dropout(0.1)(c1)\n",
728
+ " c1 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c1)\n",
729
+ " p1 = MaxPooling3D((2, 2, 2))(c1)\n",
730
+ "\n",
731
+ " c2 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(p1)\n",
732
+ " c2 = Dropout(0.1)(c2)\n",
733
+ " c2 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c2)\n",
734
+ " p2 = MaxPooling3D((2, 2, 2))(c2)\n",
735
+ "\n",
736
+ " c3 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(p2)\n",
737
+ " c3 = Dropout(0.2)(c3)\n",
738
+ " c3 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c3)\n",
739
+ " p3 = MaxPooling3D((2, 2, 2))(c3)\n",
740
+ "\n",
741
+ " c4 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(p3)\n",
742
+ " c4 = Dropout(0.2)(c4)\n",
743
+ " c4 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c4)\n",
744
+ " p4 = MaxPooling3D((2, 2, 2))(c4)\n",
745
+ "\n",
746
+ " c5 = Conv3D(filters = 256, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(p4)\n",
747
+ " c5 = Dropout(0.3)(c5)\n",
748
+ " c5 = Conv3D(filters = 256, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c5)\n",
749
+ " \n",
750
+ "\n",
751
+ " gating_6 = gating_signal(c5, 128, batch_norm)\n",
752
+ " att_6 = attention_block(c4, gating_6, 128)\n",
753
+ " u6 = UpSampling3D((2, 2, 2), data_format='channels_last')(c5)\n",
754
+ " u6 = concatenate([u6, att_6])\n",
755
+ " c6 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(u6)\n",
756
+ " c6 = Dropout(0.2)(c6)\n",
757
+ " c6 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c6) \n",
758
+ " \n",
759
+ " gating_7 = gating_signal(c6, 64, batch_norm)\n",
760
+ " att_7 = attention_block(c3, gating_6, 64)\n",
761
+ " u7 = UpSampling3D((2, 2, 2), data_format='channels_last')(c6)\n",
762
+ " u7 = concatenate([u7, att_7])\n",
763
+ " c7 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(u7)\n",
764
+ " c7 = Dropout(0.2)(c7)\n",
765
+ " c7 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c7) \n",
766
+ " \n",
767
+ " gating_8 = gating_signal(c7, 64, batch_norm)\n",
768
+ " att_8 = attention_block(c2, gating_6, 64)\n",
769
+ " u8 = UpSampling3D((2, 2, 2), data_format='channels_last')(c7)\n",
770
+ " u8 = concatenate([u8, att_8])\n",
771
+ " c8 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(u8)\n",
772
+ " c8 = Dropout(0.1)(c8)\n",
773
+ " c8 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c8) \n",
774
+ "\n",
775
+ " gating_9 = gating_signal(c8, 64, batch_norm)\n",
776
+ " att_9 = attention_block(c1, gating_6, 64)\n",
777
+ " u9 = UpSampling3D((2, 2, 2), data_format='channels_last')(c8)\n",
778
+ " u9 = concatenate([u9, att_9])\n",
779
+ " c9 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(u9)\n",
780
+ " c9 = Dropout(0.1)(c9)\n",
781
+ " c9 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c9) \n",
782
+ "\n",
783
+ " outputs = Conv3D(num_classes, (1, 1, 1))(c9)\n",
784
+ " outputs = BatchNormalization()(outputs)\n",
785
+ " outputs = Activation('softmax')(outputs)\n",
786
+ "\n",
787
+ " model = Model(inputs=[inputs], outputs=[outputs], name=\"Attention_UNet\")\n",
788
+ " model.summary()\n",
789
+ "\n",
790
+ " return model"
791
+ ]
792
+ },
793
+ {
794
+ "cell_type": "markdown",
795
+ "metadata": {
796
+ "id": "xndmsEwjVhn7"
797
+ },
798
+ "source": [
799
+ "# Test the working of a 3D Attention UNet Model"
800
+ ]
801
+ },
802
+ {
803
+ "cell_type": "code",
804
+ "execution_count": null,
805
+ "metadata": {
806
+ "id": "pBNjxGbjVn9U"
807
+ },
808
+ "outputs": [],
809
+ "source": [
810
+ "steps_per_epoch = len(train_img_list)//batch_size\n",
811
+ "val_steps_per_epoch = len(val_img_list)//batch_size\n",
812
+ "\n",
813
+ "model = attention_unet(IMG_HEIGHT = 128,\n",
814
+ " IMG_WIDTH = 128,\n",
815
+ " IMG_DEPTH = 128,\n",
816
+ " IMG_CHANNELS = 3,\n",
817
+ " num_classes = 4)\n",
818
+ "\n",
819
+ "model.compile(optimizer = optim, loss = tversky_loss, metrics = metrics)\n",
820
+ "\n",
821
+ "print(model.summary)\n",
822
+ "\n",
823
+ "print(model.input_shape)\n",
824
+ "print(model.output_shape)"
825
+ ]
826
+ },
827
+ {
828
+ "cell_type": "markdown",
829
+ "metadata": {
830
+ "id": "8qnlrlr1YXu4"
831
+ },
832
+ "source": [
833
+ "# Fit the Model"
834
+ ]
835
+ },
836
+ {
837
+ "cell_type": "code",
838
+ "execution_count": null,
839
+ "metadata": {
840
+ "id": "UXmCjFvjYaSG"
841
+ },
842
+ "outputs": [],
843
+ "source": [
844
+ "import tensorflow.keras as keras\n",
845
+ "from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger, TerminateOnNaN\n",
846
+ "\n",
847
+ "checkpoint_path = ''\n",
848
+ "log_path = ''\n",
849
+ "\n",
850
+ "callbacks = [\n",
851
+ " EarlyStopping(monitor='val_loss', patience=4, verbose=1),\n",
852
+ " ReduceLROnPlateau(factor=0.1,\n",
853
+ " monitor='val_loss',\n",
854
+ " patience=4,\n",
855
+ " min_lr=0.0001,\n",
856
+ " verbose=1,\n",
857
+ " mode='min'),\n",
858
+ " ModelCheckpoint(checkpoint_path,\n",
859
+ " monitor='val_loss',\n",
860
+ " mode='min',\n",
861
+ " verbose=0,\n",
862
+ " save_best_only=True),\n",
863
+ " CSVLogger(log_path, separator=',', append=True),\n",
864
+ " TerminateOnNaN()\n",
865
+ "]\n",
866
+ "\n",
867
+ "history = model.fit(train_img_datagen,\n",
868
+ " steps_per_epoch=steps_per_epoch,\n",
869
+ " epochs=100,\n",
870
+ " verbose=1,\n",
871
+ " validation_data=val_img_datagen,\n",
872
+ " validation_steps=val_steps_per_epoch,\n",
873
+ " callbacks=callbacks\n",
874
+ " )\n",
875
+ "\n",
876
+ "history_callback = np.save('', history.history)"
877
+ ]
878
+ },
879
+ {
880
+ "cell_type": "markdown",
881
+ "metadata": {
882
+ "id": "pfcKmJv4jP2J"
883
+ },
884
+ "source": [
885
+ "# Load Model for more training"
886
+ ]
887
+ },
888
+ {
889
+ "cell_type": "code",
890
+ "execution_count": null,
891
+ "metadata": {
892
+ "id": "7RXukeY_jiad"
893
+ },
894
+ "outputs": [],
895
+ "source": [
896
+ "import tensorflow.keras.models as load\n",
897
+ "import keras\n",
898
+ "model = load.load_model('', custom_objects={\n",
899
+ " 'tversky_loss': tversky_loss,\n",
900
+ " 'dice_coef': dice_coef\n",
901
+ "})\n",
902
+ "\n",
903
+ "checkpoint_path = ''\n",
904
+ "log_path = ''\n",
905
+ "\n",
906
+ "callbacks = [\n",
907
+ " EarlyStopping(monitor='val_loss', patience=4, verbose=1),\n",
908
+ " ReduceLROnPlateau(factor=0.1,\n",
909
+ " monitor='val_loss',\n",
910
+ " patience=4,\n",
911
+ " min_lr=0.0001,\n",
912
+ " verbose=1,\n",
913
+ " mode='min'),\n",
914
+ " ModelCheckpoint(checkpoint_path,\n",
915
+ " monitor='val_loss',\n",
916
+ " mode='min',\n",
917
+ " verbose=0,\n",
918
+ " save_best_only=True),\n",
919
+ " CSVLogger(log_path, separator=',', append=True),\n",
920
+ " TerminateOnNaN()\n",
921
+ "]\n",
922
+ "\n",
923
+ "history = model.fit(train_img_datagen,\n",
924
+ " steps_per_epoch=steps_per_epoch,\n",
925
+ " epochs=100,\n",
926
+ " verbose=1,\n",
927
+ " validation_data=val_img_datagen,\n",
928
+ " validation_steps=val_steps_per_epoch,\n",
929
+ " callbacks=callbacks\n",
930
+ " )\n",
931
+ "\n",
932
+ "history_callback = np.save('', history.history)"
933
+ ]
934
+ },
935
+ {
936
+ "cell_type": "markdown",
937
+ "metadata": {
938
+ "id": "SPBUC1HIfqDt"
939
+ },
940
+ "source": [
941
+ "# Plot the training and validation loss (tversky) and dice coefficient (metric) at each epoch"
942
+ ]
943
+ },
944
+ {
945
+ "cell_type": "code",
946
+ "execution_count": null,
947
+ "metadata": {
948
+ "id": "I7e4YkM5f1Jg"
949
+ },
950
+ "outputs": [],
951
+ "source": [
952
+ "history = np.load('',allow_pickle='TRUE').item()\n",
953
+ "\n",
954
+ "print(history)\n",
955
+ "loss = history['loss']\n",
956
+ "val_loss = history['val_loss']\n",
957
+ "epochs = range(1, len(loss) + 1)\n",
958
+ "plt.plot(epochs, loss, 'y', label='Training loss')\n",
959
+ "plt.plot(epochs, val_loss, 'r', label='Validation loss')\n",
960
+ "plt.title('Training and Validation Loss')\n",
961
+ "plt.xlabel('Epochs')\n",
962
+ "plt.ylabel('Loss')\n",
963
+ "plt.legend()\n",
964
+ "plt.show()\n",
965
+ "\n",
966
+ "acc = history['dice_coef']\n",
967
+ "val_acc = history['val_dice_coef']\n",
968
+ "\n",
969
+ "plt.plot(epochs, acc, 'y', label='Training accuracy')\n",
970
+ "plt.plot(epochs, val_acc, 'r', label='Validation accuracy')\n",
971
+ "plt.title('Trainign and Validation Accuracy')\n",
972
+ "plt.xlabel('Epochs')\n",
973
+ "plt.ylabel('Accuracy')\n",
974
+ "plt.legend()\n",
975
+ "plt.show()"
976
+ ]
977
+ },
978
+ {
979
+ "cell_type": "markdown",
980
+ "metadata": {
981
+ "id": "XV8kjMkemQ-W"
982
+ },
983
+ "source": [
984
+ "# Model Evaluation"
985
+ ]
986
+ },
987
+ {
988
+ "cell_type": "code",
989
+ "execution_count": null,
990
+ "metadata": {
991
+ "id": "ChhYHB8PmTnK"
992
+ },
993
+ "outputs": [],
994
+ "source": [
995
+ "from tensorflow.keras.models import load_model\n",
996
+ "my_model = load_model('', custom_objects={\n",
997
+ " 'tversky_loss': tversky_loss,\n",
998
+ " 'dice_coef': dice_coef},\n",
999
+ " compile = True)\n",
1000
+ "\n",
1001
+ "# Verify IoU on a batch of images from the test dataset\n",
1002
+ "batch_size = 8\n",
1003
+ "test_img_datagen = imageLoader(val_img_dir, val_img_list,\n",
1004
+ " val_mask_dir, val_mask_list, batch_size)\n",
1005
+ "\n",
1006
+ "test_image_batch, test_mask_batch = test_img_datagen.__next__()\n",
1007
+ "\n",
1008
+ "test_mask_batch_argmax = np.argmax(test_mask_batch, axis=4)\n",
1009
+ "\n",
1010
+ "results = my_model.evaluate(test_image_batch, test_mask_batch, batch_size=batch_size)\n",
1011
+ "print(\"test acc, test loss:\", results)"
1012
+ ]
1013
+ },
1014
+ {
1015
+ "cell_type": "markdown",
1016
+ "metadata": {
1017
+ "id": "xvEqiU6SqY2y"
1018
+ },
1019
+ "source": [
1020
+ "# Predict on a test scan"
1021
+ ]
1022
+ },
1023
+ {
1024
+ "cell_type": "code",
1025
+ "execution_count": null,
1026
+ "metadata": {
1027
+ "id": "8-MUQpCiqcxd"
1028
+ },
1029
+ "outputs": [],
1030
+ "source": [
1031
+ "from tensorflow.keras.models import load_model\n",
1032
+ "my_model = load_model('', compile=False)\n",
1033
+ "\n",
1034
+ "img_num = 53\n",
1035
+ "test_scan = np.load('' + str(img_num) + '.npy')\n",
1036
+ "\n",
1037
+ "test_mask = np.load('' + str(img_num) + '.npy')\n",
1038
+ "test_mask_argmax = np.argmax(test_mask, axis = 3)\n",
1039
+ "\n",
1040
+ "test_scan_input = np.expand_dims(test_scan, axis = 0)\n",
1041
+ "test_prediction = my_model.predict(test_scan_input)\n",
1042
+ "test_prediction_argmax = np.argmax(test_prediction, axis = 4)[0, :, :, :]"
1043
+ ]
1044
+ },
1045
+ {
1046
+ "cell_type": "code",
1047
+ "execution_count": null,
1048
+ "metadata": {
1049
+ "colab": {
1050
+ "background_save": true
1051
+ },
1052
+ "id": "65FmAMNhmX8E"
1053
+ },
1054
+ "outputs": [],
1055
+ "source": [
1056
+ "# n_slice = 55\n",
1057
+ "n_slice = random.randint(0, test_mask_argmax.shape[2])\n",
1058
+ "\n",
1059
+ "plt.figure(figsize=(12,8))\n",
1060
+ "plt.subplot(231)\n",
1061
+ "plt.imshow(test_scan[:, :, n_slice, 1], cmap='gray')\n",
1062
+ "plt.title('Testing Scan')\n",
1063
+ "\n",
1064
+ "plt.subplot(232)\n",
1065
+ "plt.imshow(test_mask_argmax[:, :, n_slice])\n",
1066
+ "plt.title('Testing Label')\n",
1067
+ "\n",
1068
+ "plt.subplot(235)\n",
1069
+ "plt.imshow(test_prediction_argmax[:, :, n_slice])\n",
1070
+ "plt.title('Prediction on test image')\n",
1071
+ "\n",
1072
+ "plt.show()"
1073
+ ]
1074
+ }
1075
+ ],
1076
+ "metadata": {
1077
+ "accelerator": "GPU",
1078
+ "colab": {
1079
+ "collapsed_sections": [
1080
+ "TdEse3Kwq3JD",
1081
+ "L5yBxROtvDAI",
1082
+ "EORoZoj7yPfW",
1083
+ "-wICUx56ugDz",
1084
+ "nq3p80zN2ew2",
1085
+ "dBKMHMn96Z3c",
1086
+ "PR2Ugre0YP-v",
1087
+ "-Aw_Peb9iJYb",
1088
+ "e6Cvn6hWvars",
1089
+ "xndmsEwjVhn7",
1090
+ "pfcKmJv4jP2J"
1091
+ ],
1092
+ "provenance": []
1093
+ },
1094
+ "gpuClass": "standard",
1095
+ "kernelspec": {
1096
+ "display_name": "Python 3",
1097
+ "name": "python3"
1098
+ },
1099
+ "language_info": {
1100
+ "name": "python"
1101
+ }
1102
+ },
1103
+ "nbformat": 4,
1104
+ "nbformat_minor": 0
1105
+ }